@@ -179,13 +179,16 @@ def convert(
179179 for vi in self .model .graph .value_info :
180180 vi .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
181181 for idx , d in enumerate (vi .type .tensor_type .shape .dim ):
182- vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
182+ if d .dim_value :
183+ vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
183184 for out in self .model .graph .output :
184185 out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
185186 for idx , d in enumerate (out .type .tensor_type .shape .dim ):
186- out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
187+ if d .dim_value :
188+ out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
187189 # Populate type information with inferred types
188190 self .model = onnx_utils .infer_shapes (self .model , strict_mode = True , check_type = False )
191+ self ._ensure_types_are_defined ()
189192 # Sanity check: Verify type correctness
190193 self .model = onnx_utils .infer_shapes (self .model , strict_mode = True , check_type = True )
191194
@@ -201,6 +204,12 @@ def convert(
201204
202205 return self .model
203206
207+ def _ensure_types_are_defined (self ):
208+ """Ensure that all tensor types are defined."""
209+ for vi in self .model .graph .value_info :
210+ if vi .type .tensor_type .elem_type == onnx .TensorProto .UNDEFINED :
211+ vi .type .tensor_type .elem_type = self .low_precision_type .onnx_type
212+
204213 def _propagate_types_shapes_custom_ops (self , model ):
205214 """Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
206215 logger .info ("Propagating tensor shapes and types in model with custom ops." )
0 commit comments