@@ -2076,15 +2076,15 @@ def mkconst(npval, desc):
2076
2076
xalign , yalign = align .split ('_' )
2077
2077
2078
2078
# consts
2079
- const_neg_one = mkconst (np .array ([- 1 ]).astype (np .int64 ), 'const_neg_one' )
2080
- const_pad_vals = mkconst (pads , 'pads' )
2081
2079
const_zero = mkconst (np .array ([0 ], np .int64 ), 'const_zero_dtype' )
2082
2080
const_one = mkconst (np .array ([1 ], np .int64 ), 'const_one_dtype' )
2081
+ const_neg_one = mkconst (np .array ([- 1 ]).astype (np .int64 ), 'const_neg_one' )
2082
+ const_pad_vals = mkconst (pads , 'pads' )
2083
2083
const_t = mkconst (np .array ([- 1 , 1 ], np .int64 ), 'const_t' )
2084
2084
const_xlen = mkconst (np .array ([xlen ], np .int64 ), 'const_xlen' )
2085
2085
const_ylen = mkconst (np .array ([ylen ], np .int64 ), 'const_ylen' )
2086
- const_stride = mkconst (np .array ([xlenp + 1 ], np .int64 ), 'const_stride' )
2087
2086
const_xlenp = mkconst (np .array ([xlenp ], np .int64 ), 'const_xlenp' )
2087
+ const_stride = mkconst (np .array ([xlenp + 1 ], np .int64 ), 'const_stride' )
2088
2088
const_minxy = mkconst (np .array ([min (xlen , ylen )], np .int64 ), 'const_minxy' )
2089
2089
const_xmax = mkconst (np .array ([xlen * xlenp + xlenp - 1 ], np .int64 ), 'const_xmax' )
2090
2090
const_ymax = mkconst (np .array ([xlenp * ylen - 1 ], np .int64 ), 'const_ymax' )
@@ -2099,13 +2099,13 @@ def mkconst(npval, desc):
2099
2099
k1_scalar = ctx .make_node ('Squeeze' , [k1 .output [0 ]])
2100
2100
m_padded = ctx .make_node ('Pad' , [m , const_pad_vals , node .input [2 ]])
2101
2101
2102
- # starting index for super diagonals
2102
+ # starting indexes for super diagonals
2103
2103
xstart_0 = ctx .make_node ('Max' , [const_zero , k0 .output [0 ]])
2104
2104
xstart_1 = ctx .make_node ('Add' , [xstart_0 .output [0 ], const_neg_one ])
2105
2105
xstart_2 = ctx .make_node ('Range' , [k1_scalar .output [0 ], xstart_1 .output [0 ], const_neg_one ])
2106
2106
xstart = ctx .make_node ('Reshape' , [xstart_2 .output [0 ], const_t ])
2107
2107
2108
- # starting indices for sub diagonals
2108
+ # starting indexes for sub diagonals
2109
2109
ystart_0 = ctx .make_node ('Min' , [const_neg_one , k1 .output [0 ]])
2110
2110
ystart_0_scalar = ctx .make_node ('Squeeze' , [ystart_0 .output [0 ]])
2111
2111
ystart_1 = ctx .make_node ('Add' , [k0 .output [0 ], const_neg_one ])
@@ -2159,18 +2159,19 @@ def right_align(sizes, indices, starts, maxval):
2159
2159
diags_2 = ctx .make_node ('Expand' , [diags_1 .output [0 ], const_gather_shape ])
2160
2160
diags = ctx .make_node ('GatherElements' , [m2 .output [0 ], diags_2 .output [0 ]], attr = {'axis' : - 1 })
2161
2161
2162
- # if k0=k1, rank of output matrix is 1 less than usual.
2163
- # hence, need 'If' to compute right output matrix shape
2164
2162
def compute_out_shape (k0_k1_same = False ):
2165
2163
g = ctx .create_new_graph_with_same_config ()
2166
2164
g .parent_graph = ctx
2167
2165
if k0_k1_same :
2168
2166
outshape = g .make_node ('Concat' , [const_partial_shape , maxsize_0 .output [0 ]], attr = {'axis' : 0 })
2169
2167
else :
2170
- outshape = g .make_node ('Concat' , [const_partial_shape , const_neg_one , maxsize_0 .output [0 ]], attr = {'axis' : 0 })
2168
+ outshape = g .make_node ('Concat' , [const_partial_shape , const_neg_one , maxsize_0 .output [0 ]],
2169
+ attr = {'axis' : 0 })
2171
2170
g .add_graph_output (outshape .output [0 ], TensorProto .INT64 , [- 1 ])
2172
2171
return g
2173
2172
2173
+ # if k0==k1, rank of output matrix is 1 less than usual
2174
+ # hence, using 'If' to compute right output matrix shape
2174
2175
k0_k1_same = ctx .make_node ('Equal' , [k1 .output [0 ], k0 .output [0 ]])
2175
2176
if_node = ctx .make_node ('If' , [k0_k1_same .output [0 ]])
2176
2177
if_node .set_body_graph_as_attr ('then_branch' , compute_out_shape (True ))
0 commit comments