Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ def convert(
# Remove redundant casts
self._cleanup()

if self.keep_io_types:
self._restore_original_io_types()

self._sanity_check()

return self.model
Expand All @@ -210,6 +213,22 @@ def _ensure_types_are_defined(self):
if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type

def _restore_original_io_types(self):
"""Restore original I/O types."""

def restore_tensor_type(tensor):
if tensor.name in self.original_network_io:
original_type = self.original_network_io[tensor.name]
if tensor.type.tensor_type.elem_type != original_type:
tensor.type.tensor_type.elem_type = original_type
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're just updating the tensor's metadata here. Shouldn't we also inject a cast here?

Looking at the bigger picture, keep_io_types flag was covered in the unit tests, and is also set as default in MOQ, so generally we know it is working, but there's some edge case we're first seeing in https://nvbugspro.nvidia.com/bug/5532019.
Inputs are clearly handled in lines 145-148, and outputs are presumably handled in _cleanup_pre_output_same_type_cast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed this function and created conversion logic in the GraphSanitizer.

# Update value_info_map if tensor exists there
if tensor.name in self.value_info_map:
self.value_info_map[tensor.name].type.tensor_type.elem_type = original_type

# Restore input and output types
for tensor in self.model.graph.input + self.model.graph.output:
restore_tensor_type(tensor)

def _propagate_types_shapes_custom_ops(self, model):
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
logger.info("Propagating tensor shapes and types in model with custom ops.")
Expand Down
Loading