diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 8b6fd545d..4d6917fd2 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -179,13 +179,16 @@ def convert( for vi in self.model.graph.value_info: vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED for idx, d in enumerate(vi.type.tensor_type.shape.dim): - vi.type.tensor_type.shape.dim[idx].dim_param = "unk" + if d.dim_value: + vi.type.tensor_type.shape.dim[idx].dim_param = "unk" for out in self.model.graph.output: out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED for idx, d in enumerate(out.type.tensor_type.shape.dim): - out.type.tensor_type.shape.dim[idx].dim_param = "unk" + if d.dim_value: + out.type.tensor_type.shape.dim[idx].dim_param = "unk" # Populate type information with inferred types self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False) + self._ensure_types_are_defined() # Sanity check: Verify type correctness self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True) @@ -201,6 +204,12 @@ def convert( return self.model + def _ensure_types_are_defined(self): + """Ensure that all tensor types are defined.""" + for vi in self.model.graph.value_info: + if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED: + vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type + def _propagate_types_shapes_custom_ops(self, model): """Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications.""" logger.info("Propagating tensor shapes and types in model with custom ops.")