Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
"supported_algorithms": [ ],
"supported_quantization_encodings": [ ]
},
"CaptureLayerAnnotations": {
"module_path": "olive.passes.pytorch.capture_layer_annotations.CaptureLayerAnnotations",
"supported_providers": [ "*" ],
"supported_accelerators": [ "*" ],
"supported_precisions": [ "*" ],
"supported_algorithms": [ ],
"supported_quantization_encodings": [ ]
},
"ComposeOnnxModels": {
"module_path": "olive.passes.onnx.compose.ComposeOnnxModels",
"supported_providers": [ "*" ],
Expand Down
8 changes: 8 additions & 0 deletions olive/passes/onnx/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,14 @@ def _convert_model_on_device(
)
if split_assignment_encoded:
ir_model.metadata_props["split_assignments"] = split_assignment_encoded

# apply layer annotations if present
layer_annotations = model_attributes.get("layer_annotations")
if layer_annotations:
from olive.passes.onnx.layer_annotation import annotate_ir_model

annotate_ir_model(ir_model, layer_annotations)

output_model = ir_model_to_olive_model(ir_model, output_model_path, config)

output_model.model_attributes = model_attributes
Expand Down
109 changes: 109 additions & 0 deletions olive/passes/onnx/layer_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging

import onnx
from onnxscript import ir

logger = logging.getLogger(__name__)


def _flatten_annotations(layer_annotations: dict[str, list[str]]) -> list[tuple[str, str]]:
"""Flatten a dict of {layer_name: [substring, ...]} into [(substring, layer_name), ...]."""
return [(substring, layer_name) for layer_name, substrings in layer_annotations.items() for substring in substrings]


# ---------------------------------------------------------------------------
# ir.Model path (used by conversion.py)
# ---------------------------------------------------------------------------


def _annotate_ir_graph(graph: ir.Graph, substring_annotations: list[tuple[str, str]]) -> None:
"""Annotate nodes in an ir.Graph, recursing into subgraphs."""
for node in graph:
if node.name is None:
continue

matched_annotation = None
for substring, annotation in substring_annotations:
if substring in node.name:
matched_annotation = annotation

if matched_annotation is not None:
node.metadata_props["layer_ann"] = matched_annotation

# Recurse into subgraphs for control-flow nodes (If, Loop, etc.)
for attr in node.attributes.values():
if isinstance(attr, ir.Attr) and attr.type == ir.AttributeType.GRAPH:
_annotate_ir_graph(attr.value, substring_annotations)
elif isinstance(attr, ir.Attr) and attr.type == ir.AttributeType.GRAPHS:
for sub_graph in attr.value:
_annotate_ir_graph(sub_graph, substring_annotations)


def annotate_ir_model(model: ir.Model, layer_annotations: dict[str, list[str]]) -> None:
"""Annotate an onnxscript ir.Model with layer annotations.

For each node whose name contains a configured substring, a metadata property
``layer_ann`` is set to the corresponding layer name. If multiple substrings
match, the last one in iteration order wins (consistent with the ORT reference
implementation).

:param model: The onnxscript IR model to annotate.
:param layer_annotations: Mapping of layer name to list of node-name substrings.
"""
substring_annotations = _flatten_annotations(layer_annotations)
_annotate_ir_graph(model.graph, substring_annotations)


# ---------------------------------------------------------------------------
# onnx.ModelProto path (used by model_builder.py)
# ---------------------------------------------------------------------------


def _annotate_proto_graph(graph: onnx.GraphProto, substring_annotations: list[tuple[str, str]]) -> None:
"""Annotate nodes in an onnx.GraphProto, recursing into subgraphs."""
for node in graph.node:
matched_annotation = None
for substring, annotation in substring_annotations:
if substring in node.name:
matched_annotation = annotation

if matched_annotation is not None:
entry = None
for prop in node.metadata_props:
if prop.key == "layer_ann":
entry = prop
break

if entry:
entry.value = matched_annotation
else:
entry = node.metadata_props.add()
entry.key = "layer_ann"
entry.value = matched_annotation

# Recurse into subgraphs for control-flow nodes (If, Loop, etc.)
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
_annotate_proto_graph(attr.g, substring_annotations)
elif attr.type == onnx.AttributeProto.GRAPHS:
for sub_graph in attr.graphs:
_annotate_proto_graph(sub_graph, substring_annotations)


def annotate_proto_model(model: onnx.ModelProto, layer_annotations: dict[str, list[str]]) -> None:
"""Annotate an onnx.ModelProto with layer annotations.

For each node whose name contains a configured substring, a metadata property
``layer_ann`` is set to the corresponding layer name. If multiple substrings
match, the last one in iteration order wins (consistent with the ORT reference
implementation).

:param model: The ONNX ModelProto to annotate.
:param layer_annotations: Mapping of layer name to list of node-name substrings.
"""
substring_annotations = _flatten_annotations(layer_annotations)
_annotate_proto_graph(model.graph, substring_annotations)
9 changes: 9 additions & 0 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,15 @@ def _run_for_config(
model_proto = onnx.load(output_model_filepath, load_external_data=False)
onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str})
onnx.save(model_proto, output_model_filepath)

