Skip to content

Commit 55263ca

Browse files
[OpenVINO backend] support take_along_axis (#21447)
1 parent df481e9 commit 55263ca

File tree

2 files changed

+67
-4
lines changed

2 files changed

+67
-4
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ NumpyDtypeTest::test_std
5959
NumpyDtypeTest::test_subtract
6060
NumpyDtypeTest::test_sum
6161
NumpyDtypeTest::test_swapaxes
62-
NumpyDtypeTest::test_take_along_axis
6362
NumpyDtypeTest::test_tensordot_
6463
NumpyDtypeTest::test_tile
6564
NumpyDtypeTest::test_trace
@@ -151,7 +150,6 @@ NumpyTwoInputOpsCorrectnessTest::test_inner
151150
NumpyTwoInputOpsCorrectnessTest::test_linspace
152151
NumpyTwoInputOpsCorrectnessTest::test_logspace
153152
NumpyTwoInputOpsCorrectnessTest::test_quantile
154-
NumpyTwoInputOpsCorrectnessTest::test_take_along_axis
155153
NumpyTwoInputOpsCorrectnessTest::test_tensordot
156154
NumpyTwoInputOpsCorrectnessTest::test_vdot
157155
NumpyOneInputOpsDynamicShapeTest::test_angle

keras/src/backend/openvino/numpy.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,10 +1458,75 @@ def take(x, indices, axis=None):
14581458

14591459

14601460
def take_along_axis(x, indices, axis=None):
1461-
raise NotImplementedError(
1462-
"`take_along_axis` is not supported with openvino backend"
1461+
x = get_ov_output(x)
1462+
indices = get_ov_output(indices)
1463+
1464+
if axis is None:
1465+
target_shape = ov_opset.constant([-1], dtype=Type.i32).output(0)
1466+
x_flat = ov_opset.reshape(x, target_shape, False).output(0)
1467+
indices_flat = ov_opset.reshape(indices, target_shape, False).output(0)
1468+
result = ov_opset.gather_elements(x_flat, indices_flat, 0).output(0)
1469+
return OpenVINOKerasTensor(result)
1470+
1471+
x_rank = len(x.get_partial_shape())
1472+
if axis < 0:
1473+
axis += x_rank
1474+
1475+
x_shape = ov_opset.shape_of(x, Type.i32).output(0)
1476+
indices_shape = ov_opset.shape_of(indices, Type.i32).output(0)
1477+
1478+
zero_const = ov_opset.constant(0, dtype=Type.i32).output(0)
1479+
axis_index = ov_opset.constant([axis], dtype=Type.i32).output(0)
1480+
1481+
# Fix negative indices
1482+
dim_size = ov_opset.squeeze(
1483+
ov_opset.gather(x_shape, axis_index, zero_const).output(0), zero_const
1484+
).output(0)
1485+
zero_scalar = ov_opset.constant(0, indices.get_element_type()).output(0)
1486+
is_neg = ov_opset.less(indices, zero_scalar).output(0)
1487+
dim_size_cast = ov_opset.convert(
1488+
dim_size, indices.get_element_type()
1489+
).output(0)
1490+
indices = ov_opset.select(
1491+
is_neg, ov_opset.add(indices, dim_size_cast).output(0), indices
1492+
).output(0)
1493+
indices = ov_opset.convert(indices, Type.i32).output(0)
1494+
1495+
x_target_parts, indices_target_parts = [], []
1496+
1497+
for i in range(x_rank):
1498+
dim_idx = ov_opset.constant([i], dtype=Type.i32).output(0)
1499+
x_dim = ov_opset.gather(x_shape, dim_idx, zero_const).output(0)
1500+
indices_dim = ov_opset.gather(
1501+
indices_shape, dim_idx, zero_const
1502+
).output(0)
1503+
1504+
if i == axis:
1505+
# For axis dimension: keep original dimensions
1506+
x_target_parts.append(x_dim)
1507+
indices_target_parts.append(indices_dim)
1508+
else:
1509+
# For other dimensions: use maximum for broadcasting
1510+
max_dim = ov_opset.maximum(x_dim, indices_dim).output(0)
1511+
x_target_parts.append(max_dim)
1512+
indices_target_parts.append(max_dim)
1513+
1514+
x_target_shape = ov_opset.concat(x_target_parts, axis=0).output(0)
1515+
indices_target_shape = ov_opset.concat(indices_target_parts, axis=0).output(
1516+
0
14631517
)
14641518

1519+
# Broadcast to target shapes and gather elements
1520+
x_broadcasted = ov_opset.broadcast(x, x_target_shape).output(0)
1521+
indices_broadcasted = ov_opset.broadcast(
1522+
indices, indices_target_shape
1523+
).output(0)
1524+
result = ov_opset.gather_elements(
1525+
x_broadcasted, indices_broadcasted, axis
1526+
).output(0)
1527+
1528+
return OpenVINOKerasTensor(result)
1529+
14651530

14661531
def tan(x):
14671532
x = get_ov_output(x)

0 commit comments

Comments
 (0)