Skip to content

Commit 4028e2a

Browse files
committed
[Autocast] Fix edge case casting input directly to output
Update modelopt/onnx/autocast/precisionconverter.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: aboubezari <[email protected]> cleanup Signed-off-by: Ali Boubezari <[email protected]> Inject identity nodes in sanitizer; revert existing logic; update test Signed-off-by: Ali Boubezari <[email protected]> Inject identity nodes in sanitizer; revert existing logic; update test Signed-off-by: Ali Boubezari <[email protected]> move pass Signed-off-by: Ali Boubezari <[email protected]> call sanitizer in precision converter Signed-off-by: Ali Boubezari <[email protected]>
1 parent 17439e6 commit 4028e2a

File tree

3 files changed

+127
-0
lines changed

3 files changed

+127
-0
lines changed

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def sanitize(self) -> None:
6363
self.ensure_graph_name_exists()
6464
onnx_utils.name_onnx_nodes(self.model.graph)
6565
self.replace_custom_domain_nodes()
66+
self.sanitize_io_casts()
6667
self.cleanup_model()
6768
self.set_ir_version(self.max_ir_version)
6869
self.convert_fp64_to_fp32()
@@ -343,6 +344,43 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
343344
logger.debug(f"Failed to match LayerNorm pattern at {mean_node.name}: {e!s}")
344345
return None
345346

347+
def sanitize_io_casts(self) -> None:
348+
"""Handle the special case where an input is casted directly to an output.
349+
350+
Inject an identity node after the cast node.
351+
"""
352+
model_input_names = {input.name for input in self.model.graph.input}
353+
model_output_names = {output.name for output in self.model.graph.output}
354+
nodes_to_add = []
355+
for node in self.model.graph.node:
356+
if (
357+
node.op_type == "Cast"
358+
and node.input
359+
and node.output
360+
and node.input[0] in model_input_names
361+
and node.output[0] in model_output_names
362+
):
363+
# Unique per graph output to avoid collisions when multiple outputs are cast from the same input
364+
cast_output_name = node.output[0]
365+
cast_new_output_name = f"{cast_output_name}__io_cast_src"
366+
nodes_to_add.append(
367+
helper.make_node(
368+
"Identity",
369+
inputs=[cast_new_output_name],
370+
outputs=[cast_output_name],
371+
name=f"{node.name}__io_cast_identity",
372+
)
373+
)
374+
# Rewire Cast to produce the new intermediate
375+
node.output[0] = cast_new_output_name
376+
377+
for node in nodes_to_add:
378+
self.model.graph.node.append(node)
379+
380+
# Make sure the graph is topologically sorted
381+
gs_graph = gs.import_onnx(self.model).cleanup().toposort()
382+
self.model = gs.export_onnx(gs_graph)
383+
346384
def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto:
347385
"""Create a LayerNormalization node with optional bias."""
348386
ln_name = f"LayerNorm_{pattern['mean_node'].name}"

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import modelopt.onnx.autocast.utils as utils
3434
import modelopt.onnx.utils as onnx_utils
35+
from modelopt.onnx.autocast.graphsanitizer import GraphSanitizer
3536
from modelopt.onnx.autocast.logging_config import configure_logging, logger
3637

