@@ -215,29 +215,19 @@ def _ensure_types_are_defined(self):
215
215
216
216
def _restore_original_io_types (self ):
217
217
"""Restore original I/O types."""
218
- # Restore input types
219
- for input_tensor in self .model .graph .input :
220
- if input_tensor .name in self .original_network_io :
221
- original_type = self .original_network_io [input_tensor .name ]
222
- if input_tensor .type .tensor_type .elem_type != original_type :
223
- input_tensor .type .tensor_type .elem_type = original_type
224
- # Update value_info_map if tensor exists there
225
- if input_tensor .name in self .value_info_map :
226
- self .value_info_map [
227
- input_tensor .name
228
- ].type .tensor_type .elem_type = original_type
229
-
230
- # Restore output types
231
- for output_tensor in self .model .graph .output :
232
- if output_tensor .name in self .original_network_io :
233
- original_type = self .original_network_io [output_tensor .name ]
234
- if output_tensor .type .tensor_type .elem_type != original_type :
235
- output_tensor .type .tensor_type .elem_type = original_type
218
+
219
+ def restore_tensor_type (tensor ):
220
+ if tensor .name in self .original_network_io :
221
+ original_type = self .original_network_io [tensor .name ]
222
+ if tensor .type .tensor_type .elem_type != original_type :
223
+ tensor .type .tensor_type .elem_type = original_type
236
224
# Update value_info_map if tensor exists there
237
- if output_tensor .name in self .value_info_map :
238
- self .value_info_map [
239
- output_tensor .name
240
- ].type .tensor_type .elem_type = original_type
225
+ if tensor .name in self .value_info_map :
226
+ self .value_info_map [tensor .name ].type .tensor_type .elem_type = original_type
227
+
228
+ # Restore input and output types
229
+ for tensor in self .model .graph .input + self .model .graph .output :
230
+ restore_tensor_type (tensor )
241
231
242
232
def _propagate_types_shapes_custom_ops (self , model ):
243
233
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
0 commit comments