Skip to content

Commit cc56474

Browse files
authored
[Keras 3 OpenVINO Backend]: Support numpy.median operation (#21667)
* feat: numpy median for openvino backend * feat: included tests for numpy median * fix: code format
1 parent 286c58f commit cc56474

File tree

2 files changed

+132
-3
lines changed

2 files changed

+132
-3
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ NumpyDtypeTest::test_logspace
4545
NumpyDtypeTest::test_matmul_
4646
NumpyDtypeTest::test_max
4747
NumpyDtypeTest::test_mean
48-
NumpyDtypeTest::test_median
4948
NumpyDtypeTest::test_minimum_python_types
5049
NumpyDtypeTest::test_multiply
5150
NumpyDtypeTest::test_power
@@ -98,7 +97,6 @@ NumpyOneInputOpsCorrectnessTest::test_isposinf
9897
NumpyOneInputOpsCorrectnessTest::test_logaddexp2
9998
NumpyOneInputOpsCorrectnessTest::test_max
10099
NumpyOneInputOpsCorrectnessTest::test_mean
101-
NumpyOneInputOpsCorrectnessTest::test_median
102100
NumpyOneInputOpsCorrectnessTest::test_pad_float16_constant_2
103101
NumpyOneInputOpsCorrectnessTest::test_pad_float32_constant_2
104102
NumpyOneInputOpsCorrectnessTest::test_pad_float64_constant_2

keras/src/backend/openvino/numpy.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1184,7 +1184,138 @@ def maximum(x1, x2):
11841184

11851185

11861186
def median(x, axis=None, keepdims=False):
1187-
raise NotImplementedError("`median` is not supported with openvino backend")
1187+
x = get_ov_output(x)
1188+
x_shape = x.get_partial_shape()
1189+
rank = x_shape.rank.get_length()
1190+
1191+
if rank == 0:
1192+
return OpenVINOKerasTensor(x)
1193+
1194+
# Handle axis=None by flattening the input
1195+
flattened_all = False
1196+
if axis is None:
1197+
x = ov_opset.reshape(x, [-1], False).output(0)
1198+
axis = 0
1199+
original_rank = rank
1200+
rank = 1
1201+
flattened_all = True
1202+
else:
1203+
# Handle tuple axis - for median, we only support single axis
1204+
if isinstance(axis, (tuple, list)):
1205+
if len(axis) != 1:
1206+
raise ValueError("median only supports single axis reduction")
1207+
axis = axis[0]
1208+
1209+
# Handle negative axis
1210+
if axis < 0:
1211+
axis = rank + axis
1212+
original_rank = rank
1213+
1214+
# Get the size of the dimension to sort
1215+
shape_tensor = ov_opset.shape_of(x, output_type=Type.i32).output(0)
1216+
k = ov_opset.gather(
1217+
shape_tensor,
1218+
ov_opset.constant([axis], Type.i32).output(0),
1219+
ov_opset.constant(0, Type.i32).output(0),
1220+
).output(0)
1221+
1222+
# Convert k to a scalar value
1223+
k_scalar = ov_opset.squeeze(k, [0]).output(0)
1224+
1225+
# Use topk with k=size_of_axis to get all elements sorted
1226+
topk_outputs = ov_opset.topk(
1227+
x, k=k_scalar, axis=axis, mode="min", sort="value", stable=True
1228+
)
1229+
1230+
# Get the sorted values
1231+
sorted_values = topk_outputs.output(0)
1232+
1233+
# Convert to float for median calculation
1234+
x1_type = ov_to_keras_type(sorted_values.get_element_type())
1235+
result_type = dtypes.result_type(x1_type, float)
1236+
result_type = OPENVINO_DTYPES[result_type]
1237+
sorted_values = ov_opset.convert(sorted_values, result_type).output(0)
1238+
1239+
# Calculate median indices
1240+
# For odd length: median_idx = (k-1) // 2
1241+
# For even length: we need indices (k//2 - 1) and k//2, then average
1242+
1243+
k_minus_1 = ov_opset.subtract(
1244+
k_scalar, ov_opset.constant(1, Type.i32).output(0)
1245+
).output(0)
1246+
k_div_2 = ov_opset.divide(
1247+
k_scalar, ov_opset.constant(2, Type.i32).output(0)
1248+
).output(0)
1249+
k_minus_1_div_2 = ov_opset.divide(
1250+
k_minus_1, ov_opset.constant(2, Type.i32).output(0)
1251+
).output(0)
1252+
1253+
# Check if k is odd
1254+
k_mod_2 = ov_opset.mod(
1255+
k_scalar, ov_opset.constant(2, Type.i32).output(0)
1256+
).output(0)
1257+
is_odd = ov_opset.equal(
1258+
k_mod_2, ov_opset.constant(1, Type.i32).output(0)
1259+
).output(0)
1260+
1261+
# For odd case: take the middle element
1262+
odd_idx = k_minus_1_div_2
1263+
1264+
# For even case: take average of two middle elements
1265+
even_idx1 = ov_opset.subtract(
1266+
k_div_2, ov_opset.constant(1, Type.i32).output(0)
1267+
).output(0)
1268+
even_idx2 = k_div_2
1269+
1270+
# Gather elements for both cases
1271+
# Create gather indices tensor for the axis
1272+
gather_indices_odd = ov_opset.unsqueeze(odd_idx, [0]).output(0)
1273+
gather_indices_even1 = ov_opset.unsqueeze(even_idx1, [0]).output(0)
1274+
gather_indices_even2 = ov_opset.unsqueeze(even_idx2, [0]).output(0)
1275+
1276+
# Gather the median elements
1277+
odd_result = ov_opset.gather(
1278+
sorted_values,
1279+
gather_indices_odd,
1280+
ov_opset.constant(axis, Type.i32).output(0),
1281+
).output(0)
1282+
even_result1 = ov_opset.gather(
1283+
sorted_values,
1284+
gather_indices_even1,
1285+
ov_opset.constant(axis, Type.i32).output(0),
1286+
).output(0)
1287+
even_result2 = ov_opset.gather(
1288+
sorted_values,
1289+
gather_indices_even2,
1290+
ov_opset.constant(axis, Type.i32).output(0),
1291+
).output(0)
1292+
1293+
# Average the two middle elements for even case
1294+
even_sum = ov_opset.add(even_result1, even_result2).output(0)
1295+
even_result = ov_opset.divide(
1296+
even_sum, ov_opset.constant(2.0, result_type).output(0)
1297+
).output(0)
1298+
1299+
# Select between odd and even results
1300+
median_result = ov_opset.select(is_odd, odd_result, even_result).output(0)
1301+
1302+
# Remove the gathered dimension (squeeze)
1303+
median_result = ov_opset.squeeze(median_result, [axis]).output(0)
1304+
1305+
# Handle keepdims
1306+
if keepdims:
1307+
if flattened_all:
1308+
# When axis=None, keepdims should restore all dimensions as 1
1309+
ones_shape = ov_opset.constant(
1310+
[1] * original_rank, Type.i32
1311+
).output(0)
1312+
median_result = ov_opset.reshape(
1313+
median_result, ones_shape, False
1314+
).output(0)
1315+
else:
1316+
median_result = ov_opset.unsqueeze(median_result, [axis]).output(0)
1317+
1318+
return OpenVINOKerasTensor(median_result)
11881319

11891320

11901321
def meshgrid(*x, indexing="xy"):

0 commit comments

Comments
 (0)