Skip to content

Commit de80f51

Browse files
Prevent tf1 loop ops from const folding (#1561)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent ab816b3 commit de80f51

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tf2onnx/tf_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ def is_huge_shape(x):
216216
outputs_to_values[output_names[0]] = np.array(shape[i], dtype=np_dtype)
217217
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
218218
progress = True
219-
can_fold = node.type not in ['Enter', 'Placeholder', 'PlaceholderWithDefault']
219+
can_fold = node.type not in ['Enter', 'Placeholder', 'PlaceholderWithDefault', 'Switch', 'Merge',
220+
'NextIteration', 'Exit']
220221
can_fold = can_fold and not node.type.startswith('Random')
221222
can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
222223
# We can only fold nodes with a single output

0 commit comments

Comments
 (0)