Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ def convert(
# Remove redundant casts
self._cleanup()

if self.keep_io_types:
self._restore_original_io_types()

self._sanity_check()

return self.model
Expand All @@ -210,6 +213,32 @@ def _ensure_types_are_defined(self):
if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type

def _restore_original_io_types(self):
"""Restore original I/O types."""
# Restore input types
for input_tensor in self.model.graph.input:
if input_tensor.name in self.original_network_io:
original_type = self.original_network_io[input_tensor.name]
if input_tensor.type.tensor_type.elem_type != original_type:
input_tensor.type.tensor_type.elem_type = original_type
# Update value_info_map if tensor exists there
if input_tensor.name in self.value_info_map:
self.value_info_map[
input_tensor.name
].type.tensor_type.elem_type = original_type

# Restore output types
for output_tensor in self.model.graph.output:
if output_tensor.name in self.original_network_io:
original_type = self.original_network_io[output_tensor.name]
if output_tensor.type.tensor_type.elem_type != original_type:
output_tensor.type.tensor_type.elem_type = original_type
# Update value_info_map if tensor exists there
if output_tensor.name in self.value_info_map:
self.value_info_map[
output_tensor.name
].type.tensor_type.elem_type = original_type

def _propagate_types_shapes_custom_ops(self, model):
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
logger.info("Propagating tensor shapes and types in model with custom ops.")
Expand Down
Loading