@@ -1644,6 +1644,27 @@ def version_11(cls, ctx, node, **kwargs):
1644
1644
1645
1645
@tf_op ("MatrixBandPart" )
1646
1646
class MatrixBandPart :
1647
+ @classmethod
1648
+ def _apply_mask_and_transform (cls , ctx , node , mask ):
1649
+ shapes = node .output_shapes
1650
+ dtypes = node .output_dtypes
1651
+ dtype = ctx .get_dtype (node .input [0 ])
1652
+ data = node .input [0 ]
1653
+ if dtype == TensorProto .BOOL :
1654
+ # bool is not supported for 'Mul', so convert mask and input supported dtype
1655
+ mask = ctx .make_node ("Cast" , inputs = mask .output , attr = {'to' : TensorProto .FLOAT }).output [0 ]
1656
+ data = ctx .make_node ("Cast" , [data ], attr = {'to' : TensorProto .FLOAT }).output [0 ]
1657
+ result = ctx .make_node (op_type = "Mul" , inputs = [mask , data ], shapes = shapes , dtypes = [TensorProto .FLOAT ])
1658
+ ctx .remove_node (node .name )
1659
+ ctx .make_node ("Cast" , inputs = result .output , attr = {'to' : dtype },
1660
+ name = node .name , outputs = node .output , dtypes = dtypes )
1661
+ else :
1662
+ mask = ctx .make_node (op_type = "Cast" , inputs = mask .output , attr = {"to" : dtype }).output [0 ]
1663
+ ctx .remove_node (node .name )
1664
+ ctx .make_node (op_type = "Mul" , inputs = [mask , data ],
1665
+ name = node .name , outputs = node .output , shapes = shapes ,
1666
+ dtypes = dtypes )
1667
+
1647
1668
@classmethod
1648
1669
def version_7 (cls , ctx , node , ** kwargs ):
1649
1670
# T output = MatrixBandPart(T input, int num_lower, int num_upper)
@@ -1714,14 +1735,7 @@ def version_7(cls, ctx, node, **kwargs):
1714
1735
mask_matrix = ctx .make_node (op_type = "Transpose" , inputs = cast1 .output )
1715
1736
else :
1716
1737
mask_matrix = squeeze
1717
- cast2 = ctx .make_node (op_type = "Cast" , inputs = mask_matrix .output ,
1718
- attr = {"to" : ctx .get_dtype (node .input [0 ])})
1719
- shapes = node .output_shapes
1720
- dtypes = node .output_dtypes
1721
- ctx .remove_node (node .name )
1722
- ctx .make_node (op_type = "Mul" , inputs = [cast2 .output [0 ], node .input [0 ]],
1723
- name = node .name , outputs = node .output , shapes = shapes ,
1724
- dtypes = dtypes )
1738
+ cls ._apply_mask_and_transform (ctx , node , mask_matrix )
1725
1739
1726
1740
@classmethod
1727
1741
def version_11 (cls , ctx , node , ** kwargs ):
@@ -1739,17 +1753,12 @@ def version_11(cls, ctx, node, **kwargs):
1739
1753
{'data' : whole_shape , 'starts' : [- 2 ], 'ends' : [int_max_val ], 'axes' : [0 ]})
1740
1754
if num_lower_const == 0 and num_upper_const == 0 :
1741
1755
if rank == 2 :
1742
- identity_node = ctx .make_node ("EyeLike" , [data ]). output [ 0 ]
1756
+ identity_node = ctx .make_node ("EyeLike" , [data ])
1743
1757
else :
1744
1758
zero_tensor = helper .make_tensor ("value" , dtype , dims = [1 ], vals = [0 ])
1745
1759
const_of_shape = ctx .make_node ("ConstantOfShape" , [shape ], attr = {'value' : zero_tensor }).output [0 ]
1746
- identity_node = ctx .make_node ("EyeLike" , [const_of_shape ]).output [0 ]
1747
- shapes = node .output_shapes
1748
- dtypes = node .output_dtypes
1749
- ctx .remove_node (node .name )
1750
- ctx .make_node (op_type = "Mul" , inputs = [identity_node , data ],
1751
- name = node .name , outputs = node .output , shapes = shapes ,
1752
- dtypes = dtypes )
1760
+ identity_node = ctx .make_node ("EyeLike" , [const_of_shape ])
1761
+ cls ._apply_mask_and_transform (ctx , node , identity_node )
1753
1762
return
1754
1763
zero_const = ctx .make_const (utils .make_name ("zero" ), np .array (0 , np .int64 )).output [0 ]
1755
1764
one_const = ctx .make_const (utils .make_name ("one" ), np .array (1 , np .int64 )).output [0 ]
@@ -1771,14 +1780,14 @@ def version_11(cls, ctx, node, **kwargs):
1771
1780
if ctx .get_dtype (num_upper ) != TensorProto .INT64 :
1772
1781
num_upper = ctx .make_node ("Cast" , [num_upper ], attr = {'to' : TensorProto .INT64 }).output [0 ]
1773
1782
greater = ctx .make_node ("Greater" , [idx_diff , num_upper ]).output [0 ]
1774
- less_or_equal = ctx .make_node ("Not" , [greater ]). output [ 0 ]
1783
+ less_or_equal = ctx .make_node ("Not" , [greater ])
1775
1784
conditions .append (less_or_equal )
1776
1785
if num_lower_const is None or num_lower_const >= 0 :
1777
1786
if ctx .get_dtype (num_lower ) != TensorProto .INT64 :
1778
1787
num_lower = ctx .make_node ("Cast" , [num_lower ], attr = {'to' : TensorProto .INT64 }).output [0 ]
1779
1788
num_lower_neg = ctx .make_node ("Neg" , [num_lower ]).output [0 ]
1780
1789
greater = ctx .make_node ("Greater" , [num_lower_neg , idx_diff ]).output [0 ]
1781
- less_or_equal = ctx .make_node ("Not" , [greater ]). output [ 0 ]
1790
+ less_or_equal = ctx .make_node ("Not" , [greater ])
1782
1791
conditions .append (less_or_equal )
1783
1792
if len (conditions ) == 0 :
1784
1793
node .type = "Identity"
@@ -1787,14 +1796,8 @@ def version_11(cls, ctx, node, **kwargs):
1787
1796
if len (conditions ) == 1 :
1788
1797
cond = conditions [0 ]
1789
1798
if len (conditions ) == 2 :
1790
- cond = ctx .make_node ("And" , conditions ).output [0 ]
1791
- mask = ctx .make_node ("Cast" , [cond ], attr = {'to' : ctx .get_dtype (data )}).output [0 ]
1792
- shapes = node .output_shapes
1793
- dtypes = node .output_dtypes
1794
- ctx .remove_node (node .name )
1795
- ctx .make_node (op_type = "Mul" , inputs = [mask , data ],
1796
- name = node .name , outputs = node .output , shapes = shapes ,
1797
- dtypes = dtypes )
1799
+ cond = ctx .make_node ("And" , inputs = [c .output [0 ] for c in conditions ])
1800
+ cls ._apply_mask_and_transform (ctx , node , cond )
1798
1801
1799
1802
1800
1803
def _make_softmax_cross_entropy_with_logits (ctx , label , logit , tf_ori_node ):
0 commit comments