Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ NumpyDtypeTest::test_logspace
NumpyDtypeTest::test_matmul_
NumpyDtypeTest::test_max
NumpyDtypeTest::test_mean
NumpyDtypeTest::test_median
NumpyDtypeTest::test_meshgrid
NumpyDtypeTest::test_minimum_python_types
NumpyDtypeTest::test_multiply
Expand Down Expand Up @@ -95,7 +94,6 @@ NumpyOneInputOpsCorrectnessTest::test_isinf
NumpyOneInputOpsCorrectnessTest::test_logaddexp
NumpyOneInputOpsCorrectnessTest::test_max
NumpyOneInputOpsCorrectnessTest::test_mean
NumpyOneInputOpsCorrectnessTest::test_median
NumpyOneInputOpsCorrectnessTest::test_meshgrid
NumpyOneInputOpsCorrectnessTest::test_pad_float16_constant_2
NumpyOneInputOpsCorrectnessTest::test_pad_float32_constant_2
Expand Down
140 changes: 139 additions & 1 deletion keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,145 @@ def maximum(x1, x2):


def median(x, axis=None, keepdims=False):
raise NotImplementedError("`median` is not supported with openvino backend")
if np.isscalar(x):
x = get_ov_output(x)
return OpenVINOKerasTensor(x)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add comments to explain the logic of the algorithm you implemented

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments added and some variables renamed to better explain the logic.

x = get_ov_output(x)
x_type = x.get_element_type()
x_rank_org = x.get_partial_shape().rank.get_length()
if x_type == Type.boolean or x_type.is_integral():
x_type = OPENVINO_DTYPES[config.floatx()]
x = ov_opset.convert(x, x_type).output(0)

x_shape_original = ov_opset.shape_of(x, Type.i32).output(0)

if axis is None:
flatten_shape = ov_opset.constant([-1], Type.i32).output(0)
x = ov_opset.reshape(x, flatten_shape, False).output(0)
axis = 0
axis_norm = axis
ov_axis_positive = get_ov_output(axis)
flattened = True
k_value = x.get_partial_shape().get_dimension(index=0).get_length()
elif isinstance(axis, int):
flattened = False
x_rank = x.get_partial_shape().rank.get_length()
if axis < 0:
axis_norm = x_rank + axis
else:
axis_norm = axis
ov_axis_positive = ov_axis = get_ov_output(axis)
k_value = (
x.get_partial_shape().get_dimension(index=axis_norm).get_length()
)
else:
# where axis is tuple or list of integers, move 'axis' dims to the
# rightmost positions and flatten them
flattened = False
if isinstance(axis, (tuple, list)):
ov_axis = axis = list(axis)
ov_axis = ov_opset.constant(axis, Type.i32).output(0)
x_rank = x.get_partial_shape().rank.get_length()
axis_as_range = ov_opset.range(
ov_opset.constant(0, Type.i32).output(0),
x_rank,
ov_opset.constant(1, Type.i32).output(0),
Type.i32,
).output(0)
# normalise any negative axes to their positive indices
ov_axis_positive = ov_opset.gather(
axis_as_range, ov_axis, ov_opset.constant([0], Type.i32)
).output(0)
# only move axis dims if tuple contains more than 1 axis
if ov_axis_positive.get_partial_shape().rank.get_length() > 1:
axis_compare = ov_opset.not_equal(
ov_opset.unsqueeze(axis_as_range, 1).output(0),
ov_opset.unsqueeze(ov_axis_positive, 0).output(0),
).output(0)
keep_axes = ov_opset.reduce_logical_or(
axis_compare, ov_opset.constant([1], Type.i32).output(0)
).output(0)
nz = ov_opset.non_zero(keep_axes, Type.i32).output(0)
keep_axes = ov_opset.reduce_sum(
nz, ov_opset.constant([1], Type.i32).output(0)
).output(0)
reordered_axes = ov_opset.concat(
[keep_axes, ov_axis_positive], 0
).output(0)
x = ov_opset.transpose(x, reordered_axes).output(0)

flat_rank = ov_opset.subtract(
x_rank, ov_opset.constant([1], Type.i64).output(0)
).output(0)
flatten_shape = ov_opset.broadcast(
ov_opset.constant([0], Type.i32).output(0), flat_rank
).output(0)
flatten_shape = ov_opset.scatter_elements_update(
flatten_shape,
ov_opset.constant([-1], Type.i32).output(0),
ov_opset.constant([-1], Type.i32).output(0),
0,
"sum",
).output(0)

x = ov_opset.reshape(x, flatten_shape, True).output(0)
axis = -1
x_rank = x.get_partial_shape().rank.get_length()
axis_norm = x_rank + axis
ov_axis_positive = get_ov_output(axis_norm)
k_value = (
x.get_partial_shape().get_dimension(index=axis_norm).get_length()
)

x_sorted = ov_opset.topk(
x, k_value, axis_norm, "min", "value", stable=True
).output(0)
k_value = ov_opset.convert(k_value, x_type).output(0)
half_index = ov_opset.floor(
ov_opset.divide(k_value, ov_opset.constant([2], x_type)).output(0)
).output(0)
half_index = ov_opset.convert(half_index, Type.i32).output(0)
x_mod = ov_opset.mod(k_value, ov_opset.constant([2], x_type)).output(0)
is_even = ov_opset.equal(x_mod, ov_opset.constant([0], x_type)).output(0)

med_0 = ov_opset.gather(x_sorted, half_index, ov_axis_positive).output(0)
med_1 = ov_opset.select(
is_even,
ov_opset.gather(
x_sorted,
ov_opset.subtract(
half_index, ov_opset.constant([1], Type.i32)
).output(0),
ov_axis_positive,
).output(0),
med_0,
).output(0)

median_odd = med_0
median_even = ov_opset.divide(
ov_opset.add(med_1, med_0).output(0),
ov_opset.constant([2], x_type),
).output(0)

median_eval = ov_opset.select(is_even, median_even, median_odd).output(0)

if keepdims:
if flattened:
median_shape = ov_opset.divide(
x_shape_original, x_shape_original, "none"
).output(0)
median_eval = ov_opset.reshape(
median_eval, median_shape, False
).output(0)
else:
if median_eval.get_partial_shape().rank.get_length() != x_rank_org:
median_eval = ov_opset.unsqueeze(median_eval, ov_axis).output(0)

else:
median_eval = ov_opset.squeeze(median_eval, ov_axis_positive).output(0)

return OpenVINOKerasTensor(median_eval)


def meshgrid(*x, indexing="xy"):
Expand Down
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
env =
KERAS_BACKEND=openvino
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please revert

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File has been deleted.