Skip to content

Commit 2a35495

Browse files
committed
Formatting
1 parent 2b8cf57 commit 2a35495

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,9 +2151,8 @@ def right_align(sizes, indices, starts, maxval):
21512151
else:
21522152
ydiags = ctx.make_node('Min', [ydiags_2.output[0], const_ymax])
21532153

2154-
# flatten last dimension of matrix
2154+
# flatten last dimension of matrix, extract diagonal values
21552155
m2 = ctx.make_node('Reshape', [m_padded.output[0], const_m2_shape])
2156-
21572156
diags_0 = ctx.make_node('Concat', [xdiags.output[0], ydiags.output[0]], attr={'axis': 0})
21582157
diags_1 = ctx.make_node('Reshape', [diags_0.output[0], const_neg_one])
21592158
diags_2 = ctx.make_node('Expand', [diags_1.output[0], const_gather_shape])
@@ -2163,11 +2162,11 @@ def compute_out_shape(k0_k1_same=False):
21632162
g = ctx.create_new_graph_with_same_config()
21642163
g.parent_graph = ctx
21652164
if k0_k1_same:
2166-
outshape = g.make_node('Concat', [const_partial_shape, maxsize_0.output[0]], attr={'axis': 0})
2165+
dims = [const_partial_shape, maxsize_0.output[0]]
21672166
else:
2168-
outshape = g.make_node('Concat', [const_partial_shape, const_neg_one, maxsize_0.output[0]],
2169-
attr={'axis': 0})
2170-
g.add_graph_output(outshape.output[0], TensorProto.INT64, [-1])
2167+
dims = [const_partial_shape, const_neg_one, maxsize_0.output[0]]
2168+
out_shape = g.make_node('Concat', dims, attr={'axis': 0})
2169+
g.add_graph_output(out_shape.output[0], TensorProto.INT64, [-1])
21712170
return g
21722171

21732172
# if k0==k1, rank of output matrix is 1 less than usual

tf2onnx/tf_loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,6 @@ def tf_reload_graph(tf_graph):
356356
)
357357

358358
graph_def = tf_graph.as_graph_def(add_shapes=True)
359-
#tf_reset_default_graph()
360359
with tf.Graph().as_default() as inferred_graph:
361360
tf.import_graph_def(graph_def, name="")
362361
return inferred_graph

0 commit comments

Comments
 (0)