From 4028e2aa32d328e3cc227d8d3b3d10b81d72c007 Mon Sep 17 00:00:00 2001 From: Ali Boubezari Date: Mon, 8 Sep 2025 17:01:35 -0700 Subject: [PATCH] [Autocast] Fix edge case casting input directly to output Update modelopt/onnx/autocast/precisionconverter.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: aboubezari <126983138+aboubezari@users.noreply.github.com> cleanup Signed-off-by: Ali Boubezari Inject identity nodes in sanitizer; revert existing logic; update test Signed-off-by: Ali Boubezari Inject identity nodes in sanitizer; revert existing logic; update test Signed-off-by: Ali Boubezari move pass Signed-off-by: Ali Boubezari call sanitizer in precision converter Signed-off-by: Ali Boubezari --- modelopt/onnx/autocast/graphsanitizer.py | 38 ++++++++++ modelopt/onnx/autocast/precisionconverter.py | 19 +++++ .../onnx/autocast/test_precisionconverter.py | 70 +++++++++++++++++++ 3 files changed, 127 insertions(+) diff --git a/modelopt/onnx/autocast/graphsanitizer.py b/modelopt/onnx/autocast/graphsanitizer.py index 49b78c749..1d5c78bd9 100644 --- a/modelopt/onnx/autocast/graphsanitizer.py +++ b/modelopt/onnx/autocast/graphsanitizer.py @@ -63,6 +63,7 @@ def sanitize(self) -> None: self.ensure_graph_name_exists() onnx_utils.name_onnx_nodes(self.model.graph) self.replace_custom_domain_nodes() + self.sanitize_io_casts() self.cleanup_model() self.set_ir_version(self.max_ir_version) self.convert_fp64_to_fp32() @@ -343,6 +344,43 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None: logger.debug(f"Failed to match LayerNorm pattern at {mean_node.name}: {e!s}") return None + def sanitize_io_casts(self) -> None: + """Handle the special case where an input is casted directly to an output. + + Inject an identity node after the cast node. + """ + model_input_names = {input.name for input in self.model.graph.input} + model_output_names = {output.name for output in self.model.graph.output} + nodes_to_add = [] + for node in self.model.graph.node: + if ( + node.op_type == "Cast" + and node.input + and node.output + and node.input[0] in model_input_names + and node.output[0] in model_output_names + ): + # Unique per graph output to avoid collisions when multiple outputs are cast from the same input + cast_output_name = node.output[0] + cast_new_output_name = f"{cast_output_name}__io_cast_src" + nodes_to_add.append( + helper.make_node( + "Identity", + inputs=[cast_new_output_name], + outputs=[cast_output_name], + name=f"{node.name}__io_cast_identity", + ) + ) + # Rewire Cast to produce the new intermediate + node.output[0] = cast_new_output_name + + for node in nodes_to_add: + self.model.graph.node.append(node) + + # Make sure the graph is topologically sorted + gs_graph = gs.import_onnx(self.model).cleanup().toposort() + self.model = gs.export_onnx(gs_graph) + def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto: """Create a LayerNormalization node with optional bias.""" ln_name = f"LayerNorm_{pattern['mean_node'].name}" diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 210280067..ac2829669 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -32,6 +32,7 @@ import modelopt.onnx.autocast.utils as utils import modelopt.onnx.utils as onnx_utils +from modelopt.onnx.autocast.graphsanitizer import GraphSanitizer from modelopt.onnx.autocast.logging_config import configure_logging, logger configure_logging() @@ -73,6 +74,9 @@ def __init__( low_precision_type: str = "fp16", init_conversion_max_bytes: int | None = None, custom_ops: set[str] | None = None, + min_opset: int = 13, + max_ir_version: int | None = None, + trt_plugins: list[str] | None = [], ) -> None: """Initialize PrecisionConverter. @@ -109,6 +113,9 @@ def __init__( self.original_network_io.update( {io.name: io.type.tensor_type.elem_type for io in self.model.graph.output} ) + self.min_opset = min_opset + self.max_ir_version = max_ir_version + self.trt_plugins = trt_plugins def convert( self, @@ -132,6 +139,8 @@ def convert( "AutoCast can only operate on valid ONNX models, but the input model is invalid. See log for details." ) + self._sanitize_model() + # Filter out nodes that are not allowed to be in low precision # This is done here and not in NodeClassifier because it is required for the model to be valid high_precision_nodes, low_precision_nodes = self._filter_unsupported_op_types( @@ -1050,3 +1059,13 @@ def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool: get_consumer_nodes = utils.get_consumer_nodes(self.model, const_producer.output[0]) return len(get_consumer_nodes) == 1 and get_consumer_nodes[0] == node return False + + def _sanitize_model(self): + graph_sanitizer = GraphSanitizer( + self.model, + self.min_opset, + trt_plugins=self.trt_plugins, + max_ir_version=self.max_ir_version, + ) + graph_sanitizer.sanitize() + self.model = graph_sanitizer.model diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index a85bed287..bc99464ad 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -25,6 +25,7 @@ configure_logging("DEBUG") +LATEST_IR_VERSION_SUPPORTED_BY_ORT = 10 def low_precision_onnx_type(low_precision_type_str): return TensorProto.FLOAT16 if low_precision_type_str == "fp16" else TensorProto.BFLOAT16 @@ -1101,3 +1102,72 @@ def test_multiple_output_node_casted_to_output( high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"] ) onnx.checker.check_model(converted_model) + +@pytest.fixture +def model_with_casted_input_to_output(): + """Create a model with an output produced by a Cast node.""" + # Create input and outputs + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3]) + y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) # Intermediate output + y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) # Final output + + # Create constant value + const = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + + # Create constant node + const_node = helper.make_node( + "Constant", + [], + ["const"], + name="const", + value=numpy_helper.from_array(const, name="const_value"), + ) + + # Create computation nodes + add1 = helper.make_node("Add", ["X", "const"], ["add1_out"], name="add1") + add2 = helper.make_node("Add", ["add1_out", "const"], ["Y2"], name="add2") + + # Create cast node that feeds directly from input to output + cast_input = helper.make_node("Cast", ["X"], ["Y1"], name="cast_input", to=TensorProto.FLOAT) + + graph = helper.make_graph( + [const_node, add1, add2, cast_input], + "model_with_casted_output", + [x], + [y1, y2], + [], + ) + + model = helper.make_model(graph, producer_name="model_with_casted_output") + model.opset_import[0].version = 20 + model.ir_version = 10 + onnx.checker.check_model(model) + + model = onnx_utils.infer_shapes(model) + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + + return model, value_info_map, initializer_map, node_to_init_map + + +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("keep_io_types", [True, False]) +def test_casted_input_to_output_model( + model_with_casted_input_to_output, low_precision_type, keep_io_types +): + model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output + + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=keep_io_types, + low_precision_type=low_precision_type, + min_opset=22 if low_precision_type == "bf16" else 13, + max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT, + trt_plugins=[], + ) + converted_model = converter.convert( + high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"] + ) + onnx.checker.check_model(converted_model) \ No newline at end of file