Skip to content

Commit d87ba34

Browse files
authored
#1763: fix output shape of flatten rewriter (#1766)
Signed-off-by: Guenther Schmuelling <[email protected]>
1 parent 3b0590c commit d87ba34

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

tf2onnx/onnx_opset/math.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,11 @@ class Softmax:
203203
def version_1(cls, ctx, node, **kwargs):
204204
# T output = Softmax(T logits). The axis softmax would be performed on is always on -1.
205205
# T output = Softmax(T input, @int axis). Default axis is 1.
206-
logits_rank = len(ctx.get_shape(node.input[0]))
207-
node.set_attr("axis", logits_rank - 1)
206+
axis = node.get_attr_value("axis")
207+
if axis is None:
208+
# by default use the last dim
209+
axis = len(ctx.get_shape(node.input[0])) - 1
210+
node.set_attr("axis", axis)
208211

209212
@classmethod
210213
def version_11(cls, ctx, node, **kwargs):

tf2onnx/rewriter/flatten_rewriter.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,9 @@ def rewrite_flatten(g, ops):
8585
out_name = utils.port_name(op_name)
8686
g.make_node("Flatten", [reshape_node.input[0]], outputs=[out_name], name=op_name)
8787

88-
last_dim = input_shape[-1]
89-
sec_last_dim = input_shape[-2]
90-
new_dim = None
91-
if last_dim > 0 and sec_last_dim > 0:
92-
new_dim = last_dim * sec_last_dim
93-
else:
94-
new_dim = -1
95-
96-
g.set_shape(out_name, input_shape[:-2] + [new_dim])
88+
# take output shape from reshape()
89+
output_shape = g.get_shape(reshape_node.output[0])
90+
g.set_shape(out_name, output_shape)
9791
g.replace_all_inputs(reshape_node.output[0], out_name, ops=ops)
9892
for n in to_remove:
9993
g.remove_node(n.name)

0 commit comments

Comments
 (0)