Skip to content

Commit 2615b5b

Browse files
[OpenVINO backend] support tri, triu, and tril (#21408)
1 parent d9ca374 commit 2615b5b

File tree

2 files changed

+82
-8
lines changed

2 files changed

+82
-8
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
NumPyTestRot90
22
NumpyArrayCreateOpsCorrectnessTest::test_eye
3-
NumpyArrayCreateOpsCorrectnessTest::test_tri
43
NumpyDtypeTest::test_absolute_bool
54
NumpyDtypeTest::test_add_
65
NumpyDtypeTest::test_all
@@ -62,7 +61,6 @@ NumpyDtypeTest::test_swapaxes
6261
NumpyDtypeTest::test_tensordot_
6362
NumpyDtypeTest::test_tile
6463
NumpyDtypeTest::test_trace
65-
NumpyDtypeTest::test_tri
6664
NumpyDtypeTest::test_trunc
6765
NumpyDtypeTest::test_unravel
6866
NumpyDtypeTest::test_var
@@ -128,9 +126,6 @@ NumpyOneInputOpsCorrectnessTest::test_swapaxes
128126
NumpyOneInputOpsCorrectnessTest::test_tile
129127
NumpyOneInputOpsCorrectnessTest::test_trace
130128
NumpyOneInputOpsCorrectnessTest::test_transpose
131-
NumpyOneInputOpsCorrectnessTest::test_tril
132-
NumpyOneInputOpsCorrectnessTest::test_tril_in_layer
133-
NumpyOneInputOpsCorrectnessTest::test_triu
134129
NumpyOneInputOpsCorrectnessTest::test_trunc
135130
NumpyOneInputOpsCorrectnessTest::test_unravel_index
136131
NumpyOneInputOpsCorrectnessTest::test_var

keras/src/backend/openvino/numpy.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,15 +1565,94 @@ def trace(x, offset=0, axis1=0, axis2=1):
15651565

15661566

15671567
def tri(N, M=None, k=0, dtype=None):
1568-
raise NotImplementedError("`tri` is not supported with openvino backend")
1568+
if M is None:
1569+
M = N
1570+
if dtype is None:
1571+
dtype = "float32"
1572+
1573+
ov_dtype = OPENVINO_DTYPES[dtype]
1574+
1575+
def ensure_constant(value, default_type=Type.i32):
1576+
if isinstance(value, (int, float)):
1577+
return ov_opset.constant(value, default_type)
1578+
elif hasattr(value, "get_element_type"):
1579+
if value.get_element_type() != Type.i32:
1580+
value = ov_opset.convert(value, Type.i32)
1581+
return ov_opset.squeeze(value, ov_opset.constant([0], Type.i32))
1582+
else:
1583+
return ov_opset.constant(value, default_type)
1584+
1585+
N_const = ensure_constant(N)
1586+
M_const = ensure_constant(M)
1587+
k_const = ensure_constant(k)
1588+
1589+
# Create row and column indices
1590+
row_range = ov_opset.range(
1591+
ov_opset.constant(0, Type.i32),
1592+
N_const,
1593+
ov_opset.constant(1, Type.i32),
1594+
output_type=Type.i32,
1595+
)
1596+
col_range = ov_opset.range(
1597+
ov_opset.constant(0, Type.i32),
1598+
M_const,
1599+
ov_opset.constant(1, Type.i32),
1600+
output_type=Type.i32,
1601+
)
1602+
1603+
# Reshape indices for broadcasting
1604+
row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))
1605+
col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))
1606+
1607+
mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k_const))
1608+
1609+
if ov_dtype == Type.boolean:
1610+
result = mask
1611+
else:
1612+
result = ov_opset.convert(mask, ov_dtype)
1613+
1614+
return OpenVINOKerasTensor(result.output(0))
15691615

15701616

15711617
def tril(x, k=0):
1572-
raise NotImplementedError("`tril` is not supported with openvino backend")
1618+
x = get_ov_output(x)
1619+
ov_type = x.get_element_type()
1620+
shape = ov_opset.shape_of(x, Type.i32)
1621+
zero_const = ov_opset.constant(0, Type.i32)
1622+
minus2 = ov_opset.constant([-2], Type.i32)
1623+
minus1 = ov_opset.constant([-1], Type.i32)
1624+
M = ov_opset.squeeze(ov_opset.gather(shape, minus2, zero_const), zero_const)
1625+
N = ov_opset.squeeze(ov_opset.gather(shape, minus1, zero_const), zero_const)
1626+
tri_mask = tri(M, N, k=k, dtype="bool").output
1627+
mask = ov_opset.convert(tri_mask, ov_type)
1628+
if ov_type == Type.boolean:
1629+
out = ov_opset.logical_and(x, mask)
1630+
else:
1631+
out = ov_opset.multiply(x, mask)
1632+
return OpenVINOKerasTensor(out.output(0))
15731633

15741634

15751635
def triu(x, k=0):
1576-
raise NotImplementedError("`triu` is not supported with openvino backend")
1636+
x = get_ov_output(x)
1637+
ov_type = x.get_element_type()
1638+
shape = ov_opset.shape_of(x, Type.i32)
1639+
zero_const = ov_opset.constant(0, Type.i32)
1640+
minus2 = ov_opset.constant([-2], Type.i32)
1641+
minus1 = ov_opset.constant([-1], Type.i32)
1642+
M = ov_opset.squeeze(ov_opset.gather(shape, minus2, zero_const), zero_const)
1643+
N = ov_opset.squeeze(ov_opset.gather(shape, minus1, zero_const), zero_const)
1644+
tri_mask = tri(M, N, k=k - 1, dtype="bool").output
1645+
if ov_type == Type.boolean:
1646+
mask = ov_opset.logical_not(tri_mask)
1647+
else:
1648+
const_one = ov_opset.constant(1, ov_type)
1649+
converted_mask = ov_opset.convert(tri_mask, ov_type)
1650+
mask = ov_opset.subtract(const_one, converted_mask)
1651+
if ov_type == Type.boolean:
1652+
out = ov_opset.logical_and(x, mask)
1653+
else:
1654+
out = ov_opset.multiply(x, mask)
1655+
return OpenVINOKerasTensor(out.output(0))
15771656

15781657

15791658
def vdot(x1, x2):

0 commit comments

Comments
 (0)