Skip to content

Commit ea0a40f

Browse files
gmni comit
1 parent 0f77687 commit ea0a40f

File tree

1 file changed

+69
-45
lines changed

1 file changed

+69
-45
lines changed

keras/src/backend/openvino/numpy.py

Lines changed: 69 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,75 +15,99 @@
1515
from keras.src.backend.openvino.core import convert_to_tensor
1616
from keras.src.backend.openvino.core import get_ov_output
1717
from keras.src.backend.openvino.core import ov_to_keras_type
18-
# --- Chnage for issue 29115 ---
19-
import openvino.runtime.opset14 as ov
20-
21-
from .core import OpenVINOKerasTensor # already present in file
22-
from .core import _convert_to_node, _wrap_node # adapt if your file names differ
18+
from .core import _convert_to_node
2319

2420
def diagonal(x, offset=0, axis1=0, axis2=1):
2521
"""OpenVINO backend decomposition for keras.ops.diagonal."""
2622
x_node = _convert_to_node(x) # -> ov.Node
27-
offset_const = ov.constant(int(offset), dtype="i64")
23+
offset_const = ov_opset.constant(int(offset), dtype="i64")
2824

2925
# rank & normalize axes
30-
shape = ov.shape_of(x_node) # i64 vector
31-
rank = ov.shape_of(shape) # scalar i64 (len of shape)
32-
rank_val = ov.squeeze(rank) # [] -> scalar
33-
axis1_node = ov.mod(ov.add(ov.constant(int(axis1), dtype="i64"), rank_val), rank_val)
34-
axis2_node = ov.mod(ov.add(ov.constant(int(axis2), dtype="i64"), rank_val), rank_val)
26+
shape = ov_opset.shape_of(x_node) # i64 vector
27+
rank = ov_opset.shape_of(shape) # scalar i64 (len of shape)
28+
rank_val = ov_opset.squeeze(rank) # [] -> scalar
29+
axis1_node = ov_opset.floor_mod(
30+
ov_opset.add(ov_opset.constant(int(axis1), dtype="i64"), rank_val), rank_val
31+
)
32+
axis2_node = ov_opset.floor_mod(
33+
ov_opset.add(ov_opset.constant(int(axis2), dtype="i64"), rank_val), rank_val
34+
)
3535

3636
# If axis1 == axis2, behavior should match numpy error; Keras tests don't hit this,
3737
# so we skip explicit assert to keep graph-friendly.
3838

3939
# Build permutation to move axis1, axis2 to the end
4040
# perm = [all axes except axis1/axis2 in order] + [axis1, axis2]
41-
arange = ov.range(ov.constant(0, dtype="i64"), rank_val, ov.constant(1, dtype="i64"))
42-
mask1 = ov.equal(arange, axis1_node)
43-
mask2 = ov.equal(arange, axis2_node)
44-
not12 = ov.logical_not(ov.logical_or(mask1, mask2))
45-
others = ov.squeeze(ov.non_zero(not12), [1]) # gather positions != axis1, axis2
46-
perm = ov.concat([others, ov.reshape(axis1_node, [1]), ov.reshape(axis2_node, [1])], 0)
47-
48-
x_perm = ov.transpose(x_node, perm)
49-
permuted_shape = ov.shape_of(x_perm)
50-
# last two dims
51-
last2 = ov.gather(permuted_shape, ov.constant([-2, -1], dtype="i64"), ov.constant(0, dtype="i64"))
52-
d1 = ov.gather(permuted_shape, ov.constant([-2], dtype="i64"), ov.constant(0, dtype="i64"))
53-
d2 = ov.gather(permuted_shape, ov.constant([-1], dtype="i64"), ov.constant(0, dtype="i64"))
54-
d1 = ov.squeeze(d1) # scalar
55-
d2 = ov.squeeze(d2) # scalar
41+
arange = ov_opset.range(
42+
ov_opset.constant(0, dtype="i64"), rank_val, ov_opset.constant(1, dtype="i64")
43+
)
44+
mask1 = ov_opset.equal(arange, axis1_node)
45+
mask2 = ov_opset.equal(arange, axis2_node)
46+
not12 = ov_opset.logical_not(ov_opset.logical_or(mask1, mask2))
47+
others = ov_opset.squeeze(
48+
ov_opset.non_zero(not12), [1]
49+
) # gather positions != axis1, axis2
50+
perm = ov_opset.concat(
51+
[others, ov_opset.reshape(axis1_node, [1]), ov_opset.reshape(axis2_node, [1])], 0
52+
)
53+
54+
x_perm = ov_opset.transpose(x_node, perm)
55+
permuted_shape = ov_opset.shape_of(x_perm)
56+
d1 = ov_opset.gather(
57+
permuted_shape,
58+
ov_opset.constant([-2], dtype="i64"),
59+
ov_opset.constant(0, dtype="i64"),
60+
)
61+
d2 = ov_opset.gather(
62+
permuted_shape,
63+
ov_opset.constant([-1], dtype="i64"),
64+
ov_opset.constant(0, dtype="i64"),
65+
)
66+
d1 = ov_opset.squeeze(d1) # scalar
67+
d2 = ov_opset.squeeze(d2) # scalar
5668

