@@ -215,29 +215,19 @@ def _ensure_types_are_defined(self):
215215
216216 def _restore_original_io_types (self ):
217217 """Restore original I/O types."""
218- # Restore input types
219- for input_tensor in self .model .graph .input :
220- if input_tensor .name in self .original_network_io :
221- original_type = self .original_network_io [input_tensor .name ]
222- if input_tensor .type .tensor_type .elem_type != original_type :
223- input_tensor .type .tensor_type .elem_type = original_type
224- # Update value_info_map if tensor exists there
225- if input_tensor .name in self .value_info_map :
226- self .value_info_map [
227- input_tensor .name
228- ].type .tensor_type .elem_type = original_type
229-
230- # Restore output types
231- for output_tensor in self .model .graph .output :
232- if output_tensor .name in self .original_network_io :
233- original_type = self .original_network_io [output_tensor .name ]
234- if output_tensor .type .tensor_type .elem_type != original_type :
235- output_tensor .type .tensor_type .elem_type = original_type
218+
219+ def restore_tensor_type (tensor ):
220+ if tensor .name in self .original_network_io :
221+ original_type = self .original_network_io [tensor .name ]
222+ if tensor .type .tensor_type .elem_type != original_type :
223+ tensor .type .tensor_type .elem_type = original_type
236224 # Update value_info_map if tensor exists there
237- if output_tensor .name in self .value_info_map :
238- self .value_info_map [
239- output_tensor .name
240- ].type .tensor_type .elem_type = original_type
225+ if tensor .name in self .value_info_map :
226+ self .value_info_map [tensor .name ].type .tensor_type .elem_type = original_type
227+
228+ # Restore input and output types
229+ for tensor in self .model .graph .input + self .model .graph .output :
230+ restore_tensor_type (tensor )
241231
242232 def _propagate_types_shapes_custom_ops (self , model ):
243233 """Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
0 commit comments