Skip to content

Commit 01822ec

Browse files
committed
Simplify restore_original_io_types
Signed-off-by: ajrasane <[email protected]>
1 parent 1e4d697 commit 01822ec

File tree

1 file changed

+12
-22
lines changed

1 file changed

+12
-22
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -215,29 +215,19 @@ def _ensure_types_are_defined(self):
215215

216216
def _restore_original_io_types(self):
217217
"""Restore original I/O types."""
218-
# Restore input types
219-
for input_tensor in self.model.graph.input:
220-
if input_tensor.name in self.original_network_io:
221-
original_type = self.original_network_io[input_tensor.name]
222-
if input_tensor.type.tensor_type.elem_type != original_type:
223-
input_tensor.type.tensor_type.elem_type = original_type
224-
# Update value_info_map if tensor exists there
225-
if input_tensor.name in self.value_info_map:
226-
self.value_info_map[
227-
input_tensor.name
228-
].type.tensor_type.elem_type = original_type
229-
230-
# Restore output types
231-
for output_tensor in self.model.graph.output:
232-
if output_tensor.name in self.original_network_io:
233-
original_type = self.original_network_io[output_tensor.name]
234-
if output_tensor.type.tensor_type.elem_type != original_type:
235-
output_tensor.type.tensor_type.elem_type = original_type
218+
219+
def restore_tensor_type(tensor):
220+
if tensor.name in self.original_network_io:
221+
original_type = self.original_network_io[tensor.name]
222+
if tensor.type.tensor_type.elem_type != original_type:
223+
tensor.type.tensor_type.elem_type = original_type
236224
# Update value_info_map if tensor exists there
237-
if output_tensor.name in self.value_info_map:
238-
self.value_info_map[
239-
output_tensor.name
240-
].type.tensor_type.elem_type = original_type
225+
if tensor.name in self.value_info_map:
226+
self.value_info_map[tensor.name].type.tensor_type.elem_type = original_type
227+
228+
# Restore input and output types
229+
for tensor in self.model.graph.input + self.model.graph.output:
230+
restore_tensor_type(tensor)
241231

242232
def _propagate_types_shapes_custom_ops(self, model):
243233
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""

0 commit comments

Comments
 (0)