|
15 | 15 | from keras.src.backend.openvino.core import convert_to_tensor
|
16 | 16 | from keras.src.backend.openvino.core import get_ov_output
|
17 | 17 | from keras.src.backend.openvino.core import ov_to_keras_type
|
18 |
| -# --- Chnage for issue 29115 --- |
19 |
| -import openvino.runtime.opset14 as ov |
20 |
| - |
21 |
| -from .core import OpenVINOKerasTensor # already present in file |
22 |
| -from .core import _convert_to_node, _wrap_node # adapt if your file names differ |
| 18 | +from .core import _convert_to_node |
23 | 19 |
|
24 | 20 | def diagonal(x, offset=0, axis1=0, axis2=1):
|
25 | 21 | """OpenVINO backend decomposition for keras.ops.diagonal."""
|
26 | 22 | x_node = _convert_to_node(x) # -> ov.Node
|
27 |
| - offset_const = ov.constant(int(offset), dtype="i64") |
| 23 | + offset_const = ov_opset.constant(int(offset), dtype="i64") |
28 | 24 |
|
29 | 25 | # rank & normalize axes
|
30 |
| - shape = ov.shape_of(x_node) # i64 vector |
31 |
| - rank = ov.shape_of(shape) # scalar i64 (len of shape) |
32 |
| - rank_val = ov.squeeze(rank) # [] -> scalar |
33 |
| - axis1_node = ov.mod(ov.add(ov.constant(int(axis1), dtype="i64"), rank_val), rank_val) |
34 |
| - axis2_node = ov.mod(ov.add(ov.constant(int(axis2), dtype="i64"), rank_val), rank_val) |
| 26 | + shape = ov_opset.shape_of(x_node) # i64 vector |
| 27 | + rank = ov_opset.shape_of(shape) # scalar i64 (len of shape) |
| 28 | + rank_val = ov_opset.squeeze(rank) # [] -> scalar |
| 29 | + axis1_node = ov_opset.floor_mod( |
| 30 | + ov_opset.add(ov_opset.constant(int(axis1), dtype="i64"), rank_val), rank_val |
| 31 | + ) |
| 32 | + axis2_node = ov_opset.floor_mod( |
| 33 | + ov_opset.add(ov_opset.constant(int(axis2), dtype="i64"), rank_val), rank_val |
| 34 | + ) |
35 | 35 |
|
36 | 36 | # If axis1 == axis2, behavior should match numpy error; Keras tests don't hit this,
|
37 | 37 | # so we skip explicit assert to keep graph-friendly.
|
38 | 38 |
|
39 | 39 | # Build permutation to move axis1, axis2 to the end
|
40 | 40 | # perm = [all axes except axis1/axis2 in order] + [axis1, axis2]
|
41 |
| - arange = ov.range(ov.constant(0, dtype="i64"), rank_val, ov.constant(1, dtype="i64")) |
42 |
| - mask1 = ov.equal(arange, axis1_node) |
43 |
| - mask2 = ov.equal(arange, axis2_node) |
44 |
| - not12 = ov.logical_not(ov.logical_or(mask1, mask2)) |
45 |
| - others = ov.squeeze(ov.non_zero(not12), [1]) # gather positions != axis1, axis2 |
46 |
| - perm = ov.concat([others, ov.reshape(axis1_node, [1]), ov.reshape(axis2_node, [1])], 0) |
47 |
| - |
48 |
| - x_perm = ov.transpose(x_node, perm) |
49 |
| - permuted_shape = ov.shape_of(x_perm) |
50 |
| - # last two dims |
51 |
| - last2 = ov.gather(permuted_shape, ov.constant([-2, -1], dtype="i64"), ov.constant(0, dtype="i64")) |
52 |
| - d1 = ov.gather(permuted_shape, ov.constant([-2], dtype="i64"), ov.constant(0, dtype="i64")) |
53 |
| - d2 = ov.gather(permuted_shape, ov.constant([-1], dtype="i64"), ov.constant(0, dtype="i64")) |
54 |
| - d1 = ov.squeeze(d1) # scalar |
55 |
| - d2 = ov.squeeze(d2) # scalar |
| 41 | + arange = ov_opset.range( |
| 42 | + ov_opset.constant(0, dtype="i64"), rank_val, ov_opset.constant(1, dtype="i64") |
| 43 | + ) |
| 44 | + mask1 = ov_opset.equal(arange, axis1_node) |
| 45 | + mask2 = ov_opset.equal(arange, axis2_node) |
| 46 | + not12 = ov_opset.logical_not(ov_opset.logical_or(mask1, mask2)) |
| 47 | + others = ov_opset.squeeze( |
| 48 | + ov_opset.non_zero(not12), [1] |
| 49 | + ) # gather positions != axis1, axis2 |
| 50 | + perm = ov_opset.concat( |
| 51 | + [others, ov_opset.reshape(axis1_node, [1]), ov_opset.reshape(axis2_node, [1])], 0 |
| 52 | + ) |
| 53 | + |
| 54 | + x_perm = ov_opset.transpose(x_node, perm) |
| 55 | + permuted_shape = ov_opset.shape_of(x_perm) |
| 56 | + d1 = ov_opset.gather( |
| 57 | + permuted_shape, |
| 58 | + ov_opset.constant([-2], dtype="i64"), |
| 59 | + ov_opset.constant(0, dtype="i64"), |
| 60 | + ) |
| 61 | + d2 = ov_opset.gather( |
| 62 | + permuted_shape, |
| 63 | + ov_opset.constant([-1], dtype="i64"), |
| 64 | + ov_opset.constant(0, dtype="i64"), |
| 65 | + ) |
| 66 | + d1 = ov_opset.squeeze(d1) # scalar |
| 67 | + d2 = ov_opset.squeeze(d2) # scalar |
56 | 68 |
|
57 | 69 | # start1 = max(0, offset), start2 = max(0, -offset)
|
58 |
| - zero = ov.constant(0, dtype="i64") |
59 |
| - start1 = ov.maximum(zero, offset_const) |
60 |
| - start2 = ov.maximum(zero, ov.negative(offset_const)) |
| 70 | + zero = ov_opset.constant(0, dtype="i64") |
| 71 | + start1 = ov_opset.maximum(zero, offset_const) |
| 72 | + start2 = ov_opset.maximum(zero, ov_opset.negative(offset_const)) |
61 | 73 |
|
62 | 74 | # L = min(d1 - start1, d2 - start2)
|
63 |
| - l1 = ov.subtract(d1, start1) |
64 |
| - l2 = ov.subtract(d2, start2) |
65 |
| - L = ov.minimum(l1, l2) |
| 75 | + l1 = ov_opset.subtract(d1, start1) |
| 76 | + l2 = ov_opset.subtract(d2, start2) |
| 77 | + L = ov_opset.minimum(l1, l2) |
66 | 78 |
|
67 | 79 | # r = range(0, L, 1) -> shape [L]
|
68 |
| - r = ov.range(zero, L, ov.constant(1, dtype="i64")) |
69 |
| - idx_row = ov.add(r, start1) |
70 |
| - idx_col = ov.add(r, start2) |
71 |
| - idx_row = ov.unsqueeze(idx_row, ov.constant(1, dtype="i64")) # [L,1] |
72 |
| - idx_col = ov.unsqueeze(idx_col, ov.constant(1, dtype="i64")) # [L,1] |
73 |
| - diag_idx = ov.concat([idx_row, idx_col], 1) # [L,2] |
| 80 | + r = ov_opset.range(zero, L, ov_opset.constant(1, dtype="i64")) |
| 81 | + idx_row = ov_opset.add(r, start1) |
| 82 | + idx_col = ov_opset.add(r, start2) |
| 83 | + idx_row = ov_opset.unsqueeze( |
| 84 | + idx_row, ov_opset.constant(1, dtype="i64") |
| 85 | + ) # [L,1] |
| 86 | + idx_col = ov_opset.unsqueeze( |
| 87 | + idx_col, ov_opset.constant(1, dtype="i64") |
| 88 | + ) # [L,1] |
| 89 | + diag_idx = ov_opset.concat([idx_row, idx_col], 1) # [L,2] |
74 | 90 |
|
75 | 91 | # Broadcast indices to batch dims: target shape = (*batch, L, 2)
|
76 | 92 | # batch_rank = rank(x) - 2
|
77 |
| - two = ov.constant(2, dtype="i64") |
78 |
| - batch_rank = ov.subtract(rank_val, two) |
| 93 | + two = ov_opset.constant(2, dtype="i64") |
| 94 | + batch_rank = ov_opset.subtract(rank_val, two) |
79 | 95 | # build target shape: concat(permuted_shape[:batch_rank], [L, 2])
|
80 |
| - batch_shape = ov.slice(permuted_shape, ov.constant([0], dtype="i64"), |
81 |
| - ov.reshape(batch_rank, [1]), ov.constant([1], dtype="i64")) |
82 |
| - target_shape = ov.concat([batch_shape, ov.reshape(L, [1]), ov.constant([2], dtype="i64")], 0) |
83 |
| - bcast_idx = ov.broadcast(diag_idx, target_shape) |
| 96 | + batch_shape = ov_opset.strided_slice( |
| 97 | + permuted_shape, |
| 98 | + begin=ov_opset.constant([0], dtype="i64"), |
| 99 | + end=ov_opset.reshape(batch_rank, [1]), |
| 100 | + strides=ov_opset.constant([1], dtype="i64"), |
| 101 | + begin_mask=[0], |
| 102 | + end_mask=[0], |
| 103 | + ) |
| 104 | + target_shape = ov_opset.concat( |
| 105 | + [batch_shape, ov_opset.reshape(L, [1]), ov_opset.constant([2], dtype="i64")], 0 |
| 106 | + ) |
| 107 | + bcast_idx = ov_opset.broadcast(diag_idx, target_shape) |
84 | 108 |
|
85 | 109 | # GatherND with batch_dims = batch_rank
|
86 |
| - gathered = ov.gather_nd(x_perm, bcast_idx, batch_rank) |
| 110 | + gathered = ov_opset.gather_nd(x_perm, bcast_idx, batch_rank) |
87 | 111 |
|
88 | 112 | return OpenVINOKerasTensor(gathered)
|
89 | 113 |
|
|
0 commit comments