3738
configure_logging()
@@ -73,6 +74,9 @@ def __init__(
7374
low_precision_type: str = "fp16",
7475
init_conversion_max_bytes: int | None = None,
7576
custom_ops: set[str] | None = None,
77+
min_opset: int = 13,
78+
max_ir_version: int | None = None,
79+
trt_plugins: list[str] | None = [],
7680
) -> None:
7781
"""Initialize PrecisionConverter.
7882
@@ -109,6 +113,9 @@ def __init__(
109113
self.original_network_io.update(
110114
{io.name: io.type.tensor_type.elem_type for io in self.model.graph.output}
111115
)
116+
self.min_opset = min_opset
117+
self.max_ir_version = max_ir_version
118+
self.trt_plugins = trt_plugins
112119

113120
def convert(
114121
self,
@@ -132,6 +139,8 @@ def convert(
132139
"AutoCast can only operate on valid ONNX models, but the input model is invalid. See log for details."
133140
)
134141

142+
self._sanitize_model()
143+
135144
# Filter out nodes that are not allowed to be in low precision
136145
# This is done here and not in NodeClassifier because it is required for the model to be valid
137146
high_precision_nodes, low_precision_nodes = self._filter_unsupported_op_types(
@@ -1050,3 +1059,13 @@ def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool:
10501059
get_consumer_nodes = utils.get_consumer_nodes(self.model, const_producer.output[0])
10511060
return len(get_consumer_nodes) == 1 and get_consumer_nodes[0] == node
10521061
return False
1062+
1063+
def _sanitize_model(self):
1064+
graph_sanitizer = GraphSanitizer(
1065+
self.model,
1066+
self.min_opset,
1067+
trt_plugins=self.trt_plugins,
1068+
max_ir_version=self.max_ir_version,
1069+
)
1070+
graph_sanitizer.sanitize()
1071+
self.model = graph_sanitizer.model

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
configure_logging("DEBUG")
2727

28+
LATEST_IR_VERSION_SUPPORTED_BY_ORT = 10
2829

2930
def low_precision_onnx_type(low_precision_type_str):
3031
return TensorProto.FLOAT16 if low_precision_type_str == "fp16" else TensorProto.BFLOAT16
@@ -1101,3 +1102,72 @@ def test_multiple_output_node_casted_to_output(
11011102
high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"]
11021103
)
11031104
onnx.checker.check_model(converted_model)
1105+
1106+
@pytest.fixture
1107+
def model_with_casted_input_to_output():
1108+
"""Create a model with an output produced by a Cast node."""
1109+
# Create input and outputs
1110+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
1111+
y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) # Intermediate output
1112+
y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) # Final output
1113+
1114+
# Create constant value
1115+
const = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
1116+
1117+
# Create constant node
1118+
const_node = helper.make_node(
1119+
"Constant",
1120+
[],
1121+
["const"],
1122+
name="const",
1123+
value=numpy_helper.from_array(const, name="const_value"),
1124+
)
1125+
1126+
# Create computation nodes
1127+
add1 = helper.make_node("Add", ["X", "const"], ["add1_out"], name="add1")
1128+
add2 = helper.make_node("Add", ["add1_out", "const"], ["Y2"], name="add2")
1129+
1130+
# Create cast node that feeds directly from input to output
1131+
cast_input = helper.make_node("Cast", ["X"], ["Y1"], name="cast_input", to=TensorProto.FLOAT)
1132+
1133+
graph = helper.make_graph(
1134+
[const_node, add1, add2, cast_input],
1135+
"model_with_casted_output",
1136+
[x],
1137+
[y1, y2],
1138+
[],
1139+
)
1140+
1141+
model = helper.make_model(graph, producer_name="model_with_casted_output")
1142+
model.opset_import[0].version = 20
1143+
model.ir_version = 10
1144+
onnx.checker.check_model(model)
1145+
1146+
model = onnx_utils.infer_shapes(model)
1147+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1148+
1149+
return model, value_info_map, initializer_map, node_to_init_map
1150+
1151+
1152+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1153+
@pytest.mark.parametrize("keep_io_types", [True, False])
1154+
def test_casted_input_to_output_model(
1155+
model_with_casted_input_to_output, low_precision_type, keep_io_types
1156+
):
1157+
model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output
1158+
1159+
converter = PrecisionConverter(
1160+
model,
1161+
value_info_map,
1162+
initializer_map,
1163+
node_to_init_map,
1164+
keep_io_types=keep_io_types,
1165+
low_precision_type=low_precision_type,
1166+
min_opset=22 if low_precision_type == "bf16" else 13,
1167+
max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT,
1168+
trt_plugins=[],
1169+
)
1170+
converted_model = converter.convert(
1171+
high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
1172+
)
1173+
onnx.checker.check_model(converted_model)

0 commit comments

Comments
 (0)