# apply layer annotations if present
layer_annotations = model_attributes.get("layer_annotations")
if not metadata_only and layer_annotations:
from olive.passes.onnx.layer_annotation import annotate_proto_model

model_proto = onnx.load(output_model_filepath, load_external_data=False)
annotate_proto_model(model_proto, layer_annotations)
onnx.save(model_proto, output_model_filepath)
except Exception:
# if model building fails, clean up the intermediate files in the cache_dir
cache_dir = Path(HF_HUB_CACHE)
Expand Down
65 changes: 65 additions & 0 deletions olive/passes/pytorch/capture_layer_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from copy import deepcopy
from typing import Union

from olive.hardware.accelerator import AcceleratorSpec
from olive.model import HfModelHandler, PyTorchModelHandler
from olive.passes import Pass
from olive.passes.pass_config import BasePassConfig, PassConfigParam

logger = logging.getLogger(__name__)


class CaptureLayerAnnotations(Pass):
"""Capture layer annotation metadata for an ONNX model.

Given a mapping of layer names to node-name substrings, attaches a
``layer_annotations`` dictionary to the model attributes. Downstream
ONNX conversion passes will read this attribute and annotate each ONNX
node whose name contains a matching substring with a ``layer_ann``
metadata property.
"""

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]:
return {
"layer_annotations": PassConfigParam(
type_=dict,
required=True,
description=(
"Mapping of layer name to a list of node-name substrings. "
'For example: {"encoder": ["attn", "mlp"], "decoder": ["cross_attn"]}. '
"During ONNX conversion every node whose name contains a listed substring "
"will receive a metadata property 'layer_ann' set to the layer name."
),
),
}

@classmethod
def validate_config(
cls,
config: type[BasePassConfig],
accelerator_spec: AcceleratorSpec,
) -> bool:
if not super().validate_config(config, accelerator_spec):
return False

if not config.layer_annotations:
logger.info("layer_annotations must be a non-empty dictionary.")
return False

return True

def _run_for_config(
self, model: Union[HfModelHandler, PyTorchModelHandler], config: type[BasePassConfig], output_model_path: str
) -> Union[HfModelHandler, PyTorchModelHandler]:
model.model = None
output_model = deepcopy(model)
output_model.model_attributes = model_attributes = output_model.model_attributes or {}
model_attributes["layer_annotations"] = config.layer_annotations

return output_model
84 changes: 84 additions & 0 deletions test/passes/pytorch/test_capture_layer_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import pytest
import torch

from olive.model import PyTorchModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.pytorch.capture_layer_annotations import CaptureLayerAnnotations


class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.attn = torch.nn.Linear(4, 4)
self.mlp = torch.nn.Linear(4, 4)

def forward(self, x):
return self.mlp(self.attn(x))


def _make_input_model():
return PyTorchModelHandler(model_loader=lambda _: SimpleModel())


class TestCaptureLayerAnnotations:
def test_layer_annotations_stored_in_model_attributes(self, tmp_path):
annotations = {"encoder": ["attn", "mlp"], "decoder": ["cross_attn"]}
p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True)

out = p.run(_make_input_model(), str(tmp_path))

assert out.model_attributes is not None
assert out.model_attributes["layer_annotations"] == annotations

def test_output_is_deep_copy(self, tmp_path):
annotations = {"layer0": ["sub1"]}
p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True)
input_model = _make_input_model()

out = p.run(input_model, str(tmp_path))

assert out is not input_model
assert input_model.model_attributes is None or "layer_annotations" not in (input_model.model_attributes or {})

def test_preserves_existing_model_attributes(self, tmp_path):
annotations = {"enc": ["attn"]}
p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True)
input_model = _make_input_model()
input_model.model_attributes = {"some_key": "some_value"}

out = p.run(input_model, str(tmp_path))

assert out.model_attributes["some_key"] == "some_value"
assert out.model_attributes["layer_annotations"] == annotations

def test_validate_config_rejects_empty_annotations(self):
from olive.hardware import DEFAULT_CPU_ACCELERATOR

config = CaptureLayerAnnotations.generate_config(DEFAULT_CPU_ACCELERATOR, {"layer_annotations": {}})
assert CaptureLayerAnnotations.validate_config(config, DEFAULT_CPU_ACCELERATOR) is False

def test_validate_config_accepts_non_empty_annotations(self):
from olive.hardware import DEFAULT_CPU_ACCELERATOR

config = CaptureLayerAnnotations.generate_config(
DEFAULT_CPU_ACCELERATOR, {"layer_annotations": {"enc": ["attn"]}}
)
assert CaptureLayerAnnotations.validate_config(config, DEFAULT_CPU_ACCELERATOR) is True

@pytest.mark.parametrize(
"annotations",
[
{"encoder": ["attn"]},
{"a": ["x"], "b": ["y", "z"]},
],
)
def test_various_annotation_mappings(self, annotations, tmp_path):
p = create_pass_from_dict(CaptureLayerAnnotations, {"layer_annotations": annotations}, disable_search=True)

out = p.run(_make_input_model(), str(tmp_path))

assert out.model_attributes["layer_annotations"] == annotations
Loading