@@ -1184,7 +1184,138 @@ def maximum(x1, x2):
1184
1184
1185
1185
1186
1186
def median (x , axis = None , keepdims = False ):
1187
- raise NotImplementedError ("`median` is not supported with openvino backend" )
1187
+ x = get_ov_output (x )
1188
+ x_shape = x .get_partial_shape ()
1189
+ rank = x_shape .rank .get_length ()
1190
+
1191
+ if rank == 0 :
1192
+ return OpenVINOKerasTensor (x )
1193
+
1194
+ # Handle axis=None by flattening the input
1195
+ flattened_all = False
1196
+ if axis is None :
1197
+ x = ov_opset .reshape (x , [- 1 ], False ).output (0 )
1198
+ axis = 0
1199
+ original_rank = rank
1200
+ rank = 1
1201
+ flattened_all = True
1202
+ else :
1203
+ # Handle tuple axis - for median, we only support single axis
1204
+ if isinstance (axis , (tuple , list )):
1205
+ if len (axis ) != 1 :
1206
+ raise ValueError ("median only supports single axis reduction" )
1207
+ axis = axis [0 ]
1208
+
1209
+ # Handle negative axis
1210
+ if axis < 0 :
1211
+ axis = rank + axis
1212
+ original_rank = rank
1213
+
1214
+ # Get the size of the dimension to sort
1215
+ shape_tensor = ov_opset .shape_of (x , output_type = Type .i32 ).output (0 )
1216
+ k = ov_opset .gather (
1217
+ shape_tensor ,
1218
+ ov_opset .constant ([axis ], Type .i32 ).output (0 ),
1219
+ ov_opset .constant (0 , Type .i32 ).output (0 ),
1220
+ ).output (0 )
1221
+
1222
+ # Convert k to a scalar value
1223
+ k_scalar = ov_opset .squeeze (k , [0 ]).output (0 )
1224
+
1225
+ # Use topk with k=size_of_axis to get all elements sorted
1226
+ topk_outputs = ov_opset .topk (
1227
+ x , k = k_scalar , axis = axis , mode = "min" , sort = "value" , stable = True
1228
+ )
1229
+
1230
+ # Get the sorted values
1231
+ sorted_values = topk_outputs .output (0 )
1232
+
1233
+ # Convert to float for median calculation
1234
+ x1_type = ov_to_keras_type (sorted_values .get_element_type ())
1235
+ result_type = dtypes .result_type (x1_type , float )
1236
+ result_type = OPENVINO_DTYPES [result_type ]
1237
+ sorted_values = ov_opset .convert (sorted_values , result_type ).output (0 )
1238
+
1239
+ # Calculate median indices
1240
+ # For odd length: median_idx = (k-1) // 2
1241
+ # For even length: we need indices (k//2 - 1) and k//2, then average
1242
+
1243
+ k_minus_1 = ov_opset .subtract (
1244
+ k_scalar , ov_opset .constant (1 , Type .i32 ).output (0 )
1245
+ ).output (0 )
1246
+ k_div_2 = ov_opset .divide (
1247
+ k_scalar , ov_opset .constant (2 , Type .i32 ).output (0 )
1248
+ ).output (0 )
1249
+ k_minus_1_div_2 = ov_opset .divide (
1250
+ k_minus_1 , ov_opset .constant (2 , Type .i32 ).output (0 )
1251
+ ).output (0 )
1252
+
1253
+ # Check if k is odd
1254
+ k_mod_2 = ov_opset .mod (
1255
+ k_scalar , ov_opset .constant (2 , Type .i32 ).output (0 )
1256
+ ).output (0 )
1257
+ is_odd = ov_opset .equal (
1258
+ k_mod_2 , ov_opset .constant (1 , Type .i32 ).output (0 )
1259
+ ).output (0 )
1260
+
1261
+ # For odd case: take the middle element
1262
+ odd_idx = k_minus_1_div_2
1263
+
1264
+ # For even case: take average of two middle elements
1265
+ even_idx1 = ov_opset .subtract (
1266
+ k_div_2 , ov_opset .constant (1 , Type .i32 ).output (0 )
1267
+ ).output (0 )
1268
+ even_idx2 = k_div_2
1269
+
1270
+ # Gather elements for both cases
1271
+ # Create gather indices tensor for the axis
1272
+ gather_indices_odd = ov_opset .unsqueeze (odd_idx , [0 ]).output (0 )
1273
+ gather_indices_even1 = ov_opset .unsqueeze (even_idx1 , [0 ]).output (0 )
1274
+ gather_indices_even2 = ov_opset .unsqueeze (even_idx2 , [0 ]).output (0 )
1275
+
1276
+ # Gather the median elements
1277
+ odd_result = ov_opset .gather (
1278
+ sorted_values ,
1279
+ gather_indices_odd ,
1280
+ ov_opset .constant (axis , Type .i32 ).output (0 ),
1281
+ ).output (0 )
1282
+ even_result1 = ov_opset .gather (
1283
+ sorted_values ,
1284
+ gather_indices_even1 ,
1285
+ ov_opset .constant (axis , Type .i32 ).output (0 ),
1286
+ ).output (0 )
1287
+ even_result2 = ov_opset .gather (
1288
+ sorted_values ,
1289
+ gather_indices_even2 ,
1290
+ ov_opset .constant (axis , Type .i32 ).output (0 ),
1291
+ ).output (0 )
1292
+
1293
+ # Average the two middle elements for even case
1294
+ even_sum = ov_opset .add (even_result1 , even_result2 ).output (0 )
1295
+ even_result = ov_opset .divide (
1296
+ even_sum , ov_opset .constant (2.0 , result_type ).output (0 )
1297
+ ).output (0 )
1298
+
1299
+ # Select between odd and even results
1300
+ median_result = ov_opset .select (is_odd , odd_result , even_result ).output (0 )
1301
+
1302
+ # Remove the gathered dimension (squeeze)
1303
+ median_result = ov_opset .squeeze (median_result , [axis ]).output (0 )
1304
+
1305
+ # Handle keepdims
1306
+ if keepdims :
1307
+ if flattened_all :
1308
+ # When axis=None, keepdims should restore all dimensions as 1
1309
+ ones_shape = ov_opset .constant (
1310
+ [1 ] * original_rank , Type .i32
1311
+ ).output (0 )
1312
+ median_result = ov_opset .reshape (
1313
+ median_result , ones_shape , False
1314
+ ).output (0 )
1315
+ else :
1316
+ median_result = ov_opset .unsqueeze (median_result , [axis ]).output (0 )
1317
+
1318
+ return OpenVINOKerasTensor (median_result )
1188
1319
1189
1320
1190
1321
def meshgrid (* x , indexing = "xy" ):
0 commit comments