-
Notifications
You must be signed in to change notification settings - Fork 283
Add an ability to annotate a model with layer annotations. #2361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
yuslepukhin
wants to merge
6
commits into
main
Choose a base branch
from
yuslepukhin/convert_layer_annotate
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+524
−10
Open
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
868fa69
Implement layer annotation logic
yuslepukhin f3a18f3
Add unit tests
yuslepukhin 956b56d
Lint
yuslepukhin 954f5b4
Add more test per code review
yuslepukhin 2250d69
Address more review comments
yuslepukhin 413a177
Address more issues
yuslepukhin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
| import logging | ||
|
|
||
| import onnx | ||
| from onnxscript import ir | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _flatten_annotations(layer_annotations: dict[str, list[str]]) -> list[tuple[str, str]]: | ||
| """Flatten a dict of {layer_name: [substring, ...]} into [(substring, layer_name), ...].""" | ||
| return [(substring, layer_name) for layer_name, substrings in layer_annotations.items() for substring in substrings] | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # ir.Model path (used by conversion.py) | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _annotate_ir_graph(graph: ir.Graph, substring_annotations: list[tuple[str, str]]) -> None: | ||
| """Annotate nodes in an ir.Graph, recursing into subgraphs.""" | ||
| for node in graph: | ||
| if node.name is None: | ||
| continue | ||
|
|
||
| matched_annotation = None | ||
| for substring, annotation in substring_annotations: | ||
| if substring in node.name: | ||
| matched_annotation = annotation | ||
yuslepukhin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if matched_annotation is not None: | ||
| node.metadata_props["layer_ann"] = matched_annotation | ||
|
|
||
| # Recurse into subgraphs for control-flow nodes (If, Loop, etc.) | ||
| for attr in node.attributes.values(): | ||
| if isinstance(attr, ir.Attr) and attr.type == ir.AttributeType.GRAPH: | ||
| _annotate_ir_graph(attr.value, substring_annotations) | ||
| elif isinstance(attr, ir.Attr) and attr.type == ir.AttributeType.GRAPHS: | ||
| for sub_graph in attr.value: | ||
| _annotate_ir_graph(sub_graph, substring_annotations) | ||
|
|
||
|
|
||
| def annotate_ir_model(model: ir.Model, layer_annotations: dict[str, list[str]]) -> None: | ||
| """Annotate an onnxscript ir.Model with layer annotations. | ||
|
|
||
| For each node whose name contains a configured substring, a metadata property | ||
| ``layer_ann`` is set to the corresponding layer name. If multiple substrings | ||
| match, the last one in iteration order wins (consistent with the ORT reference | ||
| implementation). | ||
|
|
||
| :param model: The onnxscript IR model to annotate. | ||
| :param layer_annotations: Mapping of layer name to list of node-name substrings. | ||
| """ | ||
| substring_annotations = _flatten_annotations(layer_annotations) | ||
| _annotate_ir_graph(model.graph, substring_annotations) | ||
|
|
||
yuslepukhin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # onnx.ModelProto path (used by model_builder.py) | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _annotate_proto_graph(graph: onnx.GraphProto, substring_annotations: list[tuple[str, str]]) -> None: | ||
| """Annotate nodes in an onnx.GraphProto, recursing into subgraphs.""" | ||
| for node in graph.node: | ||
| matched_annotation = None | ||
| for substring, annotation in substring_annotations: | ||
| if substring in node.name: | ||
| matched_annotation = annotation | ||
|
|
||
yuslepukhin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if matched_annotation is not None: | ||
| entry = None | ||
| for prop in node.metadata_props: | ||
| if prop.key == "layer_ann": | ||
| entry = prop | ||
| break | ||
|
|
||
| if entry: | ||
| entry.value = matched_annotation | ||
| else: | ||
| entry = node.metadata_props.add() | ||
| entry.key = "layer_ann" | ||
| entry.value = matched_annotation | ||
|
|
||
| # Recurse into subgraphs for control-flow nodes (If, Loop, etc.) | ||
| for attr in node.attribute: | ||
| if attr.type == onnx.AttributeProto.GRAPH: | ||
| _annotate_proto_graph(attr.g, substring_annotations) | ||
| elif attr.type == onnx.AttributeProto.GRAPHS: | ||
| for sub_graph in attr.graphs: | ||
| _annotate_proto_graph(sub_graph, substring_annotations) | ||
|
|
||
|
|
||
| def annotate_proto_model(model: onnx.ModelProto, layer_annotations: dict[str, list[str]]) -> None: | ||
| """Annotate an onnx.ModelProto with layer annotations. | ||
|
|
||
| For each node whose name contains a configured substring, a metadata property | ||
| ``layer_ann`` is set to the corresponding layer name. If multiple substrings | ||
| match, the last one in iteration order wins (consistent with the ORT reference | ||
| implementation). | ||
|
|
||
| :param model: The ONNX ModelProto to annotate. | ||
| :param layer_annotations: Mapping of layer name to list of node-name substrings. | ||
| """ | ||
| substring_annotations = _flatten_annotations(layer_annotations) | ||
| _annotate_proto_graph(model.graph, substring_annotations) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
| import logging | ||
| from copy import deepcopy | ||
| from typing import Union | ||
|
|
||
| from olive.hardware.accelerator import AcceleratorSpec | ||
| from olive.model import HfModelHandler, PyTorchModelHandler | ||
| from olive.passes import Pass | ||
| from olive.passes.pass_config import BasePassConfig, PassConfigParam | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class CaptureLayerAnnotations(Pass): | ||
| """Capture layer annotation metadata for an ONNX model. | ||
|
|
||
| Given a mapping of layer names to node-name substrings, attaches a | ||
| ``layer_annotations`` dictionary to the model attributes. Downstream | ||
| ONNX conversion passes will read this attribute and annotate each ONNX | ||
| node whose name contains a matching substring with a ``layer_ann`` | ||
| metadata property. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: | ||
| return { | ||
| "layer_annotations": PassConfigParam( | ||
| type_=dict, | ||
| required=True, | ||
| description=( | ||
| "Mapping of layer name to a list of node-name substrings. " | ||
| 'For example: {"encoder": ["attn", "mlp"], "decoder": ["cross_attn"]}. ' | ||
| "During ONNX conversion every node whose name contains a listed substring " | ||
| "will receive a metadata property 'layer_ann' set to the layer name." | ||
| ), | ||
| ), | ||
| } | ||
|
|
||
| @classmethod | ||
| def validate_config( | ||
| cls, | ||
| config: type[BasePassConfig], | ||
| accelerator_spec: AcceleratorSpec, | ||
| ) -> bool: | ||
| if not super().validate_config(config, accelerator_spec): | ||
| return False | ||
|
|
||
| if not config.layer_annotations: | ||
| logger.info("layer_annotations must be a non-empty dictionary.") | ||
| return False | ||
|
|
||
yuslepukhin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return True | ||
|
|
||
| def _run_for_config( | ||
| self, model: Union[HfModelHandler, PyTorchModelHandler], config: type[BasePassConfig], output_model_path: str | ||
| ) -> Union[HfModelHandler, PyTorchModelHandler]: | ||
| model.model = None | ||
| output_model = deepcopy(model) | ||
| output_model.model_attributes = model_attributes = output_model.model_attributes or {} | ||
| model_attributes["layer_annotations"] = config.layer_annotations | ||
|
|
||
| return output_model | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
| import pytest | ||
| import torch | ||
|
|
||
| from olive.model import PyTorchModelHandler | ||
| from olive.passes.olive_pass import create_pass_from_dict | ||
| from olive.passes.pytorch.capture_layer_annotations import CaptureLayerAnnotations | ||
|
|
||
|
|
||
| class SimpleModel(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.attn = torch.nn.Linear(4, 4) | ||
| self.mlp = torch.nn.Linear(4, 4) | ||
|
|
||
| def forward(self, x): | ||
| return self.mlp(self.attn(x)) | ||
|
|
||
|
|
||
| def _make_input_model(): | ||
| return PyTorchModelHandler(model_loader=lambda _: SimpleModel()) | ||
|
|
||
|
|
||
| class TestCaptureLayerAnnotations: | ||
| def test_layer_annotations_stored_in_model_attributes(self, tmp_path): | ||
| annotations = {"encoder": ["attn", "mlp"], "decoder": ["cross_attn"]} | ||
| p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True) | ||
|
|
||
| out = p.run(_make_input_model(), str(tmp_path)) | ||
|
|
||
| assert out.model_attributes is not None | ||
| assert out.model_attributes["layer_annotations"] == annotations | ||
|
|
||
| def test_output_is_deep_copy(self, tmp_path): | ||
| annotations = {"layer0": ["sub1"]} | ||
| p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True) | ||
| input_model = _make_input_model() | ||
|
|
||
| out = p.run(input_model, str(tmp_path)) | ||
|
|
||
| assert out is not input_model | ||
| assert input_model.model_attributes is None or "layer_annotations" not in (input_model.model_attributes or {}) | ||
|
|
||
| def test_preserves_existing_model_attributes(self, tmp_path): | ||
| annotations = {"enc": ["attn"]} | ||
| p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True) | ||
| input_model = _make_input_model() | ||
| input_model.model_attributes = {"some_key": "some_value"} | ||
|
|
||
| out = p.run(input_model, str(tmp_path)) | ||
|
|
||
| assert out.model_attributes["some_key"] == "some_value" | ||
| assert out.model_attributes["layer_annotations"] == annotations | ||
|
|
||
| def test_validate_config_rejects_empty_annotations(self): | ||
| from olive.hardware import DEFAULT_CPU_ACCELERATOR | ||
|
|
||
| config = CaptureLayerAnnotations.generate_config(DEFAULT_CPU_ACCELERATOR, {"layer_annotations": {}}) | ||
| assert CaptureLayerAnnotations.validate_config(config, DEFAULT_CPU_ACCELERATOR) is False | ||
|
|
||
| def test_validate_config_accepts_non_empty_annotations(self): | ||
| from olive.hardware import DEFAULT_CPU_ACCELERATOR | ||
|
|
||
| config = CaptureLayerAnnotations.generate_config( | ||
| DEFAULT_CPU_ACCELERATOR, {"layer_annotations": {"enc": ["attn"]}} | ||
| ) | ||
| assert CaptureLayerAnnotations.validate_config(config, DEFAULT_CPU_ACCELERATOR) is True | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "annotations", | ||
| [ | ||
| {"encoder": ["attn"]}, | ||
| {"a": ["x"], "b": ["y", "z"]}, | ||
| ], | ||
| ) | ||
| def test_various_annotation_mappings(self, annotations, tmp_path): | ||
| p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True) | ||
|
|
||
| out = p.run(_make_input_model(), str(tmp_path)) | ||
|
|
||
| assert out.model_attributes["layer_annotations"] == annotations |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.