Skip to content

Commit 1e4d697

Browse files
committed
Remove check for keep_io_types
Signed-off-by: ajrasane <[email protected]>
1 parent 1b85355 commit 1e4d697

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,11 @@ def convert(
181181
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
182182
if d.dim_value:
183183
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
184-
if not self.keep_io_types:
185-
for out in self.model.graph.output:
186-
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
187-
for idx, d in enumerate(out.type.tensor_type.shape.dim):
188-
if d.dim_value:
189-
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
184+
for out in self.model.graph.output:
185+
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
186+
for idx, d in enumerate(out.type.tensor_type.shape.dim):
187+
if d.dim_value:
188+
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
190189
# Populate type information with inferred types
191190
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False)
192191
self._ensure_types_are_defined()
@@ -451,9 +450,8 @@ def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str
451450
for node in high_precision_nodes:
452451
# Add cast up for network inputs
453452
cast_to_fp32.extend([input for input in node.input if input in network_inputs])
454-
# Add cast down for network outputs (only if not keeping I/O types)
455-
if not self.keep_io_types:
456-
cast_to_fp16.extend([output for output in node.output if output in network_outputs])
453+
# Add cast down for network outputs
454+
cast_to_fp16.extend([output for output in node.output if output in network_outputs])
457455

458456
# Remove initializers, they are handled separately
459457
initializers = {init.name for init in self.model.graph.initializer}

0 commit comments

Comments
 (0)