@@ -181,12 +181,11 @@ def convert(
181181 for idx , d in enumerate (vi .type .tensor_type .shape .dim ):
182182 if d .dim_value :
183183 vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
184- if not self .keep_io_types :
185- for out in self .model .graph .output :
186- out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
187- for idx , d in enumerate (out .type .tensor_type .shape .dim ):
188- if d .dim_value :
189- out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
184+ for out in self .model .graph .output :
185+ out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
186+ for idx , d in enumerate (out .type .tensor_type .shape .dim ):
187+ if d .dim_value :
188+ out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
190189 # Populate type information with inferred types
191190 self .model = onnx_utils .infer_shapes (self .model , strict_mode = True , check_type = False )
192191 self ._ensure_types_are_defined ()
@@ -451,9 +450,8 @@ def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str
451450 for node in high_precision_nodes :
452451 # Add cast up for network inputs
453452 cast_to_fp32 .extend ([input for input in node .input if input in network_inputs ])
454- # Add cast down for network outputs (only if not keeping I/O types)
455- if not self .keep_io_types :
456- cast_to_fp16 .extend ([output for output in node .output if output in network_outputs ])
453+ # Add cast down for network outputs
454+ cast_to_fp16 .extend ([output for output in node .output if output in network_outputs ])
457455
458456 # Remove initializers, they are handled separately
459457 initializers = {init .name for init in self .model .graph .initializer }
0 commit comments