Skip to content

Commit 37a0c43

Browse files
committed
Fix shape inference error
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 8a7de40 commit 37a0c43

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ def merge_dims(dim1, dim2):
919919
if other_shape is None:
920920
return preferred_shape
921921
if len(preferred_shape) != len(other_shape):
922-
raise ValueError("Shapes must have the same rank.")
922+
raise ValueError(f"Shapes must have the same rank, got preferred_shape={preferred_shape}, other_shape={other_shape}")
923923
return ir.Shape(
924924
[merge_dims(dim1, dim2) for dim1, dim2 in zip(preferred_shape, other_shape)]
925925
)
@@ -1035,7 +1035,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
10351035
except Exception as e:
10361036
logger.debug(
10371037
"Skipping shape inference for node %r due to exception: %s",
1038-
node.name,
1038+
node,
10391039
e,
10401040
)
10411041

@@ -1124,7 +1124,12 @@ def process_node(self, node: ir.Node) -> Replacement | None:
11241124
for optimizer in op_optimizers:
11251125
assert optimizer
11261126
context = RewriterContext()
1127-
output = optimizer(node, context, self._state)
1127+
try:
1128+
output = optimizer(node, context, self._state)
1129+
except Exception as e:
1130+
raise RuntimeError(
1131+
f"Error during constant folding for node {node.name!r} ({node.domain}::{node.op_type})"
1132+
) from e
11281133
if output is not None:
11291134
if isinstance(output, Replacement):
11301135
return output

0 commit comments

Comments
 (0)