@@ -179,13 +179,16 @@ def convert(
179
179
for vi in self .model .graph .value_info :
180
180
vi .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
181
181
for idx , d in enumerate (vi .type .tensor_type .shape .dim ):
182
- vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
182
+ if d .dim_value :
183
+ vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
183
184
for out in self .model .graph .output :
184
185
out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
185
186
for idx , d in enumerate (out .type .tensor_type .shape .dim ):
186
- out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
187
+ if d .dim_value :
188
+ out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
187
189
# Populate type information with inferred types
188
190
self .model = onnx_utils .infer_shapes (self .model , strict_mode = True , check_type = False )
191
+ self ._ensure_types_are_defined ()
189
192
# Sanity check: Verify type correctness
190
193
self .model = onnx_utils .infer_shapes (self .model , strict_mode = True , check_type = True )
191
194
@@ -201,6 +204,12 @@ def convert(
201
204
202
205
return self .model
203
206
207
+ def _ensure_types_are_defined (self ):
208
+ """Ensure that all tensor types are defined."""
209
+ for vi in self .model .graph .value_info :
210
+ if vi .type .tensor_type .elem_type == onnx .TensorProto .UNDEFINED :
211
+ vi .type .tensor_type .elem_type = self .low_precision_type .onnx_type
212
+
204
213
def _propagate_types_shapes_custom_ops (self , model ):
205
214
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
206
215
logger .info ("Propagating tensor shapes and types in model with custom ops." )
0 commit comments