Skip to content

Commit 8b7c90e

Browse files
[OpenVINO backend] support repeat (#21433)
* [OpenVINO backend] support repeat * add check for np.integer
1 parent b990e54 commit 8b7c90e

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ NumpyDtypeTest::test_multiply
4848
NumpyDtypeTest::test_power
4949
NumpyDtypeTest::test_prod
5050
NumpyDtypeTest::test_quantile
51-
NumpyDtypeTest::test_repeat
5251
NumpyDtypeTest::test_roll
5352
NumpyDtypeTest::test_round
5453
NumpyDtypeTest::test_searchsorted
@@ -109,7 +108,6 @@ NumpyOneInputOpsCorrectnessTest::test_pad_uint8_constant_2
109108
NumpyOneInputOpsCorrectnessTest::test_pad_int32_constant_2
110109
NumpyOneInputOpsCorrectnessTest::test_prod
111110
NumpyOneInputOpsCorrectnessTest::test_real
112-
NumpyOneInputOpsCorrectnessTest::test_repeat
113111
NumpyOneInputOpsCorrectnessTest::test_reshape
114112
NumpyOneInputOpsCorrectnessTest::test_roll
115113
NumpyOneInputOpsCorrectnessTest::test_round

keras/src/backend/openvino/numpy.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1311,7 +1311,62 @@ def reciprocal(x):
13111311

13121312

13131313
def repeat(x, repeats, axis=None):
1314-
raise NotImplementedError("`repeat` is not supported with openvino backend")
1314+
x = get_ov_output(x)
1315+
const_0 = ov_opset.constant(0, Type.i32)
1316+
const_1 = ov_opset.constant(1, Type.i32)
1317+
const_neg_1 = ov_opset.constant([-1], Type.i32)
1318+
1319+
if axis is not None and axis < 0:
1320+
axis += len(x.get_partial_shape())
1321+
1322+
if axis is None:
1323+
x = ov_opset.reshape(x, const_neg_1, special_zero=False)
1324+
axis = 0
1325+
1326+
if isinstance(repeats, (int, np.integer)) or (
1327+
isinstance(repeats, np.ndarray)
1328+
and repeats.ndim == 1
1329+
and repeats.size == 1
1330+
):
1331+
repeats_val = (
1332+
int(repeats)
1333+
if isinstance(repeats, (np.integer, np.ndarray))
1334+
else repeats
1335+
)
1336+
dim_len = ov_opset.gather(
1337+
ov_opset.shape_of(x, Type.i32),
1338+
ov_opset.constant([axis], Type.i32),
1339+
const_0,
1340+
)
1341+
dim_len = ov_opset.squeeze(dim_len, ov_opset.constant([0], Type.i32))
1342+
idx_range = ov_opset.range(
1343+
const_0, dim_len, const_1, output_type=Type.i32
1344+
)
1345+
idx_range = ov_opset.unsqueeze(idx_range, const_1)
1346+
tiled = ov_opset.tile(
1347+
idx_range, ov_opset.constant([1, repeats_val], Type.i32)
1348+
)
1349+
idx = ov_opset.reshape(tiled, const_neg_1, special_zero=False)
1350+
result = ov_opset.gather(x, idx, ov_opset.constant(axis, Type.i32))
1351+
return OpenVINOKerasTensor(result.output(0))
1352+
repeats_tensor = get_ov_output(repeats)
1353+
cumsum = ov_opset.cumsum(repeats_tensor, const_0)
1354+
total = ov_opset.reduce_sum(
1355+
repeats_tensor, ov_opset.constant([0], Type.i32), keep_dims=False
1356+
)
1357+
total = ov_opset.convert(total, Type.i32)
1358+
out_indices = ov_opset.range(const_0, total, const_1, output_type=Type.i32)
1359+
cumsum_unsq = ov_opset.unsqueeze(cumsum, const_0)
1360+
out_indices_unsq = ov_opset.unsqueeze(out_indices, const_1)
1361+
cumsum_unsq = ov_opset.convert(cumsum_unsq, Type.i32)
1362+
mask = ov_opset.greater_equal(out_indices_unsq, cumsum_unsq)
1363+
gather_indices = ov_opset.reduce_sum(
1364+
ov_opset.convert(mask, Type.i32), ov_opset.constant([1], Type.i32)
1365+
)
1366+
result = ov_opset.gather(
1367+
x, gather_indices, ov_opset.constant(axis, Type.i32)
1368+
)
1369+
return OpenVINOKerasTensor(result.output(0))
13151370

13161371

13171372
def reshape(x, newshape):

0 commit comments

Comments
 (0)