diff --git a/olive/olive_config.json b/olive/olive_config.json index 2748c39101..66613a779d 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -36,6 +36,14 @@ "supported_algorithms": [ ], "supported_quantization_encodings": [ ] }, + "CaptureLayerAnnotations": { + "module_path": "olive.passes.pytorch.capture_layer_annotations.CaptureLayerAnnotations", + "supported_providers": [ "*" ], + "supported_accelerators": [ "*" ], + "supported_precisions": [ "*" ], + "supported_algorithms": [ ], + "supported_quantization_encodings": [ ] + }, "ComposeOnnxModels": { "module_path": "olive.passes.onnx.compose.ComposeOnnxModels", "supported_providers": [ "*" ], diff --git a/olive/passes/onnx/conversion.py b/olive/passes/onnx/conversion.py index c715533309..ad47cb99fa 100644 --- a/olive/passes/onnx/conversion.py +++ b/olive/passes/onnx/conversion.py @@ -738,6 +738,14 @@ def _convert_model_on_device( ) if split_assignment_encoded: ir_model.metadata_props["split_assignments"] = split_assignment_encoded + + # apply layer annotations if present + layer_annotations = model_attributes.get("layer_annotations") + if layer_annotations: + from olive.passes.onnx.layer_annotation import annotate_ir_model + + annotate_ir_model(ir_model, layer_annotations) + output_model = ir_model_to_olive_model(ir_model, output_model_path, config) output_model.model_attributes = model_attributes diff --git a/olive/passes/onnx/layer_annotation.py b/olive/passes/onnx/layer_annotation.py new file mode 100644 index 0000000000..87250cef0c --- /dev/null +++ b/olive/passes/onnx/layer_annotation.py @@ -0,0 +1,114 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +import onnx +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +def _flatten_annotations(layer_annotations: dict[str, list[str]]) -> list[tuple[str, str]]: + """Flatten a dict of {layer_name: [substring, ...]} into [(substring, layer_name), ...].""" + return [(substring, layer_name) for layer_name, substrings in layer_annotations.items() for substring in substrings] + + +# --------------------------------------------------------------------------- +# ir.Model path (used by conversion.py) +# --------------------------------------------------------------------------- + + +def _annotate_ir_graph(graph: ir.Graph, substring_annotations: list[tuple[str, str]]) -> None: + """Annotate nodes in an ir.Graph, recursing into subgraphs.""" + for node in graph: + if node.name is None: + continue + + matched_annotation = None + for substring, annotation in substring_annotations: + if substring in node.name: + matched_annotation = annotation + break # If multiple annotations match, the first one in the list wins (consistent with ORT reference implementation) + + if matched_annotation is not None: + node.metadata_props["layer_ann"] = matched_annotation + + # Recurse into subgraphs for control-flow nodes (If, Loop, etc.) + for attr in node.attributes.values(): + if isinstance(attr, ir.Attr) and attr.type == ir.AttributeType.GRAPH: + _annotate_ir_graph(attr.value, substring_annotations) + elif isinstance(attr, ir.Attr) and attr.type == ir.AttributeType.GRAPHS: + for sub_graph in attr.value: + _annotate_ir_graph(sub_graph, substring_annotations) + + +def annotate_ir_model(model: ir.Model, layer_annotations: dict[str, list[str]]) -> None: + """Annotate an onnxscript ir.Model with layer annotations. + + For each node whose name contains a configured substring, a metadata property + ``layer_ann`` is set to the corresponding layer name. If multiple substrings + match, the first one in iteration order wins (consistent with the ORT reference + implementation). + + :param model: The onnxscript IR model to annotate. + :param layer_annotations: Mapping of layer name to list of node-name substrings. + """ + substring_annotations = _flatten_annotations(layer_annotations) + _annotate_ir_graph(model.graph, substring_annotations) + + +# --------------------------------------------------------------------------- +# onnx.ModelProto path (used by model_builder.py) +# --------------------------------------------------------------------------- + + +def _annotate_proto_graph(graph: onnx.GraphProto, substring_annotations: list[tuple[str, str]]) -> None: + """Annotate nodes in an onnx.GraphProto, recursing into subgraphs.""" + for node in graph.node: + if not node.name: + continue + + matched_annotation = None + for substring, annotation in substring_annotations: + if substring in node.name: + matched_annotation = annotation + break # If multiple annotations match, the first one in the list wins (consistent with ORT reference implementation) + + if matched_annotation is not None: + entry = None + for prop in node.metadata_props: + if prop.key == "layer_ann": + entry = prop + break + + if entry: + entry.value = matched_annotation + else: + entry = node.metadata_props.add() + entry.key = "layer_ann" + entry.value = matched_annotation + + # Recurse into subgraphs for control-flow nodes (If, Loop, etc.) + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + _annotate_proto_graph(attr.g, substring_annotations) + elif attr.type == onnx.AttributeProto.GRAPHS: + for sub_graph in attr.graphs: + _annotate_proto_graph(sub_graph, substring_annotations) + + +def annotate_proto_model(model: onnx.ModelProto, layer_annotations: dict[str, list[str]]) -> None: + """Annotate an onnx.ModelProto with layer annotations. + + For each node whose name contains a configured substring, a metadata property + ``layer_ann`` is set to the corresponding layer name. If multiple substrings + match, the first one in iteration order wins (consistent with the ORT reference + implementation). + + :param model: The ONNX ModelProto to annotate. + :param layer_annotations: Mapping of layer name to list of node-name substrings. + """ + substring_annotations = _flatten_annotations(layer_annotations) + _annotate_proto_graph(model.graph, substring_annotations) diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index 5f8dceabc9..978744ec1c 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -283,17 +283,25 @@ def _run_for_config( **extra_args, ) - # add split information if present - split_assignments = model_attributes.get("split_assignments") - if not metadata_only and split_assignments: - # NOTE: currently the model builder renames modules to it's own naming convention - # so the assignments for the renamed modules won't match - split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()]) - - # load the model and set the split_assignments as model properties - # without the external data so that they can be used as is with the resaved model + # Apply post-processing annotations (split assignments and/or layer annotations) + # in a single load/save cycle to avoid redundant disk I/O. + split_assignments = model_attributes.get("split_assignments") if not metadata_only else None + layer_annotations = model_attributes.get("layer_annotations") if not metadata_only else None + + if split_assignments or layer_annotations: model_proto = onnx.load(output_model_filepath, load_external_data=False) - onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str}) + + if split_assignments: + # NOTE: currently the model builder renames modules to it's own naming convention + # so the assignments for the renamed modules won't match + split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()]) + onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str}) + + if layer_annotations: + from olive.passes.onnx.layer_annotation import annotate_proto_model + + annotate_proto_model(model_proto, layer_annotations) + onnx.save(model_proto, output_model_filepath) except Exception: # if model building fails, clean up the intermediate files in the cache_dir diff --git a/olive/passes/pytorch/capture_layer_annotations.py b/olive/passes/pytorch/capture_layer_annotations.py new file mode 100644 index 0000000000..eac40c8e46 --- /dev/null +++ b/olive/passes/pytorch/capture_layer_annotations.py @@ -0,0 +1,76 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from copy import deepcopy +from typing import Union + +from olive.hardware.accelerator import AcceleratorSpec +from olive.model import HfModelHandler, PyTorchModelHandler +from olive.passes import Pass +from olive.passes.pass_config import BasePassConfig, PassConfigParam + +logger = logging.getLogger(__name__) + + +class CaptureLayerAnnotations(Pass): + """Capture layer annotation metadata for an ONNX model. + + Given a mapping of layer names to node-name substrings, attaches a + ``layer_annotations`` dictionary to the model attributes. Downstream + ONNX conversion passes will read this attribute and annotate each ONNX + node whose name contains a matching substring with a ``layer_ann`` + metadata property. + """ + + @classmethod + def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: + return { + "layer_annotations": PassConfigParam( + type_=dict, + required=True, + description=( + "Mapping of layer name to a list of node-name substrings. " + 'For example: {"encoder": ["attn", "mlp"], "decoder": ["cross_attn"]}. ' + "During ONNX conversion every node whose name contains a listed substring " + "will receive a metadata property 'layer_ann' set to the layer name." + ), + ), + } + + @classmethod + def validate_config( + cls, + config: type[BasePassConfig], + accelerator_spec: AcceleratorSpec, + ) -> bool: + if not super().validate_config(config, accelerator_spec): + return False + + if not config.layer_annotations: + logger.info("layer_annotations must be a non-empty dictionary.") + return False + + for key, value in config.layer_annotations.items(): + if not isinstance(key, str) or not key: + logger.info("layer_annotations keys must be non-empty strings, got %r.", key) + return False + if not isinstance(value, list) or not value: + logger.info("layer_annotations[%r] must be a non-empty list of strings, got %r.", key, type(value)) + return False + if not all(isinstance(s, str) and s for s in value): + logger.info("layer_annotations[%r] must contain only non-empty strings.", key) + return False + + return True + + def _run_for_config( + self, model: Union[HfModelHandler, PyTorchModelHandler], config: type[BasePassConfig], output_model_path: str + ) -> Union[HfModelHandler, PyTorchModelHandler]: + model.model = None + output_model = deepcopy(model) + output_model.model_attributes = model_attributes = output_model.model_attributes or {} + model_attributes["layer_annotations"] = config.layer_annotations + + return output_model diff --git a/test/passes/onnx/test_layer_annotation.py b/test/passes/onnx/test_layer_annotation.py new file mode 100644 index 0000000000..a7d7de780b --- /dev/null +++ b/test/passes/onnx/test_layer_annotation.py @@ -0,0 +1,160 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from onnxscript import ir + +from olive.passes.onnx.layer_annotation import _flatten_annotations, annotate_ir_model, annotate_proto_model + + +class TestFlattenAnnotations: + """Test the _flatten_annotations helper function.""" + + def test_flatten_single_layer_single_substring(self): + """Test flattening with a single layer and single substring.""" + annotations = {"layer1": ["sub1"]} + result = _flatten_annotations(annotations) + assert result == [("sub1", "layer1")] + + def test_flatten_single_layer_multiple_substrings(self): + """Test flattening with a single layer and multiple substrings.""" + annotations = {"layer1": ["sub1", "sub2", "sub3"]} + result = _flatten_annotations(annotations) + assert result == [("sub1", "layer1"), ("sub2", "layer1"), ("sub3", "layer1")] + + def test_flatten_multiple_layers(self): + """Test flattening with multiple layers.""" + annotations = {"layer1": ["sub1", "sub2"], "layer2": ["sub3"]} + result = _flatten_annotations(annotations) + assert len(result) == 3 + assert ("sub1", "layer1") in result + assert ("sub2", "layer1") in result + assert ("sub3", "layer2") in result + + def test_flatten_empty(self): + """Test flattening with empty annotations.""" + annotations = {} + result = _flatten_annotations(annotations) + assert result == [] + + +class TestAnnotateIrModel: + """Test the annotate_ir_model function.""" + + def test_annotate_simple_graph(self): + """Test annotating a simple graph with no subgraphs.""" + node1 = ir.Node("", "Add", inputs=[], name="linear_1") + node2 = ir.Node("", "Mul", inputs=[], name="sigmoid_2") + graph = ir.Graph([], [], nodes=[node1, node2], name="test_graph", opset_imports={"": 11}) + model = ir.Model(graph, ir_version=8) + + # Annotate with layer annotations + layer_annotations = { + "layer_linear": ["linear"], + "layer_sigmoid": ["sigmoid"], + } + annotate_ir_model(model, layer_annotations) + + # Check annotations were applied + assert model.graph[0].metadata_props["layer_ann"] == "layer_linear" + assert model.graph[1].metadata_props["layer_ann"] == "layer_sigmoid" + + def test_annotate_no_matching_substrings(self): + """Test that nodes without matching substrings are not annotated.""" + node1 = ir.Node("", "Add", inputs=[], name="add_1") + graph = ir.Graph([], [], nodes=[node1], name="test_graph", opset_imports={"": 11}) + model = ir.Model(graph, ir_version=8) + + layer_annotations = {"layer_other": ["other"]} + annotate_ir_model(model, layer_annotations) + + # Node should not have layer_ann metadata + assert "layer_ann" not in model.graph[0].metadata_props + + def test_annotate_first_match_wins(self): + """Test that the first matching substring wins when multiple match.""" + node = ir.Node("", "Add", inputs=[], name="linear_sigmoid_1") + graph = ir.Graph([], [], nodes=[node], name="test_graph", opset_imports={"": 11}) + model = ir.Model(graph, ir_version=8) + + # both "linear" and "sigmoid" are substrings; "linear" comes first + layer_annotations = { + "layer_linear": ["linear"], + "layer_sigmoid": ["sigmoid"], + } + annotate_ir_model(model, layer_annotations) + + assert model.graph[0].metadata_props["layer_ann"] == "layer_linear" + + def test_annotate_node_with_none_name(self): + """Test that nodes without a name are skipped.""" + node1 = ir.Node("", "Add", inputs=[]) + node1.name = None + node2 = ir.Node("", "Mul", inputs=[], name="linear_1") + graph = ir.Graph([], [], nodes=[node1, node2], name="test_graph", opset_imports={"": 11}) + model = ir.Model(graph, ir_version=8) + + layer_annotations = {"layer_linear": ["linear"]} + annotate_ir_model(model, layer_annotations) + + # First node should not have annotation (no name) + assert "layer_ann" not in model.graph[0].metadata_props + # Second node should have annotation + assert model.graph[1].metadata_props["layer_ann"] == "layer_linear" + + def test_annotate_empty_annotations(self): + """Test annotating with empty layer_annotations.""" + node = ir.Node("", "Add", inputs=[], name="linear_1") + graph = ir.Graph([], [], nodes=[node], name="test_graph", opset_imports={"": 11}) + model = ir.Model(graph, ir_version=8) + + annotate_ir_model(model, {}) + + # Node should not have annotation + assert "layer_ann" not in model.graph[0].metadata_props + + def test_annotate_empty_graph(self): + """Test annotating an empty graph.""" + graph = ir.Graph([], [], nodes=[], name="test_graph", opset_imports={"": 11}) + model = ir.Model(graph, ir_version=8) + + layer_annotations = {"layer_linear": ["linear"]} + # Should not raise any errors + annotate_ir_model(model, layer_annotations) + + +class TestAnnotateProtoModel: + """Test the annotate_proto_model function (onnx.ModelProto path).""" + + @staticmethod + def _make_proto_model(node_names): + """Build a minimal onnx.ModelProto with nodes having the given names.""" + from onnx import TensorProto, helper + + nodes = [ + helper.make_node("Relu", inputs=["x"], outputs=[f"y_{i}"], name=name) for i, name in enumerate(node_names) + ] + graph = helper.make_graph(nodes, "test", [helper.make_tensor_value_info("x", TensorProto.FLOAT, [1])], []) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + + def test_annotate_proto_simple(self): + model_proto = self._make_proto_model(["linear_1", "sigmoid_2"]) + annotate_proto_model(model_proto, {"layer_linear": ["linear"], "layer_sigmoid": ["sigmoid"]}) + + props = {n.name: {p.key: p.value for p in n.metadata_props} for n in model_proto.graph.node} + assert props["linear_1"]["layer_ann"] == "layer_linear" + assert props["sigmoid_2"]["layer_ann"] == "layer_sigmoid" + + def test_annotate_proto_no_match(self): + model_proto = self._make_proto_model(["add_1"]) + annotate_proto_model(model_proto, {"layer_other": ["other"]}) + + assert len(model_proto.graph.node[0].metadata_props) == 0 + + def test_annotate_proto_skips_unnamed_nodes(self): + model_proto = self._make_proto_model(["", "linear_1"]) + annotate_proto_model(model_proto, {"layer_linear": ["linear"]}) + + assert len(model_proto.graph.node[0].metadata_props) == 0 + props = {p.key: p.value for p in model_proto.graph.node[1].metadata_props} + assert props["layer_ann"] == "layer_linear" diff --git a/test/passes/onnx/test_model_builder.py b/test/passes/onnx/test_model_builder.py index feaca585fa..ba62005e4b 100644 --- a/test/passes/onnx/test_model_builder.py +++ b/test/passes/onnx/test_model_builder.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- from pathlib import Path +import onnx import pytest from olive.model import ONNXModelHandler @@ -62,3 +63,40 @@ def test_model_builder_olive_quant(tmp_path, embeds, group_size): assert isinstance(output_model, ONNXModelHandler) assert Path(output_model.model_path).exists() assert Path(output_folder / "genai_config.json").exists() + + +@pytest.mark.parametrize("layer_annotations", [True, False]) +def test_model_builder_layer_annotations(tmp_path, layer_annotations): + """Test that layer annotations are correctly applied to the output ONNX model.""" + input_model = make_local_tiny_llama(tmp_path / "input_model", "hf") + + if layer_annotations: + # Create layer annotations to be applied + # Keys are layer names, values are lists of node-name substrings to match + annotations = { + "embedding_layer": ["embed_tokens"], + "norm_layer": ["norm"], + } + input_model.model_attributes = {"layer_annotations": annotations} + + p = create_pass_from_dict( + ModelBuilder, + {"precision": "fp32"}, + disable_search=True, + ) + output_folder = tmp_path / "output_model" + + # execute the pass + output_model = p.run(input_model, output_folder) + + # assert + assert isinstance(output_model, ONNXModelHandler) + assert Path(output_model.model_path).exists() + + if layer_annotations: + # Verify that metadata properties were applied to nodes + model_proto = onnx.load(output_model.model_path, load_external_data=False) + node_names_with_metadata = {node.name for node in model_proto.graph.node if node.metadata_props} + assert len(node_names_with_metadata) > 0, ( + "Expected nodes with metadata_props when layer_annotations are provided" + ) diff --git a/test/passes/pytorch/test_capture_layer_annotations.py b/test/passes/pytorch/test_capture_layer_annotations.py new file mode 100644 index 0000000000..04cc7daf9e --- /dev/null +++ b/test/passes/pytorch/test_capture_layer_annotations.py @@ -0,0 +1,102 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import pytest +import torch + +from olive.model import PyTorchModelHandler +from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.pytorch.capture_layer_annotations import CaptureLayerAnnotations + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.attn = torch.nn.Linear(4, 4) + self.mlp = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.mlp(self.attn(x)) + + +def _make_input_model(): + return PyTorchModelHandler(model_loader=lambda _: SimpleModel()) + + +class TestCaptureLayerAnnotations: + def test_layer_annotations_stored_in_model_attributes(self, tmp_path): + annotations = {"encoder": ["attn", "mlp"], "decoder": ["cross_attn"]} + p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True) + + out = p.run(_make_input_model(), str(tmp_path)) + + assert out.model_attributes is not None + assert out.model_attributes["layer_annotations"] == annotations + + def test_output_is_deep_copy(self, tmp_path): + annotations = {"layer0": ["sub1"]} + p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True) + input_model = _make_input_model() + + out = p.run(input_model, str(tmp_path)) + + assert out is not input_model + assert input_model.model_attributes is None or "layer_annotations" not in (input_model.model_attributes or {}) + + def test_preserves_existing_model_attributes(self, tmp_path): + annotations = {"enc": ["attn"]} + p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True) + input_model = _make_input_model() + input_model.model_attributes = {"some_key": "some_value"} + + out = p.run(input_model, str(tmp_path)) + + assert out.model_attributes["some_key"] == "some_value" + assert out.model_attributes["layer_annotations"] == annotations + + def test_validate_config_rejects_empty_annotations(self): + from olive.hardware import DEFAULT_CPU_ACCELERATOR + + config = CaptureLayerAnnotations.generate_config(DEFAULT_CPU_ACCELERATOR, {"layer_annotations": {}}) + assert CaptureLayerAnnotations.validate_config(config, DEFAULT_CPU_ACCELERATOR) is False + + def test_validate_config_accepts_non_empty_annotations(self): + from olive.hardware import DEFAULT_CPU_ACCELERATOR + + config = CaptureLayerAnnotations.generate_config( + DEFAULT_CPU_ACCELERATOR, {"layer_annotations": {"enc": ["attn"]}} + ) + assert CaptureLayerAnnotations.validate_config(config, DEFAULT_CPU_ACCELERATOR) is True + + @pytest.mark.parametrize( + "annotations", + [ + {"encoder": ["attn"]}, + {"a": ["x"], "b": ["y", "z"]}, + ], + ) + def test_various_annotation_mappings(self, annotations, tmp_path): + p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True) + + out = p.run(_make_input_model(), str(tmp_path)) + + assert out.model_attributes["layer_annotations"] == annotations + + @pytest.mark.parametrize( + "bad_annotations", + [ + {"enc": "attn"}, # value is a string, not a list + {"enc": []}, # empty list + {"": ["attn"]}, # empty key + {"enc": [""]}, # empty string in list + {"enc": [123]}, # non-string in list + ], + ) + def test_validate_config_rejects_malformed_annotations(self, bad_annotations): + from olive.hardware import DEFAULT_CPU_ACCELERATOR + + config = CaptureLayerAnnotations.generate_config( + DEFAULT_CPU_ACCELERATOR, {"layer_annotations": bad_annotations} + ) + assert CaptureLayerAnnotations.validate_config(config, DEFAULT_CPU_ACCELERATOR) is False