@@ -1046,74 +1046,148 @@ def maximum(x1, x2):
1046
1046
1047
1047
1048
1048
def median (x , axis = None , keepdims = False ):
1049
+ if np .isscalar (x ):
1050
+ x = get_ov_output (x )
1051
+ return OpenVINOKerasTensor (x )
1052
+
1049
1053
x = get_ov_output (x )
1050
- x_shape_original = ov_opset .shape_of (x ).output (0 )
1051
-
1054
+ x_type = x .get_element_type ()
1055
+ if x_type == Type .boolean or x_type .is_integral ():
1056
+ x = ov_opset .convert (x , Type .f32 ).output (0 )
1057
+ x_type = x .get_element_type ()
1058
+ x_shape_original = ov_opset .shape_of (x , Type .i32 ).output (0 )
1059
+
1052
1060
if axis is None :
1053
1061
flatten_shape = ov_opset .constant ([- 1 ], Type .i32 ).output (0 )
1054
1062
x = ov_opset .reshape (x , flatten_shape , False ).output (0 )
1055
1063
axis = 0
1064
+ ov_axis = get_ov_output (axis )
1056
1065
flattened = True
1057
- int_axis = False
1058
- x_shape = ov_opset .shape_of (x ).output (0 )
1059
- k_value = ov_opset .convert (x_shape , Type .i32 ).output (0 )
1066
+ k_value = ov_opset .gather (
1067
+ ov_opset .shape_of (x , Type .i32 ).output (0 ),
1068
+ ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1069
+ ov_axis ,
1070
+ ).output (0 )
1060
1071
elif isinstance (axis , int ):
1061
1072
flattened = False
1062
- int_axis = True
1063
- ov_axis = ov_opset .constant (axis , Type .i32 ).output (0 )
1064
- x_shape = ov_opset .shape_of (x ).output (0 )
1065
- k_value = ov_opset .convert (ov_opset .gather (x_shape , ov_axis , ov_opset .constant ([0 ], Type .i32 ).output (0 )).output (0 ), Type .i32 ).output (0 )
1073
+ ov_axis = get_ov_output (axis )
1074
+ x_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1075
+ k_value = ov_opset .gather (
1076
+ x_shape , ov_axis , ov_opset .constant ([0 ], Type .i32 ).output (0 )
1077
+ ).output (0 )
1066
1078
else :
1067
- # axis = (2, 1)
1068
1079
flattened = False
1069
- int_axis = False
1070
- ov_axis = ov_opset .constant (axis , Type .i32 ).output (0 ) # (2, 1)
1071
- x_rank = ov_opset .shape_of (x_shape_original ).output (0 ) # 4
1072
- axis_range = ov_opset .range (ov_opset .constant ([0 ], Type .i32 ).output (0 ), x_rank , ov_opset .constant ([1 ], Type .i32 ).output (0 )).output (0 )
1073
- axis_compare = ov_opset .equal (ov_opset .unsqueeze (ov_axis , 1 ).output (0 ), ov_opset .unsqueeze (axis_range , 0 ).output (0 )).output (0 )
1074
- mask_remove = ov_opset .reduce_logical_or (axis_compare , ov_opset .constant ([0 ], Type .i32 ).output (0 )).output (0 )
1075
- mask_keep = ov_opset .logical_not (mask_remove ).output (0 )
1076
- nz = ov_opset .non_zero (mask_keep , "i32" ).output (0 )
1077
- indices_keep = ov_opset .squeeze (nz , [0 ]).output (0 )
1078
- axis_range = ov_opset .gather (axis_range , indices_keep , ov_opset .constant ([0 ], Type .i32 ).output (0 )).output (0 ) # (0, 3)
1079
- axis_range = ov_opset .concat ([axis_range , ov_axis ], ov_opset .constant ([0 ], Type .i32 ).output (0 )).output (0 ) # (0, 3, 2, 1)
1080
- x = ov_opset .transpose (x , axis_range ).output (0 ) # x = (d0, d3, d2, d1)
1081
-
1082
- flat_rank = ov_opset .subtract (x_rank , ov_opset .constant ([1 ], Type .i32 )).output (0 )
1083
- flatten_shape = ov_opset .constant ([0 ], shape = flat_rank , type_info = Type .i32 ).output (0 )
1084
- flatten_shape = ov_opset .scatter_elements_update (flatten_shape , ov_opset .constant ([- 1 ], Type .i32 ).output (0 ), [- 1 ], [0 ], "sum" )
1085
-
1086
- x = ov_opset .reshape (x , flatten_shape , True ).output (0 ) # x = (d0, d3, d2*d1)
1080
+ ov_axis = get_ov_output (axis )
1081
+ x_rank = ov_opset .gather (
1082
+ ov_opset .shape_of (x_shape_original , Type .i32 ).output (0 ),
1083
+ ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1084
+ ).output (0 )
1085
+ axis_as_range = ov_opset .range (
1086
+ ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1087
+ x_rank ,
1088
+ ov_opset .constant ([1 ], Type .i32 ).output (0 ),
1089
+ ).output (0 )
1090
+ axis_compare = ov_opset .not_equal (
1091
+ ov_opset .unsqueeze (axis_as_range , 1 ).output (0 ),
1092
+ ov_opset .unsqueeze (ov_axis , 0 ).output (0 ),
1093
+ "NUMPY" ,
1094
+ ).output (0 )
1095
+ keep_axes = ov_opset .reduce_logical_or (
1096
+ axis_compare , ov_opset .constant ([1 ], Type .i32 ).output (0 )
1097
+ ).output (0 )
1098
+ nz = ov_opset .non_zero (keep_axes , Type .i32 ).output (0 )
1099
+ keep_axes = ov_opset .reduce_sum (
1100
+ nz , ov_opset .constant ([1 ], Type .i32 ).output (0 )
1101
+ ).output (0 )
1102
+ reordered_axes = ov_opset .concat (
1103
+ [keep_axes , ov_axis ], ov_opset .constant ([0 ], Type .i32 ).output (0 )
1104
+ ).output (0 )
1105
+ x = ov_opset .transpose (x , reordered_axes ).output (0 )
1106
+
1107
+ flat_rank = ov_opset .subtract (
1108
+ x_rank , ov_opset .constant ([1 ], Type .i32 )
1109
+ ).output (0 )
1110
+ flatten_shape = ov_opset .broadcast (
1111
+ ov_opset .constant ([0 ], Type .i32 ).output (0 ), flat_rank
1112
+ ).output (0 )
1113
+ flatten_shape = ov_opset .scatter_elements_update (
1114
+ flatten_shape ,
1115
+ ov_opset .constant ([- 1 ], Type .i32 ).output (0 ),
1116
+ ov_opset .constant ([- 1 ], Type .i32 ).output (0 ),
1117
+ 0 ,
1118
+ "sum" ,
1119
+ ).output (0 )
1120
+
1121
+ x = ov_opset .reshape (x , flatten_shape , True ).output (0 )
1087
1122
axis = - 1
1088
- x_shape = ov_opset .shape_of (x ).output (0 )
1089
- k_value = ov_opset .gather (x_shape , ov_opset .constant ([- 1 ], Type .i32 ).output (0 ), ov_opset .constant ([0 ], Type .i32 ).output (0 )).output (0 )
1090
- k_value = ov_opset .convert (k_value , Type .i32 ).output (0 )
1091
-
1092
- x_sorted = ov_opset .topk (x , k_value , axis , 'min' , 'value' , stable = True ).output (0 )
1093
- half_index = ov_opset .divide (k_value , ov_opset .constant ([2 ], Type .i32 )).output (0 )
1123
+ x_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1124
+ k_value = ov_opset .gather (
1125
+ x_shape ,
1126
+ ov_opset .constant ([- 1 ], Type .i32 ).output (0 ),
1127
+ ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1128
+ ).output (0 )
1129
+
1130
+ if axis < 0 :
1131
+ x_rank = ov_opset .gather (
1132
+ ov_opset .shape_of (x , Type .i32 ).output (0 ),
1133
+ ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1134
+ ).output (0 )
1135
+ axis_as_range = ov_opset .range (
1136
+ ov_opset .constant ([0 ], Type .i32 ).output (0 ),
1137
+ x_rank ,
1138
+ ov_opset .constant ([1 ], Type .i32 ).output (0 ),
1139
+ ).output (0 )
1140
+ ov_axis_positive = ov_opset .gather (
1141
+ axis_as_range , ov_axis , ov_opset .constant ([0 ], Type .i32 )
1142
+ ).output (0 )
1143
+ else :
1144
+ ov_axis_positive = ov_axis
1145
+
1146
+ x_sorted = ov_opset .topk (
1147
+ x , k_value , axis , "min" , "value" , stable = True
1148
+ ).output (0 )
1149
+ half_index = ov_opset .floor (
1150
+ ov_opset .divide (k_value , ov_opset .constant ([2 ], Type .i32 )).output (0 )
1151
+ ).output (0 )
1152
+ half_index = ov_opset .convert (half_index , Type .i32 ).output (0 )
1094
1153
x_mod = ov_opset .mod (k_value , ov_opset .constant ([2 ], Type .i32 )).output (0 )
1095
1154
is_even = ov_opset .equal (x_mod , ov_opset .constant ([0 ], Type .i32 )).output (0 )
1096
- med_index_0 = ov_opset .gather (x_sorted , ov_opset .floor (half_index ).output (0 ), axis ).output (0 ) # COME BACK, does it sort out higher dimensions?
1097
- med_index_1 = ov_opset .gather (x_sorted , ov_opset .add (med_index_0 , ov_opset .constant ([1 ], Type .i32 )).output (0 ), axis ).output (0 )
1098
-
1099
- median_odd = med_index_0
1100
- median_even = ov_opset .divide (ov_opset .add (med_index_1 , med_index_0 ).output (0 ), ov_opset .constant ([2 ], Type .i32 ))
1101
-
1155
+
1156
+ med_0 = ov_opset .gather (x_sorted , half_index , ov_axis_positive ).output (0 )
1157
+ med_1 = ov_opset .select (
1158
+ is_even ,
1159
+ ov_opset .gather (
1160
+ x_sorted ,
1161
+ ov_opset .subtract (
1162
+ half_index , ov_opset .constant ([1 ], Type .i32 )
1163
+ ).output (0 ),
1164
+ ov_axis_positive ,
1165
+ ).output (0 ),
1166
+ med_0 ,
1167
+ ).output (0 )
1168
+
1169
+ median_odd = med_0
1170
+ median_even = ov_opset .divide (
1171
+ ov_opset .add (med_1 , med_0 ).output (0 ),
1172
+ ov_opset .constant ([2 ], Type .f32 ),
1173
+ )
1174
+
1102
1175
median_eval = ov_opset .select (is_even , median_even , median_odd )
1103
-
1104
- if keepdims == True :
1105
- if flattened == True :
1106
- median_shape = ov_opset .divide (x_shape_original , x_shape_original ). output ( 0 )
1107
- median_eval = ov_opset . reshape ( median_eval , median_shape , False ). output ( 0 )
1108
- elif int_axis == True :
1109
- median_shape = ov_opset .shape_of ( median_eval ). output ( 0 )
1110
- median_shape = ov_opset . unsqueeze ( median_shape , axis ). output ( 0 )
1111
- median_eval = ov_opset . reshape ( median_eval , median_shape , False ).output (0 )
1176
+
1177
+ if keepdims :
1178
+ if flattened :
1179
+ median_shape = ov_opset .divide (
1180
+ x_shape_original , x_shape_original , "none"
1181
+ ). output ( 0 )
1182
+ median_eval = ov_opset .reshape (
1183
+ median_eval , median_shape , False
1184
+ ).output (0 )
1112
1185
else :
1113
1186
median_eval = ov_opset .unsqueeze (median_eval , ov_axis ).output (0 )
1114
-
1187
+
1115
1188
return OpenVINOKerasTensor (median_eval )
1116
1189
1190
+
1117
1191
def meshgrid (* x , indexing = "xy" ):
1118
1192
raise NotImplementedError (
1119
1193
"`meshgrid` is not supported with openvino backend"
0 commit comments