Skip to content

Commit ae78b9f

Browse files
authored
[Autocast] Fix edge case casting input directly to output (#305)
Signed-off-by: Ali Boubezari <[email protected]>
1 parent 6ef9954 commit ae78b9f

File tree

3 files changed

+126
-0
lines changed

3 files changed

+126
-0
lines changed

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def sanitize(self) -> None:
6363
self.ensure_graph_name_exists()
6464
onnx_utils.name_onnx_nodes(self.model.graph)
6565
self.replace_custom_domain_nodes()
66+
self.sanitize_io_casts()
6667
self.cleanup_model()
6768
self.set_ir_version(self.max_ir_version)
6869
self.convert_fp64_to_fp32()
@@ -343,6 +344,40 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
343344
logger.debug(f"Failed to match LayerNorm pattern at {mean_node.name}: {e!s}")
344345
return None
345346

347+
def sanitize_io_casts(self) -> None:
348+
"""Handle the special case where an input is casted directly to an output.
349+
350+
Inject an identity node after the cast node.
351+
"""
352+
model_input_names = {input.name for input in self.model.graph.input}
353+
model_output_names = {output.name for output in self.model.graph.output}
354+
insertions: list[tuple[int, onnx.NodeProto]] = []
355+
356+
for idx, node in enumerate(self.model.graph.node):
357+
if (
358+
node.op_type == "Cast"
359+
and node.input
360+
and node.output
361+
and node.input[0] in model_input_names
362+
and node.output[0] in model_output_names
363+
):
364+
# Unique per graph output to avoid collisions when multiple outputs are cast from the same input
365+
cast_output_name = node.output[0]
366+
cast_new_output_name = f"{cast_output_name}__io_cast_src"
367+
identity_node = helper.make_node(
368+
"Identity",
369+
inputs=[cast_new_output_name],
370+
outputs=[cast_output_name],
371+
name=f"{node.name}__io_cast_identity",
372+
)
373+
# Rewire Cast to produce the new intermediate
374+
node.output[0] = cast_new_output_name
375+
insertions.append((idx + 1, identity_node))
376+
377+
# Insert Identities in-order right after their corresponding Casts
378+
for offset, (pos, id_node) in enumerate(insertions):
379+
self.model.graph.node.insert(pos + offset, id_node)
380+
346381
def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto:
347382
"""Create a LayerNormalization node with optional bias."""
348383
ln_name = f"LayerNorm_{pattern['mean_node'].name}"

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import modelopt.onnx.autocast.utils as utils
3434
import modelopt.onnx.utils as onnx_utils
35+
from modelopt.onnx.autocast.graphsanitizer import GraphSanitizer
3536
from modelopt.onnx.autocast.logging_config import configure_logging, logger
3637

3738
configure_logging()
@@ -73,6 +74,9 @@ def __init__(
7374
low_precision_type: str = "fp16",
7475
init_conversion_max_bytes: int | None = None,
7576
custom_ops: set[str] | None = None,
77+
min_opset: int = 13,
78+
max_ir_version: int | None = None,
79+
trt_plugins: list[str] | None = [],
7680
) -> None:
7781
"""Initialize PrecisionConverter.
7882
@@ -109,6 +113,9 @@ def __init__(
109113
self.original_network_io.update(
110114
{io.name: io.type.tensor_type.elem_type for io in self.model.graph.output}
111115
)
116+
self.min_opset = min_opset
117+
self.max_ir_version = max_ir_version
118+
self.trt_plugins = trt_plugins
112119

113120
def convert(
114121
self,
@@ -132,6 +139,8 @@ def convert(
132139
"AutoCast can only operate on valid ONNX models, but the input model is invalid. See log for details."
133140
)
134141

142+
self._sanitize_model()
143+
135144
# Filter out nodes that are not allowed to be in low precision
136145
# This is done here and not in NodeClassifier because it is required for the model to be valid
137146
high_precision_nodes, low_precision_nodes = self._filter_unsupported_op_types(
@@ -1050,3 +1059,13 @@ def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool:
10501059
get_consumer_nodes = utils.get_consumer_nodes(self.model, const_producer.output[0])
10511060
return len(get_consumer_nodes) == 1 and get_consumer_nodes[0] == node
10521061
return False
1062+
1063+
def _sanitize_model(self):
1064+
graph_sanitizer = GraphSanitizer(
1065+
self.model,
1066+
self.min_opset,
1067+
trt_plugins=self.trt_plugins,
1068+
max_ir_version=self.max_ir_version,
1069+
)
1070+
graph_sanitizer.sanitize()
1071+
self.model = graph_sanitizer.model

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
configure_logging("DEBUG")
2727

28+
LATEST_IR_VERSION_SUPPORTED_BY_ORT = 10
29+
2830

2931
def low_precision_onnx_type(low_precision_type_str):
3032
return TensorProto.FLOAT16 if low_precision_type_str == "fp16" else TensorProto.BFLOAT16
@@ -1101,3 +1103,73 @@ def test_multiple_output_node_casted_to_output(
11011103
high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"]
11021104
)
11031105
onnx.checker.check_model(converted_model)
1106+
1107+
1108+
@pytest.fixture
1109+
def model_with_casted_input_to_output():
1110+
"""Create a model with an output produced by a Cast node."""
1111+
# Create input and outputs
1112+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
1113+
y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) # Intermediate output
1114+
y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) # Final output
1115+
1116+
# Create constant value
1117+
const = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
1118+
1119+
# Create constant node
1120+
const_node = helper.make_node(
1121+
"Constant",
1122+
[],
1123+
["const"],
1124+
name="const",
1125+
value=numpy_helper.from_array(const, name="const_value"),
1126+
)
1127+
1128+
# Create computation nodes
1129+
add1 = helper.make_node("Add", ["X", "const"], ["add1_out"], name="add1")
1130+
add2 = helper.make_node("Add", ["add1_out", "const"], ["Y2"], name="add2")
1131+
1132+
# Create cast node that feeds directly from input to output
1133+
cast_input = helper.make_node("Cast", ["X"], ["Y1"], name="cast_input", to=TensorProto.FLOAT)
1134+
1135+
graph = helper.make_graph(
1136+
[const_node, add1, add2, cast_input],
1137+
"model_with_casted_output",
1138+
[x],
1139+
[y1, y2],
1140+
[],
1141+
)
1142+
1143+
model = helper.make_model(graph, producer_name="model_with_casted_output")
1144+
model.opset_import[0].version = 20
1145+
model.ir_version = 10
1146+
onnx.checker.check_model(model)
1147+
1148+
model = onnx_utils.infer_shapes(model)
1149+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1150+
1151+
return model, value_info_map, initializer_map, node_to_init_map
1152+
1153+
1154+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1155+
@pytest.mark.parametrize("keep_io_types", [True, False])
1156+
def test_casted_input_to_output_model(
1157+
model_with_casted_input_to_output, low_precision_type, keep_io_types
1158+
):
1159+
model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output
1160+
1161+
converter = PrecisionConverter(
1162+
model,
1163+
value_info_map,
1164+
initializer_map,
1165+
node_to_init_map,
1166+
keep_io_types=keep_io_types,
1167+
low_precision_type=low_precision_type,
1168+
min_opset=22 if low_precision_type == "bf16" else 13,
1169+
max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT,
1170+
trt_plugins=[],
1171+
)
1172+
converted_model = converter.convert(
1173+
high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
1174+
)
1175+
onnx.checker.check_model(converted_model)

0 commit comments

Comments
 (0)