Skip to content

Commit 96802c5

Browse files
Update numpy.py
1 parent 1635157 commit 96802c5

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

keras/src/backend/openvino/numpy.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import openvino.opset14 as ov_opset
33
from openvino import Type
4+
from openvino.runtime import opset13 as ov
45

56
from keras.src.backend import config
67
from keras.src.backend.common import dtypes
@@ -15,13 +16,9 @@
1516
from keras.src.backend.openvino.core import convert_to_tensor
1617
from keras.src.backend.openvino.core import get_ov_output
1718
from keras.src.backend.openvino.core import ov_to_keras_type
18-
import numpy as np
19-
from openvino.runtime import opset13 as ov
20-
2119

2220

2321
def diagonal(x, offset=0, axis1=0, axis2=1):
24-
"""OpenVINO backend decomposition for keras.ops.diagonal."""
2522
x_node = ov.constant(x) # -> ov.Node
2623
offset_const = ov_opset.constant(int(offset), dtype="i64")
2724

@@ -30,15 +27,17 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
3027
rank = ov_opset.shape_of(shape) # scalar i64 (len of shape)
3128
rank_val = ov_opset.squeeze(rank) # [] -> scalar
3229
axis1_node = ov_opset.floor_mod(
33-
ov_opset.add(ov_opset.constant(int(axis1), dtype="i64"), rank_val), rank_val
30+
ov_opset.add(ov_opset.constant(int(axis1), dtype="i64"), rank_val),
31+
rank_val,
3432
)
3533
axis2_node = ov_opset.floor_mod(
36-
ov_opset.add(ov_opset.constant(int(axis2), dtype="i64"), rank_val), rank_val
34+
ov_opset.add(ov_opset.constant(int(axis2), dtype="i64"), rank_val),
35+
rank_val,
3736
)
38-
39-
4037
arange = ov_opset.range(
41-
ov_opset.constant(0, dtype="i64"), rank_val, ov_opset.constant(1, dtype="i64")
38+
ov_opset.constant(0, dtype="i64"),
39+
rank_val,
40+
ov_opset.constant(1, dtype="i64"),
4241
)
4342
mask1 = ov_opset.equal(arange, axis1_node)
4443
mask2 = ov_opset.equal(arange, axis2_node)
@@ -47,9 +46,14 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
4746
ov_opset.non_zero(not12), [1]
4847
) # gather positions != axis1, axis2
4948
perm = ov_opset.concat(
50-
[others, ov_opset.reshape(axis1_node, [1]), ov_opset.reshape(axis2_node, [1])], 0
49+
[
50+
others,
51+
ov_opset.reshape(axis1_node, [1]),
52+
ov_opset.reshape(axis2_node, [1]),
53+
],
54+
0,
5155
)
52-
56+
5357
x_perm = ov_opset.transpose(x_node, perm)
5458
permuted_shape = ov_opset.shape_of(x_perm)
5559
d1 = ov_opset.gather(
@@ -101,7 +105,12 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
101105
end_mask=[0],
102106
)
103107
target_shape = ov_opset.concat(
104-
[batch_shape, ov_opset.reshape(L, [1]), ov_opset.constant([2], dtype="i64")], 0
108+
[
109+
batch_shape,
110+
ov_opset.reshape(L, [1]),
111+
ov_opset.constant([2], dtype="i64"),
112+
],
113+
0,
105114
)
106115
bcast_idx = ov_opset.broadcast(diag_idx, target_shape)
107116

@@ -110,8 +119,6 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
110119

111120
return OpenVINOKerasTensor(gathered)
112121

113-
114-
115122
def add(x1, x2):
116123
element_type = None
117124
if isinstance(x1, OpenVINOKerasTensor):
@@ -771,8 +778,6 @@ def deg2rad(x):
771778
def diag(x, k=0):
772779
raise NotImplementedError("`diag` is not supported with openvino backend")
773780

774-
775-
776781
def diff(a, n=1, axis=-1):
777782
if n == 0:
778783
return OpenVINOKerasTensor(get_ov_output(a))

0 commit comments

Comments
 (0)