Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions modelopt/onnx/autocast/graphsanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def sanitize(self) -> None:
self.ensure_graph_name_exists()
onnx_utils.name_onnx_nodes(self.model.graph)
self.replace_custom_domain_nodes()
self.sanitize_io_casts()
self.cleanup_model()
self.set_ir_version(self.max_ir_version)
self.convert_fp64_to_fp32()
Expand Down Expand Up @@ -343,6 +344,43 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
logger.debug(f"Failed to match LayerNorm pattern at {mean_node.name}: {e!s}")
return None

def sanitize_io_casts(self) -> None:
"""Handle the special case where an input is casted directly to an output.

Inject an identity node after the cast node.
"""
model_input_names = {input.name for input in self.model.graph.input}
model_output_names = {output.name for output in self.model.graph.output}
nodes_to_add = []
for node in self.model.graph.node:
if (
node.op_type == "Cast"
and node.input
and node.output
and node.input[0] in model_input_names
and node.output[0] in model_output_names
):
# Unique per graph output to avoid collisions when multiple outputs are cast from the same input
cast_output_name = node.output[0]
cast_new_output_name = f"{cast_output_name}__io_cast_src"
nodes_to_add.append(
helper.make_node(
"Identity",
inputs=[cast_new_output_name],
outputs=[cast_output_name],
name=f"{node.name}__io_cast_identity",
)
)
# Rewire Cast to produce the new intermediate
node.output[0] = cast_new_output_name

for node in nodes_to_add:
self.model.graph.node.append(node)

# Make sure the graph is topologically sorted
gs_graph = gs.import_onnx(self.model).cleanup().toposort()
self.model = gs.export_onnx(gs_graph)

def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto:
"""Create a LayerNormalization node with optional bias."""
ln_name = f"LayerNorm_{pattern['mean_node'].name}"
Expand Down
19 changes: 19 additions & 0 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import modelopt.onnx.autocast.utils as utils
import modelopt.onnx.utils as onnx_utils
from modelopt.onnx.autocast.graphsanitizer import GraphSanitizer
from modelopt.onnx.autocast.logging_config import configure_logging, logger

configure_logging()
Expand Down Expand Up @@ -73,6 +74,9 @@ def __init__(
low_precision_type: str = "fp16",
init_conversion_max_bytes: int | None = None,
custom_ops: set[str] | None = None,
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
Comment on lines +77 to +79
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Replace mutable default argument with None.

The default value trt_plugins=[] creates a mutable default argument, which can lead to shared state between instances if the list is mutated.

Apply this diff:

-        trt_plugins: list[str] | None = [],
+        trt_plugins: list[str] | None = None,

Then update line 1067 to handle the None case:

         graph_sanitizer = GraphSanitizer(
             self.model,
             self.min_opset,
-            trt_plugins=self.trt_plugins,
+            trt_plugins=self.trt_plugins if self.trt_plugins is not None else [],
             max_ir_version=self.max_ir_version,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
def __init__(
...,
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = None,
):
...
Suggested change
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
graph_sanitizer = GraphSanitizer(
self.model,
self.min_opset,
trt_plugins=self.trt_plugins if self.trt_plugins is not None else [],
max_ir_version=self.max_ir_version,
)
🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 77 to 79, change
the function signature to use trt_plugins: list[str] | None = None instead of a
mutable default list, and then at line 1067 update the code to treat a None
value as an empty list (e.g., set local_trt_plugins = trt_plugins or [] before
using it) so any subsequent iterations or mutations operate on a fresh list
rather than a shared default.

) -> None:
"""Initialize PrecisionConverter.

Expand Down Expand Up @@ -109,6 +113,9 @@ def __init__(
self.original_network_io.update(
{io.name: io.type.tensor_type.elem_type for io in self.model.graph.output}
)
self.min_opset = min_opset
self.max_ir_version = max_ir_version
self.trt_plugins = trt_plugins

def convert(
self,
Expand All @@ -132,6 +139,8 @@ def convert(
"AutoCast can only operate on valid ONNX models, but the input model is invalid. See log for details."
)

self._sanitize_model()

# Filter out nodes that are not allowed to be in low precision
# This is done here and not in NodeClassifier because it is required for the model to be valid
high_precision_nodes, low_precision_nodes = self._filter_unsupported_op_types(
Expand Down Expand Up @@ -1050,3 +1059,13 @@ def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool:
get_consumer_nodes = utils.get_consumer_nodes(self.model, const_producer.output[0])
return len(get_consumer_nodes) == 1 and get_consumer_nodes[0] == node
return False

def _sanitize_model(self):
graph_sanitizer = GraphSanitizer(
self.model,
self.min_opset,
trt_plugins=self.trt_plugins,
max_ir_version=self.max_ir_version,
)
graph_sanitizer.sanitize()
self.model = graph_sanitizer.model
70 changes: 70 additions & 0 deletions tests/unit/onnx/autocast/test_precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

configure_logging("DEBUG")

LATEST_IR_VERSION_SUPPORTED_BY_ORT = 10

def low_precision_onnx_type(low_precision_type_str):
return TensorProto.FLOAT16 if low_precision_type_str == "fp16" else TensorProto.BFLOAT16
Expand Down Expand Up @@ -1101,3 +1102,72 @@ def test_multiple_output_node_casted_to_output(
high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"]
)
onnx.checker.check_model(converted_model)

@pytest.fixture
def model_with_casted_input_to_output():
"""Create a model with an output produced by a Cast node."""
# Create input and outputs
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) # Intermediate output
y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) # Final output

# Create constant value
const = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)

# Create constant node
const_node = helper.make_node(
"Constant",
[],
["const"],
name="const",
value=numpy_helper.from_array(const, name="const_value"),
)

# Create computation nodes
add1 = helper.make_node("Add", ["X", "const"], ["add1_out"], name="add1")
add2 = helper.make_node("Add", ["add1_out", "const"], ["Y2"], name="add2")

# Create cast node that feeds directly from input to output
cast_input = helper.make_node("Cast", ["X"], ["Y1"], name="cast_input", to=TensorProto.FLOAT)

graph = helper.make_graph(
[const_node, add1, add2, cast_input],
"model_with_casted_output",
[x],
[y1, y2],
[],
)

model = helper.make_model(graph, producer_name="model_with_casted_output")
model.opset_import[0].version = 20
model.ir_version = 10
onnx.checker.check_model(model)

model = onnx_utils.infer_shapes(model)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)

return model, value_info_map, initializer_map, node_to_init_map


@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
@pytest.mark.parametrize("keep_io_types", [True, False])
def test_casted_input_to_output_model(
model_with_casted_input_to_output, low_precision_type, keep_io_types
):
model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output

converter = PrecisionConverter(
model,
value_info_map,
initializer_map,
node_to_init_map,
keep_io_types=keep_io_types,
low_precision_type=low_precision_type,
min_opset=22 if low_precision_type == "bf16" else 13,
max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT,
trt_plugins=[],
)
converted_model = converter.convert(
high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
)
onnx.checker.check_model(converted_model)