Skip to content

Commit 98f982e

Browse files
committed
Finished median(), linted
1 parent 80e38da commit 98f982e

File tree

1 file changed

+124
-50
lines changed

1 file changed

+124
-50
lines changed

keras/src/backend/openvino/numpy.py

Lines changed: 124 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,74 +1046,148 @@ def maximum(x1, x2):
10461046

10471047

10481048
def median(x, axis=None, keepdims=False):
1049+
if np.isscalar(x):
1050+
x = get_ov_output(x)
1051+
return OpenVINOKerasTensor(x)
1052+
10491053
x = get_ov_output(x)
1050-
x_shape_original = ov_opset.shape_of(x).output(0)
1051-
1054+
x_type = x.get_element_type()
1055+
if x_type == Type.boolean or x_type.is_integral():
1056+
x = ov_opset.convert(x, Type.f32).output(0)
1057+
x_type = x.get_element_type()
1058+
x_shape_original = ov_opset.shape_of(x, Type.i32).output(0)
1059+
10521060
if axis is None:
10531061
flatten_shape = ov_opset.constant([-1], Type.i32).output(0)
10541062
x = ov_opset.reshape(x, flatten_shape, False).output(0)
10551063
axis = 0
1064+
ov_axis = get_ov_output(axis)
10561065
flattened = True
1057-
int_axis = False
1058-
x_shape = ov_opset.shape_of(x).output(0)
1059-
k_value = ov_opset.convert(x_shape, Type.i32).output(0)
1066+
k_value = ov_opset.gather(
1067+
ov_opset.shape_of(x, Type.i32).output(0),
1068+
ov_opset.constant([0], Type.i32).output(0),
1069+
ov_axis,
1070+
).output(0)
10601071
elif isinstance(axis, int):
10611072
flattened = False
1062-
int_axis = True
1063-
ov_axis = ov_opset.constant(axis, Type.i32).output(0)
1064-
x_shape = ov_opset.shape_of(x).output(0)
1065-
k_value = ov_opset.convert(ov_opset.gather(x_shape, ov_axis, ov_opset.constant([0], Type.i32).output(0)).output(0), Type.i32).output(0)
1073+
ov_axis = get_ov_output(axis)
1074+
x_shape = ov_opset.shape_of(x, Type.i32).output(0)
1075+
k_value = ov_opset.gather(
1076+
x_shape, ov_axis, ov_opset.constant([0], Type.i32).output(0)
1077+
).output(0)
10661078
else:
1067-
# axis = (2, 1)
10681079
flattened = False
1069-
int_axis = False
1070-
ov_axis = ov_opset.constant(axis, Type.i32).output(0) # (2, 1)
1071-
x_rank = ov_opset.shape_of(x_shape_original).output(0) # 4
1072-
axis_range = ov_opset.range(ov_opset.constant([0], Type.i32).output(0), x_rank, ov_opset.constant([1], Type.i32).output(0)).output(0)
1073-
axis_compare = ov_opset.equal(ov_opset.unsqueeze(ov_axis, 1).output(0), ov_opset.unsqueeze(axis_range, 0).output(0)).output(0)
1074-
mask_remove = ov_opset.reduce_logical_or(axis_compare, ov_opset.constant([0], Type.i32).output(0)).output(0)
1075-
mask_keep = ov_opset.logical_not(mask_remove).output(0)
1076-
nz = ov_opset.non_zero(mask_keep, "i32").output(0)
1077-
indices_keep = ov_opset.squeeze(nz, [0]).output(0)
1078-
axis_range = ov_opset.gather(axis_range, indices_keep, ov_opset.constant([0], Type.i32).output(0)).output(0) # (0, 3)
1079-
axis_range = ov_opset.concat([axis_range, ov_axis], ov_opset.constant([0], Type.i32).output(0)).output(0) # (0, 3, 2, 1)
1080-
x = ov_opset.transpose(x, axis_range).output(0) # x = (d0, d3, d2, d1)
1081-
1082-
flat_rank = ov_opset.subtract(x_rank, ov_opset.constant([1], Type.i32)).output(0)
1083-
flatten_shape = ov_opset.constant([0], shape=flat_rank, type_info=Type.i32).output(0)
1084-
flatten_shape = ov_opset.scatter_elements_update(flatten_shape, ov_opset.constant([-1], Type.i32).output(0), [-1], [0], "sum")
1085-
1086-
x = ov_opset.reshape(x, flatten_shape, True).output(0) # x = (d0, d3, d2*d1)
1080+
ov_axis = get_ov_output(axis)
1081+
x_rank = ov_opset.gather(
1082+
ov_opset.shape_of(x_shape_original, Type.i32).output(0),
1083+
ov_opset.constant([0], Type.i32).output(0),
1084+
).output(0)
1085+
axis_as_range = ov_opset.range(
1086+
ov_opset.constant([0], Type.i32).output(0),
1087+
x_rank,
1088+
ov_opset.constant([1], Type.i32).output(0),
1089+
).output(0)
1090+
axis_compare = ov_opset.not_equal(
1091+
ov_opset.unsqueeze(axis_as_range, 1).output(0),
1092+
ov_opset.unsqueeze(ov_axis, 0).output(0),
1093+
"NUMPY",
1094+
).output(0)
1095+
keep_axes = ov_opset.reduce_logical_or(
1096+
axis_compare, ov_opset.constant([1], Type.i32).output(0)
1097+
).output(0)
1098+
nz = ov_opset.non_zero(keep_axes, Type.i32).output(0)
1099+
keep_axes = ov_opset.reduce_sum(
1100+
nz, ov_opset.constant([1], Type.i32).output(0)
1101+
).output(0)
1102+
reordered_axes = ov_opset.concat(
1103+
[keep_axes, ov_axis], ov_opset.constant([0], Type.i32).output(0)
1104+
).output(0)
1105+
x = ov_opset.transpose(x, reordered_axes).output(0)
1106+
1107+
flat_rank = ov_opset.subtract(
1108+
x_rank, ov_opset.constant([1], Type.i32)
1109+
).output(0)
1110+
flatten_shape = ov_opset.broadcast(
1111+
ov_opset.constant([0], Type.i32).output(0), flat_rank
1112+
).output(0)
1113+
flatten_shape = ov_opset.scatter_elements_update(
1114+
flatten_shape,
1115+
ov_opset.constant([-1], Type.i32).output(0),
1116+
ov_opset.constant([-1], Type.i32).output(0),
1117+
0,
1118+
"sum",
1119+
).output(0)
1120+
1121+
x = ov_opset.reshape(x, flatten_shape, True).output(0)
10871122
axis = -1
1088-
x_shape = ov_opset.shape_of(x).output(0)
1089-
k_value = ov_opset.gather(x_shape, ov_opset.constant([-1], Type.i32).output(0), ov_opset.constant([0], Type.i32).output(0)).output(0)
1090-
k_value = ov_opset.convert(k_value, Type.i32).output(0)
1091-
1092-
x_sorted = ov_opset.topk(x, k_value, axis, 'min', 'value', stable=True).output(0)
1093-
half_index = ov_opset.divide(k_value, ov_opset.constant([2], Type.i32)).output(0)
1123+
x_shape = ov_opset.shape_of(x, Type.i32).output(0)
1124+
k_value = ov_opset.gather(
1125+
x_shape,
1126+
ov_opset.constant([-1], Type.i32).output(0),
1127+
ov_opset.constant([0], Type.i32).output(0),
1128+
).output(0)
1129+
1130+
if axis < 0:
1131+
x_rank = ov_opset.gather(
1132+
ov_opset.shape_of(x, Type.i32).output(0),
1133+
ov_opset.constant([0], Type.i32).output(0),
1134+
).output(0)
1135+
axis_as_range = ov_opset.range(
1136+
ov_opset.constant([0], Type.i32).output(0),
1137+
x_rank,
1138+
ov_opset.constant([1], Type.i32).output(0),
1139+
).output(0)
1140+
ov_axis_positive = ov_opset.gather(
1141+
axis_as_range, ov_axis, ov_opset.constant([0], Type.i32)
1142+
).output(0)
1143+
else:
1144+
ov_axis_positive = ov_axis
1145+
1146+
x_sorted = ov_opset.topk(
1147+
x, k_value, axis, "min", "value", stable=True
1148+
).output(0)
1149+
half_index = ov_opset.floor(
1150+
ov_opset.divide(k_value, ov_opset.constant([2], Type.i32)).output(0)
1151+
).output(0)
1152+
half_index = ov_opset.convert(half_index, Type.i32).output(0)
10941153
x_mod = ov_opset.mod(k_value, ov_opset.constant([2], Type.i32)).output(0)
10951154
is_even = ov_opset.equal(x_mod, ov_opset.constant([0], Type.i32)).output(0)
1096-
med_index_0 = ov_opset.gather(x_sorted, ov_opset.floor(half_index).output(0), axis).output(0) # COME BACK, does it sort out higher dimensions?
1097-
med_index_1 = ov_opset.gather(x_sorted, ov_opset.add(med_index_0, ov_opset.constant([1], Type.i32)).output(0), axis).output(0)
1098-
1099-
median_odd = med_index_0
1100-
median_even = ov_opset.divide(ov_opset.add(med_index_1, med_index_0).output(0), ov_opset.constant([2], Type.i32))
1101-
1155+
1156+
med_0 = ov_opset.gather(x_sorted, half_index, ov_axis_positive).output(0)
1157+
med_1 = ov_opset.select(
1158+
is_even,
1159+
ov_opset.gather(
1160+
x_sorted,
1161+
ov_opset.subtract(
1162+
half_index, ov_opset.constant([1], Type.i32)
1163+
).output(0),
1164+
ov_axis_positive,
1165+
).output(0),
1166+
med_0,
1167+
).output(0)
1168+
1169+
median_odd = med_0
1170+
median_even = ov_opset.divide(
1171+
ov_opset.add(med_1, med_0).output(0),
1172+
ov_opset.constant([2], Type.f32),
1173+
)
1174+
11021175
median_eval = ov_opset.select(is_even, median_even, median_odd)
1103-
1104-
if keepdims == True:
1105-
if flattened == True:
1106-
median_shape = ov_opset.divide(x_shape_original, x_shape_original).output(0)
1107-
median_eval = ov_opset.reshape(median_eval, median_shape, False).output(0)
1108-
elif int_axis == True:
1109-
median_shape = ov_opset.shape_of(median_eval).output(0)
1110-
median_shape = ov_opset.unsqueeze(median_shape, axis).output(0)
1111-
median_eval = ov_opset.reshape(median_eval, median_shape, False).output(0)
1176+
1177+
if keepdims:
1178+
if flattened:
1179+
median_shape = ov_opset.divide(
1180+
x_shape_original, x_shape_original, "none"
1181+
).output(0)
1182+
median_eval = ov_opset.reshape(
1183+
median_eval, median_shape, False
1184+
).output(0)
11121185
else:
11131186
median_eval = ov_opset.unsqueeze(median_eval, ov_axis).output(0)
1114-
1187+
11151188
return OpenVINOKerasTensor(median_eval)
11161189

1190+
11171191
def meshgrid(*x, indexing="xy"):
11181192
raise NotImplementedError(
11191193
"`meshgrid` is not supported with openvino backend"

0 commit comments

Comments
 (0)