@@ -2151,9 +2151,8 @@ def right_align(sizes, indices, starts, maxval):
2151
2151
else :
2152
2152
ydiags = ctx .make_node ('Min' , [ydiags_2 .output [0 ], const_ymax ])
2153
2153
2154
- # flatten last dimension of matrix
2154
+ # flatten last dimension of matrix, extract diagonal values
2155
2155
m2 = ctx .make_node ('Reshape' , [m_padded .output [0 ], const_m2_shape ])
2156
-
2157
2156
diags_0 = ctx .make_node ('Concat' , [xdiags .output [0 ], ydiags .output [0 ]], attr = {'axis' : 0 })
2158
2157
diags_1 = ctx .make_node ('Reshape' , [diags_0 .output [0 ], const_neg_one ])
2159
2158
diags_2 = ctx .make_node ('Expand' , [diags_1 .output [0 ], const_gather_shape ])
@@ -2163,11 +2162,11 @@ def compute_out_shape(k0_k1_same=False):
2163
2162
g = ctx .create_new_graph_with_same_config ()
2164
2163
g .parent_graph = ctx
2165
2164
if k0_k1_same :
2166
- outshape = g . make_node ( 'Concat' , [const_partial_shape , maxsize_0 .output [0 ]], attr = { 'axis' : 0 })
2165
+ dims = [const_partial_shape , maxsize_0 .output [0 ]]
2167
2166
else :
2168
- outshape = g . make_node ( 'Concat' , [const_partial_shape , const_neg_one , maxsize_0 .output [0 ]],
2169
- attr = {'axis' : 0 })
2170
- g .add_graph_output (outshape .output [0 ], TensorProto .INT64 , [- 1 ])
2167
+ dims = [const_partial_shape , const_neg_one , maxsize_0 .output [0 ]]
2168
+ out_shape = g . make_node ( 'Concat' , dims , attr = {'axis' : 0 })
2169
+ g .add_graph_output (out_shape .output [0 ], TensorProto .INT64 , [- 1 ])
2171
2170
return g
2172
2171
2173
2172
# if k0==k1, rank of output matrix is 1 less than usual
0 commit comments