5769
# start1 = max(0, offset), start2 = max(0, -offset)
58-
zero = ov.constant(0, dtype="i64")
59-
start1 = ov.maximum(zero, offset_const)
60-
start2 = ov.maximum(zero, ov.negative(offset_const))
70+
zero = ov_opset.constant(0, dtype="i64")
71+
start1 = ov_opset.maximum(zero, offset_const)
72+
start2 = ov_opset.maximum(zero, ov_opset.negative(offset_const))
6173

6274
# L = min(d1 - start1, d2 - start2)
63-
l1 = ov.subtract(d1, start1)
64-
l2 = ov.subtract(d2, start2)
65-
L = ov.minimum(l1, l2)
75+
l1 = ov_opset.subtract(d1, start1)
76+
l2 = ov_opset.subtract(d2, start2)
77+
L = ov_opset.minimum(l1, l2)
6678

6779
# r = range(0, L, 1) -> shape [L]
68-
r = ov.range(zero, L, ov.constant(1, dtype="i64"))
69-
idx_row = ov.add(r, start1)
70-
idx_col = ov.add(r, start2)
71-
idx_row = ov.unsqueeze(idx_row, ov.constant(1, dtype="i64")) # [L,1]
72-
idx_col = ov.unsqueeze(idx_col, ov.constant(1, dtype="i64")) # [L,1]
73-
diag_idx = ov.concat([idx_row, idx_col], 1) # [L,2]
80+
r = ov_opset.range(zero, L, ov_opset.constant(1, dtype="i64"))
81+
idx_row = ov_opset.add(r, start1)
82+
idx_col = ov_opset.add(r, start2)
83+
idx_row = ov_opset.unsqueeze(
84+
idx_row, ov_opset.constant(1, dtype="i64")
85+
) # [L,1]
86+
idx_col = ov_opset.unsqueeze(
87+
idx_col, ov_opset.constant(1, dtype="i64")
88+
) # [L,1]
89+
diag_idx = ov_opset.concat([idx_row, idx_col], 1) # [L,2]
7490

7591
# Broadcast indices to batch dims: target shape = (*batch, L, 2)
7692
# batch_rank = rank(x) - 2
77-
two = ov.constant(2, dtype="i64")
78-
batch_rank = ov.subtract(rank_val, two)
93+
two = ov_opset.constant(2, dtype="i64")
94+
batch_rank = ov_opset.subtract(rank_val, two)
7995
# build target shape: concat(permuted_shape[:batch_rank], [L, 2])
80-
batch_shape = ov.slice(permuted_shape, ov.constant([0], dtype="i64"),
81-
ov.reshape(batch_rank, [1]), ov.constant([1], dtype="i64"))
82-
target_shape = ov.concat([batch_shape, ov.reshape(L, [1]), ov.constant([2], dtype="i64")], 0)
83-
bcast_idx = ov.broadcast(diag_idx, target_shape)
96+
batch_shape = ov_opset.strided_slice(
97+
permuted_shape,
98+
begin=ov_opset.constant([0], dtype="i64"),
99+
end=ov_opset.reshape(batch_rank, [1]),
100+
strides=ov_opset.constant([1], dtype="i64"),
101+
begin_mask=[0],
102+
end_mask=[0],
103+
)
104+
target_shape = ov_opset.concat(
105+
[batch_shape, ov_opset.reshape(L, [1]), ov_opset.constant([2], dtype="i64")], 0
106+
)
107+
bcast_idx = ov_opset.broadcast(diag_idx, target_shape)
84108

85109
# GatherND with batch_dims = batch_rank
86-
gathered = ov.gather_nd(x_perm, bcast_idx, batch_rank)
110+
gathered = ov_opset.gather_nd(x_perm, bcast_idx, batch_rank)
87111

88112
return OpenVINOKerasTensor(gathered)
89113

0 commit comments

Comments
 (0)