Skip to content

Commit 03529fc

Browse files
committed
Inject identity nodes in sanitizer; revert existing logic; update test
Signed-off-by: Ali Boubezari <[email protected]>
1 parent 9363b09 commit 03529fc

File tree

3 files changed

+39
-36
lines changed

3 files changed

+39
-36
lines changed

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def sanitize(self) -> None:
6565
self.replace_custom_domain_nodes()
6666
self.cleanup_model()
6767
self.set_ir_version(self.max_ir_version)
68+
self.sanitize_io_casts()
6869

6970
def find_custom_nodes(self) -> None:
7071
"""Find custom nodes in the model.
@@ -322,6 +323,33 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
322323
logger.debug(f"Failed to match LayerNorm pattern at {mean_node.name}: {e!s}")
323324
return None
324325

326+
def sanitize_io_casts(self) -> None:
327+
"""Handle the special case where an input is casted directly to an output.
328+
329+
Inject an identity node after the cast node.
330+
"""
331+
model_input_names = {input.name for input in self.model.graph.input}
332+
model_output_names = {output.name for output in self.model.graph.output}
333+
nodes_to_add = []
334+
for node in self.model.graph.node:
335+
if node.op_type == "Cast":
336+
if node.input[0] in model_input_names and node.output[0] in model_output_names:
337+
cast_input_name = node.input[0]
338+
cast_output_name = node.output[0]
339+
cast_new_output_name = cast_input_name + "_io_cast_identity"
340+
nodes_to_add.append(
341+
helper.make_node(
342+
"Identity",
343+
inputs=[cast_new_output_name],
344+
outputs=[cast_output_name],
345+
name=node.name + "_io_cast_identity",
346+
)
347+
)
348+
node.output[0] = cast_new_output_name
349+
350+
for node in nodes_to_add:
351+
self.model.graph.node.append(node)
352+
325353
def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto:
326354
"""Create a LayerNormalization node with optional bias."""
327355
ln_name = f"LayerNorm_{pattern['mean_node'].name}"

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -586,42 +586,9 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
586586
consumer.input[i] = input_tensor
587587

588588
def _remove_preexisting_casts(self) -> None:
589-
# First check for special case where an input is casted directly to an output
590-
model_input_names = {input.name for input in self.model.graph.input}
591-
model_output_names = {output.name for output in self.model.graph.output}
592-
# Ensure that special casts that we add are not removed by the following logic
593-
casts_to_skip = []
594-
# Add casts as a separate step to avoid modifying the graph while iterating over it
595-
casts_to_add = []
596-
for node in self.model.graph.node:
597-
if node.op_type == "Cast":
598-
if node.input[0] in model_input_names and node.output[0] in model_output_names:
599-
# Create a special cast just for the input-output case.
600-
new_cast = helper.make_node(
601-
"Cast",
602-
name=node.name,
603-
inputs=[node.input[0]],
604-
outputs=[node.output[0]],
605-
to=utils.get_cast_to_type(node),
606-
)
607-
casts_to_skip.append(node.name)
608-
casts_to_add.append(new_cast)
609-
# Now adjust the old cast's name, consumers and producers
610-
node.name = f"{node.name}_io_special_case"
611-
node_new_output_name = f"{node.output[0]}_io_special_case"
612-
for consumer in utils.get_consumer_nodes(self.model, node.output[0]):
613-
for i, input_name in enumerate(consumer.input):
614-
if input_name == node.output[0]:
615-
consumer.input[i] = node_new_output_name
616-
node.output[0] = node_new_output_name
617-
618-
for cast in casts_to_add:
619-
self.model.graph.node.append(cast)
620-
casts_to_skip = set(casts_to_skip)
621-
622589
nodes_to_remove = []
623590
for node in self.model.graph.node:
624-
if node.op_type == "Cast" and node.name not in casts_to_skip:
591+
if node.op_type == "Cast":
625592
cast_from_type = self._get_tensor_type(node.input[0])
626593
cast_to_type = utils.get_cast_to_type(node)
627594
is_fp_cast = cast_to_type in [

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import modelopt.onnx.autocast.utils as utils
2222
import modelopt.onnx.utils as onnx_utils
23+
from modelopt.onnx.autocast.graphsanitizer import GraphSanitizer
2324
from modelopt.onnx.autocast.logging_config import configure_logging
2425
from modelopt.onnx.autocast.precisionconverter import PrecisionConverter
2526

@@ -1072,15 +1073,22 @@ def model_with_casted_input_to_output():
10721073

10731074

10741075
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1075-
def test_casted_input_to_output_model(model_with_casted_input_to_output, low_precision_type):
1076+
@pytest.mark.parametrize("keep_io_types", [True, False])
1077+
def test_casted_input_to_output_model(
1078+
model_with_casted_input_to_output, low_precision_type, keep_io_types
1079+
):
10761080
model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output
10771081

1082+
min_opset = 22 if low_precision_type == "bf16" else 13
1083+
graph_sanitizer = GraphSanitizer(model, min_opset)
1084+
graph_sanitizer.sanitize()
1085+
model = graph_sanitizer.model
10781086
converter = PrecisionConverter(
10791087
model,
10801088
value_info_map,
10811089
initializer_map,
10821090
node_to_init_map,
1083-
keep_io_types=True,
1091+
keep_io_types=keep_io_types,
10841092
low_precision_type=low_precision_type,
10851093
)
10861094
converted_model = converter.convert(

0 commit comments

Comments
 (0)