@@ -1565,15 +1565,94 @@ def trace(x, offset=0, axis1=0, axis2=1):
1565
1565
1566
1566
1567
1567
def tri (N , M = None , k = 0 , dtype = None ):
1568
- raise NotImplementedError ("`tri` is not supported with openvino backend" )
1568
+ if M is None :
1569
+ M = N
1570
+ if dtype is None :
1571
+ dtype = "float32"
1572
+
1573
+ ov_dtype = OPENVINO_DTYPES [dtype ]
1574
+
1575
+ def ensure_constant (value , default_type = Type .i32 ):
1576
+ if isinstance (value , (int , float )):
1577
+ return ov_opset .constant (value , default_type )
1578
+ elif hasattr (value , "get_element_type" ):
1579
+ if value .get_element_type () != Type .i32 :
1580
+ value = ov_opset .convert (value , Type .i32 )
1581
+ return ov_opset .squeeze (value , ov_opset .constant ([0 ], Type .i32 ))
1582
+ else :
1583
+ return ov_opset .constant (value , default_type )
1584
+
1585
+ N_const = ensure_constant (N )
1586
+ M_const = ensure_constant (M )
1587
+ k_const = ensure_constant (k )
1588
+
1589
+ # Create row and column indices
1590
+ row_range = ov_opset .range (
1591
+ ov_opset .constant (0 , Type .i32 ),
1592
+ N_const ,
1593
+ ov_opset .constant (1 , Type .i32 ),
1594
+ output_type = Type .i32 ,
1595
+ )
1596
+ col_range = ov_opset .range (
1597
+ ov_opset .constant (0 , Type .i32 ),
1598
+ M_const ,
1599
+ ov_opset .constant (1 , Type .i32 ),
1600
+ output_type = Type .i32 ,
1601
+ )
1602
+
1603
+ # Reshape indices for broadcasting
1604
+ row_idx = ov_opset .unsqueeze (row_range , ov_opset .constant ([1 ], Type .i32 ))
1605
+ col_idx = ov_opset .unsqueeze (col_range , ov_opset .constant ([0 ], Type .i32 ))
1606
+
1607
+ mask = ov_opset .less_equal (col_idx , ov_opset .add (row_idx , k_const ))
1608
+
1609
+ if ov_dtype == Type .boolean :
1610
+ result = mask
1611
+ else :
1612
+ result = ov_opset .convert (mask , ov_dtype )
1613
+
1614
+ return OpenVINOKerasTensor (result .output (0 ))
1569
1615
1570
1616
1571
1617
def tril (x , k = 0 ):
1572
- raise NotImplementedError ("`tril` is not supported with openvino backend" )
1618
+ x = get_ov_output (x )
1619
+ ov_type = x .get_element_type ()
1620
+ shape = ov_opset .shape_of (x , Type .i32 )
1621
+ zero_const = ov_opset .constant (0 , Type .i32 )
1622
+ minus2 = ov_opset .constant ([- 2 ], Type .i32 )
1623
+ minus1 = ov_opset .constant ([- 1 ], Type .i32 )
1624
+ M = ov_opset .squeeze (ov_opset .gather (shape , minus2 , zero_const ), zero_const )
1625
+ N = ov_opset .squeeze (ov_opset .gather (shape , minus1 , zero_const ), zero_const )
1626
+ tri_mask = tri (M , N , k = k , dtype = "bool" ).output
1627
+ mask = ov_opset .convert (tri_mask , ov_type )
1628
+ if ov_type == Type .boolean :
1629
+ out = ov_opset .logical_and (x , mask )
1630
+ else :
1631
+ out = ov_opset .multiply (x , mask )
1632
+ return OpenVINOKerasTensor (out .output (0 ))
1573
1633
1574
1634
1575
1635
def triu (x , k = 0 ):
1576
- raise NotImplementedError ("`triu` is not supported with openvino backend" )
1636
+ x = get_ov_output (x )
1637
+ ov_type = x .get_element_type ()
1638
+ shape = ov_opset .shape_of (x , Type .i32 )
1639
+ zero_const = ov_opset .constant (0 , Type .i32 )
1640
+ minus2 = ov_opset .constant ([- 2 ], Type .i32 )
1641
+ minus1 = ov_opset .constant ([- 1 ], Type .i32 )
1642
+ M = ov_opset .squeeze (ov_opset .gather (shape , minus2 , zero_const ), zero_const )
1643
+ N = ov_opset .squeeze (ov_opset .gather (shape , minus1 , zero_const ), zero_const )
1644
+ tri_mask = tri (M , N , k = k - 1 , dtype = "bool" ).output
1645
+ if ov_type == Type .boolean :
1646
+ mask = ov_opset .logical_not (tri_mask )
1647
+ else :
1648
+ const_one = ov_opset .constant (1 , ov_type )
1649
+ converted_mask = ov_opset .convert (tri_mask , ov_type )
1650
+ mask = ov_opset .subtract (const_one , converted_mask )
1651
+ if ov_type == Type .boolean :
1652
+ out = ov_opset .logical_and (x , mask )
1653
+ else :
1654
+ out = ov_opset .multiply (x , mask )
1655
+ return OpenVINOKerasTensor (out .output (0 ))
1577
1656
1578
1657
1579
1658
def vdot (x1 , x2 ):
0 commit comments