Skip to content

Commit d6d2e75

Browse files
authored
[5452146] Fix: 'Invalid tensor data type 0' (#308)
Signed-off-by: gcunhase <[email protected]>
1 parent a6fa34c commit d6d2e75

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,16 @@ def convert(
179179
for vi in self.model.graph.value_info:
180180
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
181181
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
182-
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
182+
if d.dim_value:
183+
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
183184
for out in self.model.graph.output:
184185
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
185186
for idx, d in enumerate(out.type.tensor_type.shape.dim):
186-
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
187+
if d.dim_value:
188+
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
187189
# Populate type information with inferred types
188190
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False)
191+
self._ensure_types_are_defined()
189192
# Sanity check: Verify type correctness
190193
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True)
191194

@@ -201,6 +204,12 @@ def convert(
201204

202205
return self.model
203206

207+
def _ensure_types_are_defined(self):
208+
"""Ensure that all tensor types are defined."""
209+
for vi in self.model.graph.value_info:
210+
if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
211+
vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type
212+
204213
def _propagate_types_shapes_custom_ops(self, model):
205214
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
206215
logger.info("Propagating tensor shapes and types in model with custom ops.")

0 commit comments

Comments
 (0)