Skip to content

Commit 1b85355

Browse files
committed
[AutoCast] restore outputs if keep_io_types is set
Signed-off-by: ajrasane <[email protected]>
1 parent 7d5f636 commit 1b85355

File tree

1 file changed

+38
-7
lines changed

1 file changed

+38
-7
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,12 @@ def convert(
181181
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
182182
if d.dim_value:
183183
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
184-
for out in self.model.graph.output:
185-
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
186-
for idx, d in enumerate(out.type.tensor_type.shape.dim):
187-
if d.dim_value:
188-
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
184+
if not self.keep_io_types:
185+
for out in self.model.graph.output:
186+
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
187+
for idx, d in enumerate(out.type.tensor_type.shape.dim):
188+
if d.dim_value:
189+
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
189190
# Populate type information with inferred types
190191
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False)
191192
self._ensure_types_are_defined()
@@ -200,6 +201,9 @@ def convert(
200201
# Remove redundant casts
201202
self._cleanup()
202203

204+
if self.keep_io_types:
205+
self._restore_original_io_types()
206+
203207
self._sanity_check()
204208

205209
return self.model
@@ -210,6 +214,32 @@ def _ensure_types_are_defined(self):
210214
if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
211215
vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type
212216

217+
def _restore_original_io_types(self):
218+
"""Restore original I/O types."""
219+
# Restore input types
220+
for input_tensor in self.model.graph.input:
221+
if input_tensor.name in self.original_network_io:
222+
original_type = self.original_network_io[input_tensor.name]
223+
if input_tensor.type.tensor_type.elem_type != original_type:
224+
input_tensor.type.tensor_type.elem_type = original_type
225+
# Update value_info_map if tensor exists there
226+
if input_tensor.name in self.value_info_map:
227+
self.value_info_map[
228+
input_tensor.name
229+
].type.tensor_type.elem_type = original_type
230+
231+
# Restore output types
232+
for output_tensor in self.model.graph.output:
233+
if output_tensor.name in self.original_network_io:
234+
original_type = self.original_network_io[output_tensor.name]
235+
if output_tensor.type.tensor_type.elem_type != original_type:
236+
output_tensor.type.tensor_type.elem_type = original_type
237+
# Update value_info_map if tensor exists there
238+
if output_tensor.name in self.value_info_map:
239+
self.value_info_map[
240+
output_tensor.name
241+
].type.tensor_type.elem_type = original_type
242+
213243
def _propagate_types_shapes_custom_ops(self, model):
214244
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
215245
logger.info("Propagating tensor shapes and types in model with custom ops.")
@@ -421,8 +451,9 @@ def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str
421451
for node in high_precision_nodes:
422452
# Add cast up for network inputs
423453
cast_to_fp32.extend([input for input in node.input if input in network_inputs])
424-
# Add cast down for network outputs
425-
cast_to_fp16.extend([output for output in node.output if output in network_outputs])
454+
# Add cast down for network outputs (only if not keeping I/O types)
455+
if not self.keep_io_types:
456+
cast_to_fp16.extend([output for output in node.output if output in network_outputs])
426457

427458
# Remove initializers, they are handled separately
428459
initializers = {init.name for init in self.model.graph.initializer}

0 commit comments

Comments
 (0)