Skip to content

Commit 4b79f9b

Browse files
authored
Merge output shape with input shape instead of override
1 parent dddf0c2 commit 4b79f9b

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,9 +491,7 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
491491
# should handle this. Only the optimization to eliminate redundant Cast ops
492492
# should be needed here.
493493

494-
input_shape = input.shape
495-
if input_shape is not None:
496-
output.shape = input_shape.copy()
494+
output.shape = _merge_shapes(output.shape, input.shape)
497495

498496
input_dtype = _get_input_element_type(node, 0)
499497
output_dtype = _get_int_attribute(node, "to", None)

0 commit comments

Comments
 (0)