Skip to content

Commit 110a269

Browse files
committed
Minor comments
1 parent c60ff32 commit 110a269

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,15 +2076,15 @@ def mkconst(npval, desc):
20762076
xalign, yalign = align.split('_')
20772077

20782078
# consts
2079-
const_neg_one = mkconst(np.array([-1]).astype(np.int64), 'const_neg_one')
2080-
const_pad_vals = mkconst(pads, 'pads')
20812079
const_zero = mkconst(np.array([0], np.int64), 'const_zero_dtype')
20822080
const_one = mkconst(np.array([1], np.int64), 'const_one_dtype')
2081+
const_neg_one = mkconst(np.array([-1]).astype(np.int64), 'const_neg_one')
2082+
const_pad_vals = mkconst(pads, 'pads')
20832083
const_t = mkconst(np.array([-1, 1], np.int64), 'const_t')
20842084
const_xlen = mkconst(np.array([xlen], np.int64), 'const_xlen')
20852085
const_ylen = mkconst(np.array([ylen], np.int64), 'const_ylen')
2086-
const_stride = mkconst(np.array([xlenp + 1], np.int64), 'const_stride')
20872086
const_xlenp = mkconst(np.array([xlenp], np.int64), 'const_xlenp')
2087+
const_stride = mkconst(np.array([xlenp + 1], np.int64), 'const_stride')
20882088
const_minxy = mkconst(np.array([min(xlen, ylen)], np.int64), 'const_minxy')
20892089
const_xmax = mkconst(np.array([xlen * xlenp + xlenp - 1], np.int64), 'const_xmax')
20902090
const_ymax = mkconst(np.array([xlenp * ylen - 1], np.int64), 'const_ymax')
@@ -2099,13 +2099,13 @@ def mkconst(npval, desc):
20992099
k1_scalar = ctx.make_node('Squeeze', [k1.output[0]])
21002100
m_padded = ctx.make_node('Pad', [m, const_pad_vals, node.input[2]])
21012101

2102-
# starting index for super diagonals
2102+
# starting indexes for super diagonals
21032103
xstart_0 = ctx.make_node('Max', [const_zero, k0.output[0]])
21042104
xstart_1 = ctx.make_node('Add', [xstart_0.output[0], const_neg_one])
21052105
xstart_2 = ctx.make_node('Range', [k1_scalar.output[0], xstart_1.output[0], const_neg_one])
21062106
xstart = ctx.make_node('Reshape', [xstart_2.output[0], const_t])
21072107

2108-
# starting indices for sub diagonals
2108+
# starting indexes for sub diagonals
21092109
ystart_0 = ctx.make_node('Min', [const_neg_one, k1.output[0]])
21102110
ystart_0_scalar = ctx.make_node('Squeeze', [ystart_0.output[0]])
21112111
ystart_1 = ctx.make_node('Add', [k0.output[0], const_neg_one])
@@ -2159,18 +2159,19 @@ def right_align(sizes, indices, starts, maxval):
21592159
diags_2 = ctx.make_node('Expand', [diags_1.output[0], const_gather_shape])
21602160
diags = ctx.make_node('GatherElements', [m2.output[0], diags_2.output[0]], attr={'axis': -1})
21612161

2162-
# if k0=k1, rank of output matrix is 1 less than usual.
2163-
# hence, need 'If' to compute right output matrix shape
21642162
def compute_out_shape(k0_k1_same=False):
21652163
g = ctx.create_new_graph_with_same_config()
21662164
g.parent_graph = ctx
21672165
if k0_k1_same:
21682166
outshape = g.make_node('Concat', [const_partial_shape, maxsize_0.output[0]], attr={'axis': 0})
21692167
else:
2170-
outshape = g.make_node('Concat', [const_partial_shape, const_neg_one, maxsize_0.output[0]], attr={'axis': 0})
2168+
outshape = g.make_node('Concat', [const_partial_shape, const_neg_one, maxsize_0.output[0]],
2169+
attr={'axis': 0})
21712170
g.add_graph_output(outshape.output[0], TensorProto.INT64, [-1])
21722171
return g
21732172

2173+
# if k0==k1, rank of output matrix is 1 less than usual
2174+
# hence, using 'If' to compute right output matrix shape
21742175
k0_k1_same = ctx.make_node('Equal', [k1.output[0], k0.output[0]])
21752176
if_node = ctx.make_node('If', [k0_k1_same.output[0]])
21762177
if_node.set_body_graph_as_attr('then_branch', compute_out_shape(True))

0 commit comments

Comments
 (0)