Skip to content

Commit a138ccb

Browse files
committed
Reactivate constant folding (for ConcatV2)
Allow output shape of (?,) Fix Squeeze op when tensor shape is (?,)
1 parent 4b6b8d1 commit a138ccb

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

tf2onnx/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,9 @@ def make_onnx_graph_io(self, ids):
10671067
shape = self.get_shape(name)
10681068

10691069
utils.make_sure(dtype is not None, "missing output dtype for " + name)
1070-
utils.make_sure(shape is not None, "missing output shape for " + name)
1070+
# TODO: allow None output shape or not? e.g. shape=(?,)
1071+
#utils.make_sure(shape is not None, "missing output shape for " + name)
1072+
if shape is None: logger.warning(f"missing output shape for " + name)
10711073

10721074
v = utils.make_onnx_inputs_outputs(name, dtype, shape)
10731075
tensor_value_infos.append(v)

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def version_1(cls, ctx, node, **kwargs):
196196
shape = ctx.get_shape(node.input[0])
197197
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
198198
axis = [i for i, j in enumerate(shape) if j == 1]
199+
if not axis: axis = [0]
199200
node.set_attr("axes", axis)
200201

201202
@classmethod

tf2onnx/tfonnx.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,17 @@ def rewrite_constant_fold(g, ops):
4242
tensorflow missed something, make another pass over the graph and fix want we care about.
4343
"""
4444
func_map = {
45-
"Add": np.add,
46-
"GreaterEqual": np.greater_equal,
47-
"Cast": np.cast,
45+
# "Add": np.add,
46+
# "GreaterEqual": np.greater_equal,
47+
# "Cast": np.cast,
4848
"ConcatV2": np.concatenate,
49-
"Less": np.less,
50-
"ListDiff": np.setdiff1d,
51-
"Mul": np.multiply,
52-
"Pack": np.stack,
53-
"Range": np.arange,
54-
"Sqrt": np.sqrt,
55-
"Sub": np.subtract,
49+
# "Less": np.less,
50+
# "ListDiff": np.setdiff1d,
51+
# "Mul": np.multiply,
52+
# "Pack": np.stack,
53+
# "Range": np.arange,
54+
# "Sqrt": np.sqrt,
55+
# "Sub": np.subtract,
5656
}
5757
ref_cnt_per_node = {}
5858
for idx, op in enumerate(ops):
@@ -466,7 +466,7 @@ def compat_handler(ctx, node, **kwargs):
466466
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
467467
rewrite_single_direction_gru, rewrite_bi_direction_gru,
468468
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond,
469-
rewrite_biasadd_with_conv2d,
469+
rewrite_biasadd_with_conv2d, rewrite_constant_fold
470470
]
471471

472472
if custom_rewriter is not None:

0 commit comments

Comments
 (0)