@@ -181,12 +181,11 @@ def convert(
181
181
for idx , d in enumerate (vi .type .tensor_type .shape .dim ):
182
182
if d .dim_value :
183
183
vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
184
- if not self .keep_io_types :
185
- for out in self .model .graph .output :
186
- out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
187
- for idx , d in enumerate (out .type .tensor_type .shape .dim ):
188
- if d .dim_value :
189
- out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
184
+ for out in self .model .graph .output :
185
+ out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
186
+ for idx , d in enumerate (out .type .tensor_type .shape .dim ):
187
+ if d .dim_value :
188
+ out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
190
189
# Populate type information with inferred types
191
190
self .model = onnx_utils .infer_shapes (self .model , strict_mode = True , check_type = False )
192
191
self ._ensure_types_are_defined ()
@@ -452,8 +451,7 @@ def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str
452
451
# Add cast up for network inputs
453
452
cast_to_fp32 .extend ([input for input in node .input if input in network_inputs ])
454
453
# Add cast down for network outputs (only if not keeping I/O types)
455
- if not self .keep_io_types :
456
- cast_to_fp16 .extend ([output for output in node .output if output in network_outputs ])
454
+ cast_to_fp16 .extend ([output for output in node .output if output in network_outputs ])
457
455
458
456
# Remove initializers, they are handled separately
459
457
initializers = {init .name for init in self .model .graph .initializer }
0 commit comments