Skip to content

Commit 7ccc35e

Browse files
committed
Remove check for keep_io_types
Signed-off-by: ajrasane <[email protected]>
1 parent 1b85355 commit 7ccc35e

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,11 @@ def convert(
181181
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
182182
if d.dim_value:
183183
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
184-
if not self.keep_io_types:
185-
for out in self.model.graph.output:
186-
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
187-
for idx, d in enumerate(out.type.tensor_type.shape.dim):
188-
if d.dim_value:
189-
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
184+
for out in self.model.graph.output:
185+
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
186+
for idx, d in enumerate(out.type.tensor_type.shape.dim):
187+
if d.dim_value:
188+
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
190189
# Populate type information with inferred types
191190
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False)
192191
self._ensure_types_are_defined()
@@ -452,8 +451,7 @@ def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str
452451
# Add cast up for network inputs
453452
cast_to_fp32.extend([input for input in node.input if input in network_inputs])
454453
# Add cast down for network outputs (only if not keeping I/O types)
455-
if not self.keep_io_types:
456-
cast_to_fp16.extend([output for output in node.output if output in network_outputs])
454+
cast_to_fp16.extend([output for output in node.output if output in network_outputs])
457455

458456
# Remove initializers, they are handled separately
459457
initializers = {init.name for init in self.model.graph.initializer}

0 commit comments

Comments
 (0)