Skip to content

Commit a6e2e19

Browse files
committed
Assign names for None dims
When onnx shape inference is run on symbolic input dims, it will not handle the dim name propagation and instead create a None. As long as we rely on the current version onnx shape inference there is not better information we can get. However, since in the optimizer we also have some custom shape propagator implemented (e.g. for Identity) that will propagate sym dims, we should encode the equivalents for those dimensions as much as possible. This PR assigns a string to all None dims produced by onnx shape inference, so that the string names can get propagated when possible by the optimizer. Signed-off-by: Justin Chu <[email protected]>
1 parent cb6f873 commit a6e2e19

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,7 @@ def __init__(
974974
self._sizes: dict[str, int] = {}
975975
self._modified: bool = False
976976
self._state = OptimizerState()
977+
self._unknown_dim_count = 0
977978
self._reset()
978979

979980
def _reset(self) -> None:
@@ -982,6 +983,7 @@ def _reset(self) -> None:
982983
self._sizes = {}
983984
self._modified = False
984985
self._state = OptimizerState()
986+
self._unknown_dim_count = 0
985987

986988
def _do_inference(self, node: ir.Node) -> None:
987989
output_types = {}
@@ -1029,7 +1031,12 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
10291031
inferred_shape = ir.serde.deserialize_type_proto_for_shape(
10301032
inferred_type
10311033
)
1032-
output.shape = _merge_shapes(output.shape, inferred_shape)
1034+
merged_shape = _merge_shapes(output.shape, inferred_shape)
1035+
assert merged_shape is not None
1036+
output.shape = merged_shape
1037+
for i in range(len(merged_shape)):
1038+
if merged_shape.is_unknown_dim(i):
1039+
merged_shape[i] = ir.SymbolicDim(self._new_unknown_dim_name())
10331040
output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
10341041
except Exception as e:
10351042
logger.debug(
@@ -1038,6 +1045,11 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
10381045
e,
10391046
)
10401047

1048+
def _new_unknown_dim_name(self) -> str:
1049+
name = f"unknown_{self._unknown_dim_count}"
1050+
self._unknown_dim_count += 1
1051+
return name
1052+
10411053
def new_constant(self, node: ir.Node, value) -> ir.Node | None:
10421054
irvalue = node.outputs[0]
10431055
if not isinstance(value, np.ndarray):

0 commit comments

Comments
 (0)