@@ -1046,8 +1046,73 @@ def maximum(x1, x2):
1046
1046
1047
1047
1048
1048
def median (x , axis = None , keepdims = False ):
1049
- raise NotImplementedError ("`median` is not supported with openvino backend" )
1050
-
1049
+ x = get_ov_output (x )
1050
+ x_shape_original = ov_opset .shape_of (x ).output (0 )
1051
+
1052
+ if axis is None :
1053
+ flatten_shape = ov_opset .constant ([- 1 ], Type .i32 ).output (0 )
1054
+ x = ov_opset .reshape (x , flatten_shape , False ).output (0 )
1055
+ axis = 0
1056
+ flattened = True
1057
+ int_axis = False
1058
+ x_shape = ov_opset .shape_of (x ).output (0 )
1059
+ k_value = ov_opset .convert (x_shape , Type .i32 ).output (0 )
1060
+ elif isinstance (axis , int ):
1061
+ flattened = False
1062
+ int_axis = True
1063
+ ov_axis = ov_opset .constant (axis , Type .i32 ).output (0 )
1064
+ x_shape = ov_opset .shape_of (x ).output (0 )
1065
+ k_value = ov_opset .convert (ov_opset .gather (x_shape , ov_axis , ov_opset .constant ([0 ], Type .i32 ).output (0 )).output (0 ), Type .i32 ).output (0 )
1066
+ else :
1067
+ # axis = (2, 1)
1068
+ flattened = False
1069
+ int_axis = False
1070
+ ov_axis = ov_opset .constant (axis , Type .i32 ).output (0 ) # (2, 1)
1071
+ x_rank = ov_opset .shape_of (x_shape_original ).output (0 ) # 4
1072
+ axis_range = ov_opset .range (ov_opset .constant ([0 ], Type .i32 ).output (0 ), x_rank , ov_opset .constant ([1 ], Type .i32 ).output (0 )).output (0 )
1073
+ axis_compare = ov_opset .equal (ov_opset .unsqueeze (ov_axis , 1 ).output (0 ), ov_opset .unsqueeze (axis_range , 0 ).output (0 )).output (0 )
1074
+ mask_remove = ov_opset .reduce_logical_or (axis_compare , ov_opset .constant ([0 ], Type .i32 ).output (0 )).output (0 )
1075
+ mask_keep = ov_opset .logical_not (mask_remove ).output (0 )
1076
+ nz = ov_opset .non_zero (mask_keep , "i32" ).output (0 )
1077
+ indices_keep = ov_opset .squeeze (nz , [0 ]).output (0 )
1078
+ axis_range = ov_opset .gather (axis_range , indices_keep , ov_opset .constant ([0 ], Type .i32 ).output (0 )).output (0 ) # (0, 3)
1079
+ axis_range = ov_opset .concat ([axis_range , ov_axis ], ov_opset .constant ([0 ], Type .i32 ).output (0 )).output (0 ) # (0, 3, 2, 1)
1080
+ x = ov_opset .transpose (x , axis_range ).output (0 ) # x = (d0, d3, d2, d1)
1081
+
1082
+ flat_rank = ov_opset .subtract (x_rank , ov_opset .constant ([1 ], Type .i32 )).output (0 )
1083
+ flatten_shape = ov_opset .constant ([0 ], shape = flat_rank , type_info = Type .i32 ).output (0 )
1084
+ flatten_shape = ov_opset .scatter_elements_update (flatten_shape , ov_opset .constant ([- 1 ], Type .i32 ).output (0 ), [- 1 ], [0 ], "sum" )
1085
+
1086
+ x = ov_opset .reshape (x , flatten_shape , True ).output (0 ) # x = (d0, d3, d2*d1)
1087
+ axis = - 1
1088
+ x_shape = ov_opset .shape_of (x ).output (0 )
1089
+ k_value = ov_opset .gather (x_shape , ov_opset .constant ([- 1 ], Type .i32 ).output (0 ), ov_opset .constant ([0 ], Type .i32 ).output (0 )).output (0 )
1090
+ k_value = ov_opset .convert (k_value , Type .i32 ).output (0 )
1091
+
1092
+ x_sorted = ov_opset .topk (x , k_value , axis , 'min' , 'value' , stable = True ).output (0 )
1093
+ half_index = ov_opset .divide (k_value , ov_opset .constant ([2 ], Type .i32 )).output (0 )
1094
+ x_mod = ov_opset .mod (k_value , ov_opset .constant ([2 ], Type .i32 )).output (0 )
1095
+ is_even = ov_opset .equal (x_mod , ov_opset .constant ([0 ], Type .i32 )).output (0 )
1096
+ med_index_0 = ov_opset .gather (x_sorted , ov_opset .floor (half_index ).output (0 ), axis ).output (0 ) # COME BACK, does it sort out higher dimensions?
1097
+ med_index_1 = ov_opset .gather (x_sorted , ov_opset .add (med_index_0 , ov_opset .constant ([1 ], Type .i32 )).output (0 ), axis ).output (0 )
1098
+
1099
+ median_odd = med_index_0
1100
+ median_even = ov_opset .divide (ov_opset .add (med_index_1 , med_index_0 ).output (0 ), ov_opset .constant ([2 ], Type .i32 ))
1101
+
1102
+ median_eval = ov_opset .select (is_even , median_even , median_odd )
1103
+
1104
+ if keepdims == True :
1105
+ if flattened == True :
1106
+ median_shape = ov_opset .divide (x_shape_original , x_shape_original ).output (0 )
1107
+ median_eval = ov_opset .reshape (median_eval , median_shape , False ).output (0 )
1108
+ elif int_axis == True :
1109
+ median_shape = ov_opset .shape_of (median_eval ).output (0 )
1110
+ median_shape = ov_opset .unsqueeze (median_shape , axis ).output (0 )
1111
+ median_eval = ov_opset .reshape (median_eval , median_shape , False ).output (0 )
1112
+ else :
1113
+ median_eval = ov_opset .unsqueeze (median_eval , ov_axis ).output (0 )
1114
+
1115
+ return OpenVINOKerasTensor (median_eval )
1051
1116
1052
1117
def meshgrid (* x , indexing = "xy" ):
1053
1118
raise NotImplementedError (
0 commit comments