Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,16 @@ def convert(
for vi in self.model.graph.value_info:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
for out in self.model.graph.output:
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(out.type.tensor_type.shape.dim):
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
if d.dim_value:
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
# Populate type information with inferred types
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False)
self._ensure_types_are_defined()
# Sanity check: Verify type correctness
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True)

Expand All @@ -201,6 +204,12 @@ def convert(

return self.model

def _ensure_types_are_defined(self):
"""Ensure that all tensor types are defined."""
for vi in self.model.graph.value_info:
if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type

def _propagate_types_shapes_custom_ops(self, model):
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
logger.info("Propagating tensor shapes and types in model with custom ops.")
Expand Down
Loading