@@ -181,11 +181,12 @@ def convert(
181181 for idx , d in enumerate (vi .type .tensor_type .shape .dim ):
182182 if d .dim_value :
183183 vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
184- for out in self .model .graph .output :
185- out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
186- for idx , d in enumerate (out .type .tensor_type .shape .dim ):
187- if d .dim_value :
188- out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
184+ if not self .keep_io_types :
185+ for out in self .model .graph .output :
186+ out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
187+ for idx , d in enumerate (out .type .tensor_type .shape .dim ):
188+ if d .dim_value :
189+ out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
189190 # Populate type information with inferred types
190191 self .model = onnx_utils .infer_shapes (self .model , strict_mode = True , check_type = False )
191192 self ._ensure_types_are_defined ()
@@ -200,6 +201,9 @@ def convert(
200201 # Remove redundant casts
201202 self ._cleanup ()
202203
204+ if self .keep_io_types :
205+ self ._restore_original_io_types ()
206+
203207 self ._sanity_check ()
204208
205209 return self .model
@@ -210,6 +214,32 @@ def _ensure_types_are_defined(self):
210214 if vi .type .tensor_type .elem_type == onnx .TensorProto .UNDEFINED :
211215 vi .type .tensor_type .elem_type = self .low_precision_type .onnx_type
212216
217+ def _restore_original_io_types (self ):
218+ """Restore original I/O types."""
219+ # Restore input types
220+ for input_tensor in self .model .graph .input :
221+ if input_tensor .name in self .original_network_io :
222+ original_type = self .original_network_io [input_tensor .name ]
223+ if input_tensor .type .tensor_type .elem_type != original_type :
224+ input_tensor .type .tensor_type .elem_type = original_type
225+ # Update value_info_map if tensor exists there
226+ if input_tensor .name in self .value_info_map :
227+ self .value_info_map [
228+ input_tensor .name
229+ ].type .tensor_type .elem_type = original_type
230+
231+ # Restore output types
232+ for output_tensor in self .model .graph .output :
233+ if output_tensor .name in self .original_network_io :
234+ original_type = self .original_network_io [output_tensor .name ]
235+ if output_tensor .type .tensor_type .elem_type != original_type :
236+ output_tensor .type .tensor_type .elem_type = original_type
237+ # Update value_info_map if tensor exists there
238+ if output_tensor .name in self .value_info_map :
239+ self .value_info_map [
240+ output_tensor .name
241+ ].type .tensor_type .elem_type = original_type
242+
213243 def _propagate_types_shapes_custom_ops (self , model ):
214244 """Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
215245 logger .info ("Propagating tensor shapes and types in model with custom ops." )
@@ -421,8 +451,9 @@ def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str
421451 for node in high_precision_nodes :
422452 # Add cast up for network inputs
423453 cast_to_fp32 .extend ([input for input in node .input if input in network_inputs ])
424- # Add cast down for network outputs
425- cast_to_fp16 .extend ([output for output in node .output if output in network_outputs ])
454+ # Add cast down for network outputs (only if not keeping I/O types)
455+ if not self .keep_io_types :
456+ cast_to_fp16 .extend ([output for output in node .output if output in network_outputs ])
426457
427458 # Remove initializers, they are handled separately
428459 initializers = {init .name for init in self .model .graph .initializer }
0 commit comments