@@ -1835,8 +1835,16 @@ class MatrixDiagPartV2V3:
1835
1835
@classmethod
1836
1836
def version_11 (cls , ctx , node , ** kwargs ):
1837
1837
# assemble MatrixDiagPart V2&V3 by looping k diagonals with proper pads
1838
+ const_zero = ctx .make_const (utils .make_name (node .name ) + 'const_zero' , np .array ([0 ]).astype (np .int64 ))
1839
+ const_one = ctx .make_const (utils .make_name (node .name ) + 'const_one' , np .array ([1 ]).astype (np .int64 ))
1840
+ const_two = ctx .make_const (utils .make_name (node .name ) + 'const_two' , np .array ([2 ]).astype (np .int64 ))
1841
+ const_neg_one = ctx .make_const (utils .make_name (node .name ) + 'const_neg_one' , np .array ([- 1 ]).astype (np .int64 ))
1842
+ const_neg_two = ctx .make_const (utils .make_name (node .name ) + 'const_neg_two' , np .array ([- 2 ]).astype (np .int64 ))
1843
+ def normalize ():
1844
+ raw_k = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 }).output [0 ]
1845
+ return ctx .make_node ('Reshape' , [raw_k , const_neg_one .output [0 ]]).output [0 ]
1838
1846
input_tensor = node .input [0 ]
1839
- k = ctx . make_node ( 'Cast' , [ node . input [ 1 ]], attr = { 'to' : TensorProto . INT64 }). output [ 0 ]
1847
+ k = normalize ()
1840
1848
padding = node .input [2 ]
1841
1849
align = 'LEFT_LEFT'
1842
1850
if node .op .op_type == 'MatrixDiagPartV3' :
@@ -1850,12 +1858,7 @@ def version_11(cls, ctx, node, **kwargs):
1850
1858
for out in ctx .find_output_consumers (node .output [0 ]):
1851
1859
if out .op .op_type == 'Identity' :
1852
1860
ctx .set_shape (out .output [0 ], raw_output_shape )
1853
- # define constants
1854
- const_zero = ctx .make_const (utils .make_name (node .name ) + 'const_zero' , np .array ([0 ]).astype (np .int64 ))
1855
- const_one = ctx .make_const (utils .make_name (node .name ) + 'const_one' , np .array ([1 ]).astype (np .int64 ))
1856
- const_two = ctx .make_const (utils .make_name (node .name ) + 'const_two' , np .array ([2 ]).astype (np .int64 ))
1857
- const_neg_one = ctx .make_const (utils .make_name (node .name ) + 'const_neg_one' , np .array ([- 1 ]).astype (np .int64 ))
1858
- const_neg_two = ctx .make_const (utils .make_name (node .name ) + 'const_neg_two' , np .array ([- 2 ]).astype (np .int64 ))
1861
+
1859
1862
# prepare new_shape of input
1860
1863
input_shape = ctx .make_node ('Shape' , [input_tensor ])
1861
1864
shape_input_shape = ctx .make_node ('Shape' , [input_shape .output [0 ]])
@@ -2075,7 +2078,7 @@ def mkconsts(values, dtype=np.int64):
2075
2078
xalign , yalign = align .split ('_' )
2076
2079
2077
2080
# consts
2078
- const_zero_float , const_neg_one_float = mkconsts ([[ 0 ], [ - 1 ] ], np .float32 )
2081
+ const_zero_float , const_neg_one_float = mkconsts ([0 , - 1 ], np .float32 )
2079
2082
const_zero , const_one , const_neg_one , const_neg_two , const_pad_vals , const_t = \
2080
2083
mkconsts ([[0 ], [1 ], [- 1 ], [- 2 ], pads , [- 1 , 1 ]])
2081
2084
const_zero_scalar , const_one_scalar , const_neg_one_scalar = mkconsts ([0 , 1 , - 1 ])
@@ -2098,8 +2101,12 @@ def mkconsts(values, dtype=np.int64):
2098
2101
m2_shape = ctx .make_node ('Concat' , [partial_shape , const_neg_one ], attr = {'axis' : 0 }).output [0 ]
2099
2102
gather_shape = ctx .make_node ('Concat' , [partial_shape , const_one ], attr = {'axis' : 0 }).output [0 ]
2100
2103
2104
+ def normalize ():
2105
+ raw_input1 = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 }).output [0 ]
2106
+ return ctx .make_node ('Reshape' , [raw_input1 , const_neg_one ])
2107
+
2101
2108
# get k0, k1 values. diags to be extracted
2102
- input1 = ctx . make_node ( 'Cast' , [ node . input [ 1 ]], attr = { 'to' : TensorProto . INT64 } )
2109
+ input1 = normalize ( )
2103
2110
k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]]).output [0 ]
2104
2111
k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]]).output [0 ]
2105
2112
k0_scalar = ctx .make_node ('Squeeze' , [k0 ]).output [0 ]
0 commit comments