@@ -181,11 +181,12 @@ def convert(
181
181
for idx , d in enumerate (vi .type .tensor_type .shape .dim ):
182
182
if d .dim_value :
183
183
vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
184
- for out in self .model .graph .output :
185
- out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
186
- for idx , d in enumerate (out .type .tensor_type .shape .dim ):
187
- if d .dim_value :
188
- out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
184
+ if not self .keep_io_types :
185
+ for out in self .model .graph .output :
186
+ out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
187
+ for idx , d in enumerate (out .type .tensor_type .shape .dim ):
188
+ if d .dim_value :
189
+ out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
189
190
# Populate type information with inferred types
190
191
self .model = onnx_utils .infer_shapes (self .model , strict_mode = True , check_type = False )
191
192
self ._ensure_types_are_defined ()
@@ -200,6 +201,9 @@ def convert(
200
201
# Remove redundant casts
201
202
self ._cleanup ()
202
203
204
+ if self .keep_io_types :
205
+ self ._restore_original_io_types ()
206
+
203
207
self ._sanity_check ()
204
208
205
209
return self .model
@@ -210,6 +214,32 @@ def _ensure_types_are_defined(self):
210
214
if vi .type .tensor_type .elem_type == onnx .TensorProto .UNDEFINED :
211
215
vi .type .tensor_type .elem_type = self .low_precision_type .onnx_type
212
216
217
+ def _restore_original_io_types (self ):
218
+ """Restore original I/O types."""
219
+ # Restore input types
220
+ for input_tensor in self .model .graph .input :
221
+ if input_tensor .name in self .original_network_io :
222
+ original_type = self .original_network_io [input_tensor .name ]
223
+ if input_tensor .type .tensor_type .elem_type != original_type :
224
+ input_tensor .type .tensor_type .elem_type = original_type
225
+ # Update value_info_map if tensor exists there
226
+ if input_tensor .name in self .value_info_map :
227
+ self .value_info_map [
228
+ input_tensor .name
229
+ ].type .tensor_type .elem_type = original_type
230
+
231
+ # Restore output types
232
+ for output_tensor in self .model .graph .output :
233
+ if output_tensor .name in self .original_network_io :
234
+ original_type = self .original_network_io [output_tensor .name ]
235
+ if output_tensor .type .tensor_type .elem_type != original_type :
236
+ output_tensor .type .tensor_type .elem_type = original_type
237
+ # Update value_info_map if tensor exists there
238
+ if output_tensor .name in self .value_info_map :
239
+ self .value_info_map [
240
+ output_tensor .name
241
+ ].type .tensor_type .elem_type = original_type
242
+
213
243
def _propagate_types_shapes_custom_ops (self , model ):
214
244
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
215
245
logger .info ("Propagating tensor shapes and types in model with custom ops." )
@@ -421,8 +451,9 @@ def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str
421
451
for node in high_precision_nodes :
422
452
# Add cast up for network inputs
423
453
cast_to_fp32 .extend ([input for input in node .input if input in network_inputs ])
424
- # Add cast down for network outputs
425
- cast_to_fp16 .extend ([output for output in node .output if output in network_outputs ])
454
+ # Add cast down for network outputs (only if not keeping I/O types)
455
+ if not self .keep_io_types :
456
+ cast_to_fp16 .extend ([output for output in node .output if output in network_outputs ])
426
457
427
458
# Remove initializers, they are handled separately
428
459
initializers = {init .name for init in self .model .graph .initializer }
0 commit comments