|
15 | 15 | from keras.src.backend.openvino.core import convert_to_tensor
|
16 | 16 | from keras.src.backend.openvino.core import get_ov_output
|
17 | 17 | from keras.src.backend.openvino.core import ov_to_keras_type
|
| 18 | +# --- Chnage for issue 29115 --- |
| 19 | +import openvino.runtime.opset14 as ov |
| 20 | + |
| 21 | +from .core import OpenVINOKerasTensor # already present in file |
| 22 | +from .core import _convert_to_node, _wrap_node # adapt if your file names differ |
| 23 | + |
| 24 | +def diagonal(x, offset=0, axis1=0, axis2=1): |
| 25 | + """OpenVINO backend decomposition for keras.ops.diagonal.""" |
| 26 | + x_node = _convert_to_node(x) # -> ov.Node |
| 27 | + offset_const = ov.constant(int(offset), dtype="i64") |
| 28 | + |
| 29 | + # rank & normalize axes |
| 30 | + shape = ov.shape_of(x_node) # i64 vector |
| 31 | + rank = ov.shape_of(shape) # scalar i64 (len of shape) |
| 32 | + rank_val = ov.squeeze(rank) # [] -> scalar |
| 33 | + axis1_node = ov.mod(ov.add(ov.constant(int(axis1), dtype="i64"), rank_val), rank_val) |
| 34 | + axis2_node = ov.mod(ov.add(ov.constant(int(axis2), dtype="i64"), rank_val), rank_val) |
| 35 | + |
| 36 | + # If axis1 == axis2, behavior should match numpy error; Keras tests don't hit this, |
| 37 | + # so we skip explicit assert to keep graph-friendly. |
| 38 | + |
| 39 | + # Build permutation to move axis1, axis2 to the end |
| 40 | + # perm = [all axes except axis1/axis2 in order] + [axis1, axis2] |
| 41 | + arange = ov.range(ov.constant(0, dtype="i64"), rank_val, ov.constant(1, dtype="i64")) |
| 42 | + mask1 = ov.equal(arange, axis1_node) |
| 43 | + mask2 = ov.equal(arange, axis2_node) |
| 44 | + not12 = ov.logical_not(ov.logical_or(mask1, mask2)) |
| 45 | + others = ov.squeeze(ov.non_zero(not12), [1]) # gather positions != axis1, axis2 |
| 46 | + perm = ov.concat([others, ov.reshape(axis1_node, [1]), ov.reshape(axis2_node, [1])], 0) |
| 47 | + |
| 48 | + x_perm = ov.transpose(x_node, perm) |
| 49 | + permuted_shape = ov.shape_of(x_perm) |
| 50 | + # last two dims |
| 51 | + last2 = ov.gather(permuted_shape, ov.constant([-2, -1], dtype="i64"), ov.constant(0, dtype="i64")) |
| 52 | + d1 = ov.gather(permuted_shape, ov.constant([-2], dtype="i64"), ov.constant(0, dtype="i64")) |
| 53 | + d2 = ov.gather(permuted_shape, ov.constant([-1], dtype="i64"), ov.constant(0, dtype="i64")) |
| 54 | + d1 = ov.squeeze(d1) # scalar |
| 55 | + d2 = ov.squeeze(d2) # scalar |
| 56 | + |
| 57 | + # start1 = max(0, offset), start2 = max(0, -offset) |
| 58 | + zero = ov.constant(0, dtype="i64") |
| 59 | + start1 = ov.maximum(zero, offset_const) |
| 60 | + start2 = ov.maximum(zero, ov.negative(offset_const)) |
| 61 | + |
| 62 | + # L = min(d1 - start1, d2 - start2) |
| 63 | + l1 = ov.subtract(d1, start1) |
| 64 | + l2 = ov.subtract(d2, start2) |
| 65 | + L = ov.minimum(l1, l2) |
| 66 | + |
| 67 | + # r = range(0, L, 1) -> shape [L] |
| 68 | + r = ov.range(zero, L, ov.constant(1, dtype="i64")) |
| 69 | + idx_row = ov.add(r, start1) |
| 70 | + idx_col = ov.add(r, start2) |
| 71 | + idx_row = ov.unsqueeze(idx_row, ov.constant(1, dtype="i64")) # [L,1] |
| 72 | + idx_col = ov.unsqueeze(idx_col, ov.constant(1, dtype="i64")) # [L,1] |
| 73 | + diag_idx = ov.concat([idx_row, idx_col], 1) # [L,2] |
| 74 | + |
| 75 | + # Broadcast indices to batch dims: target shape = (*batch, L, 2) |
| 76 | + # batch_rank = rank(x) - 2 |
| 77 | + two = ov.constant(2, dtype="i64") |
| 78 | + batch_rank = ov.subtract(rank_val, two) |
| 79 | + # build target shape: concat(permuted_shape[:batch_rank], [L, 2]) |
| 80 | + batch_shape = ov.slice(permuted_shape, ov.constant([0], dtype="i64"), |
| 81 | + ov.reshape(batch_rank, [1]), ov.constant([1], dtype="i64")) |
| 82 | + target_shape = ov.concat([batch_shape, ov.reshape(L, [1]), ov.constant([2], dtype="i64")], 0) |
| 83 | + bcast_idx = ov.broadcast(diag_idx, target_shape) |
| 84 | + |
| 85 | + # GatherND with batch_dims = batch_rank |
| 86 | + gathered = ov.gather_nd(x_perm, bcast_idx, batch_rank) |
| 87 | + |
| 88 | + return OpenVINOKerasTensor(gathered) |
| 89 | + |
18 | 90 |
|
19 | 91 |
|
20 | 92 | def add(x1, x2):
|
|
0 commit comments