Skip to content

Commit 4415fcc

Browse files
Add meshgrid op support in Keras OpenVINO backend (#21600)
* Added Numpy.meshgrid operation for openvino backend * Performed the changes suggested by the BOT * Performed the suggested changes * Minor changes * Minor Changes
1 parent e988786 commit 4415fcc

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ NumpyDtypeTest::test_matmul_
4545
NumpyDtypeTest::test_max
4646
NumpyDtypeTest::test_mean
4747
NumpyDtypeTest::test_median
48-
NumpyDtypeTest::test_meshgrid
4948
NumpyDtypeTest::test_minimum_python_types
5049
NumpyDtypeTest::test_multiply
5150
NumpyDtypeTest::test_power
@@ -102,7 +101,6 @@ NumpyOneInputOpsCorrectnessTest::test_logaddexp
102101
NumpyOneInputOpsCorrectnessTest::test_max
103102
NumpyOneInputOpsCorrectnessTest::test_mean
104103
NumpyOneInputOpsCorrectnessTest::test_median
105-
NumpyOneInputOpsCorrectnessTest::test_meshgrid
106104
NumpyOneInputOpsCorrectnessTest::test_pad_float16_constant_2
107105
NumpyOneInputOpsCorrectnessTest::test_pad_float32_constant_2
108106
NumpyOneInputOpsCorrectnessTest::test_pad_float64_constant_2

keras/src/backend/openvino/numpy.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,9 +1105,47 @@ def median(x, axis=None, keepdims=False):
11051105

11061106

11071107
def meshgrid(*x, indexing="xy"):
1108-
raise NotImplementedError(
1109-
"`meshgrid` is not supported with openvino backend"
1110-
)
1108+
if len(x) < 2:
1109+
raise ValueError(
1110+
"meshgrid requires at least 2 input arrays. "
1111+
f"Received: {len(x)} input array(s)."
1112+
)
1113+
if indexing not in ("xy", "ij"):
1114+
raise ValueError("indexing must be either 'xy' or 'ij'")
1115+
1116+
tensors = [get_ov_output(xi) for xi in x]
1117+
n = len(tensors)
1118+
1119+
shapes = [
1120+
ov_opset.shape_of(t, Type.i64).output(0) for t in tensors
1121+
] # each is [Ni]
1122+
one = ov_opset.constant([1], Type.i64).output(0)
1123+
1124+
if indexing == "xy":
1125+
shape_list = [shapes[1], shapes[0]] + shapes[2:]
1126+
out_shape = ov_opset.concat(shape_list, axis=0).output(0)
1127+
else:
1128+
out_shape = ov_opset.concat(shapes, axis=0).output(0)
1129+
1130+
outputs = []
1131+
for i, t in enumerate(tensors):
1132+
reshape_parts = [one] * n
1133+
if indexing == "xy":
1134+
if i == 0:
1135+
reshape_parts[1] = shapes[0]
1136+
elif i == 1:
1137+
reshape_parts[0] = shapes[1]
1138+
else:
1139+
reshape_parts[i] = shapes[i]
1140+
else:
1141+
reshape_parts[i] = shapes[i]
1142+
1143+
reshape_shape = ov_opset.concat(reshape_parts, axis=0).output(0)
1144+
reshaped = ov_opset.reshape(t, reshape_shape, False).output(0)
1145+
broadcasted = ov_opset.broadcast(reshaped, out_shape).output(0)
1146+
outputs.append(OpenVINOKerasTensor(broadcasted))
1147+
1148+
return outputs
11111149

11121150

11131151
def min(x, axis=None, keepdims=False, initial=None):

0 commit comments

Comments
 (0)