Skip to content

Commit 80e38da

Browse files
committed
WIP: median support for OpenVINO back-end
1 parent e4bca84 commit 80e38da

File tree

3 files changed

+70
-4
lines changed

3 files changed

+70
-4
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ NumpyDtypeTest::test_logspace
3838
NumpyDtypeTest::test_matmul_
3939
NumpyDtypeTest::test_max
4040
NumpyDtypeTest::test_mean
41-
NumpyDtypeTest::test_median
4241
NumpyDtypeTest::test_meshgrid
4342
NumpyDtypeTest::test_minimum_python_types
4443
NumpyDtypeTest::test_multiply
@@ -95,7 +94,6 @@ NumpyOneInputOpsCorrectnessTest::test_isinf
9594
NumpyOneInputOpsCorrectnessTest::test_logaddexp
9695
NumpyOneInputOpsCorrectnessTest::test_max
9796
NumpyOneInputOpsCorrectnessTest::test_mean
98-
NumpyOneInputOpsCorrectnessTest::test_median
9997
NumpyOneInputOpsCorrectnessTest::test_meshgrid
10098
NumpyOneInputOpsCorrectnessTest::test_pad_float16_constant_2
10199
NumpyOneInputOpsCorrectnessTest::test_pad_float32_constant_2

keras/src/backend/openvino/numpy.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,8 +1046,73 @@ def maximum(x1, x2):
10461046

10471047

10481048
def median(x, axis=None, keepdims=False):
1049-
raise NotImplementedError("`median` is not supported with openvino backend")
1050-
1049+
x = get_ov_output(x)
1050+
x_shape_original = ov_opset.shape_of(x).output(0)
1051+
1052+
if axis is None:
1053+
flatten_shape = ov_opset.constant([-1], Type.i32).output(0)
1054+
x = ov_opset.reshape(x, flatten_shape, False).output(0)
1055+
axis = 0
1056+
flattened = True
1057+
int_axis = False
1058+
x_shape = ov_opset.shape_of(x).output(0)
1059+
k_value = ov_opset.convert(x_shape, Type.i32).output(0)
1060+
elif isinstance(axis, int):
1061+
flattened = False
1062+
int_axis = True
1063+
ov_axis = ov_opset.constant(axis, Type.i32).output(0)
1064+
x_shape = ov_opset.shape_of(x).output(0)
1065+
k_value = ov_opset.convert(ov_opset.gather(x_shape, ov_axis, ov_opset.constant([0], Type.i32).output(0)).output(0), Type.i32).output(0)
1066+
else:
1067+
# axis = (2, 1)
1068+
flattened = False
1069+
int_axis = False
1070+
ov_axis = ov_opset.constant(axis, Type.i32).output(0) # (2, 1)
1071+
x_rank = ov_opset.shape_of(x_shape_original).output(0) # 4
1072+
axis_range = ov_opset.range(ov_opset.constant([0], Type.i32).output(0), x_rank, ov_opset.constant([1], Type.i32).output(0)).output(0)
1073+
axis_compare = ov_opset.equal(ov_opset.unsqueeze(ov_axis, 1).output(0), ov_opset.unsqueeze(axis_range, 0).output(0)).output(0)
1074+
mask_remove = ov_opset.reduce_logical_or(axis_compare, ov_opset.constant([0], Type.i32).output(0)).output(0)
1075+
mask_keep = ov_opset.logical_not(mask_remove).output(0)
1076+
nz = ov_opset.non_zero(mask_keep, "i32").output(0)
1077+
indices_keep = ov_opset.squeeze(nz, [0]).output(0)
1078+
axis_range = ov_opset.gather(axis_range, indices_keep, ov_opset.constant([0], Type.i32).output(0)).output(0) # (0, 3)
1079+
axis_range = ov_opset.concat([axis_range, ov_axis], ov_opset.constant([0], Type.i32).output(0)).output(0) # (0, 3, 2, 1)
1080+
x = ov_opset.transpose(x, axis_range).output(0) # x = (d0, d3, d2, d1)
1081+
1082+
flat_rank = ov_opset.subtract(x_rank, ov_opset.constant([1], Type.i32)).output(0)
1083+
flatten_shape = ov_opset.constant([0], shape=flat_rank, type_info=Type.i32).output(0)
1084+
flatten_shape = ov_opset.scatter_elements_update(flatten_shape, ov_opset.constant([-1], Type.i32).output(0), [-1], [0], "sum")
1085+
1086+
x = ov_opset.reshape(x, flatten_shape, True).output(0) # x = (d0, d3, d2*d1)
1087+
axis = -1
1088+
x_shape = ov_opset.shape_of(x).output(0)
1089+
k_value = ov_opset.gather(x_shape, ov_opset.constant([-1], Type.i32).output(0), ov_opset.constant([0], Type.i32).output(0)).output(0)
1090+
k_value = ov_opset.convert(k_value, Type.i32).output(0)
1091+
1092+
x_sorted = ov_opset.topk(x, k_value, axis, 'min', 'value', stable=True).output(0)
1093+
half_index = ov_opset.divide(k_value, ov_opset.constant([2], Type.i32)).output(0)
1094+
x_mod = ov_opset.mod(k_value, ov_opset.constant([2], Type.i32)).output(0)
1095+
is_even = ov_opset.equal(x_mod, ov_opset.constant([0], Type.i32)).output(0)
1096+
med_index_0 = ov_opset.gather(x_sorted, ov_opset.floor(half_index).output(0), axis).output(0) # COME BACK, does it sort out higher dimensions?
1097+
med_index_1 = ov_opset.gather(x_sorted, ov_opset.add(med_index_0, ov_opset.constant([1], Type.i32)).output(0), axis).output(0)
1098+
1099+
median_odd = med_index_0
1100+
median_even = ov_opset.divide(ov_opset.add(med_index_1, med_index_0).output(0), ov_opset.constant([2], Type.i32))
1101+
1102+
median_eval = ov_opset.select(is_even, median_even, median_odd)
1103+
1104+
if keepdims == True:
1105+
if flattened == True:
1106+
median_shape = ov_opset.divide(x_shape_original, x_shape_original).output(0)
1107+
median_eval = ov_opset.reshape(median_eval, median_shape, False).output(0)
1108+
elif int_axis == True:
1109+
median_shape = ov_opset.shape_of(median_eval).output(0)
1110+
median_shape = ov_opset.unsqueeze(median_shape, axis).output(0)
1111+
median_eval = ov_opset.reshape(median_eval, median_shape, False).output(0)
1112+
else:
1113+
median_eval = ov_opset.unsqueeze(median_eval, ov_axis).output(0)
1114+
1115+
return OpenVINOKerasTensor(median_eval)
10511116

10521117
def meshgrid(*x, indexing="xy"):
10531118
raise NotImplementedError(

pytest.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[pytest]
2+
env =
3+
KERAS_BACKEND=openvino

0 commit comments

Comments
 (0)