Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
"supported_algorithms": [ ],
"supported_quantization_encodings": [ ]
},
"CaptureLayerAnnotations": {
"module_path": "olive.passes.pytorch.capture_layer_annotations.CaptureLayerAnnotations",
"supported_providers": [ "*" ],
"supported_accelerators": [ "*" ],
"supported_precisions": [ "*" ],
"supported_algorithms": [ ],
"supported_quantization_encodings": [ ]
},
"ComposeOnnxModels": {
"module_path": "olive.passes.onnx.compose.ComposeOnnxModels",
"supported_providers": [ "*" ],
Expand Down
8 changes: 8 additions & 0 deletions olive/passes/onnx/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,14 @@ def _convert_model_on_device(
)
if split_assignment_encoded:
ir_model.metadata_props["split_assignments"] = split_assignment_encoded

# apply layer annotations if present
layer_annotations = model_attributes.get("layer_annotations")
if layer_annotations:
from olive.passes.onnx.layer_annotation import annotate_ir_model

annotate_ir_model(ir_model, layer_annotations)

output_model = ir_model_to_olive_model(ir_model, output_model_path, config)

output_model.model_attributes = model_attributes
Expand Down
114 changes: 114 additions & 0 deletions olive/passes/onnx/layer_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging

import onnx
from onnxscript import ir

logger = logging.getLogger(__name__)


def _flatten_annotations(layer_annotations: dict[str, list[str]]) -> list[tuple[str, str]]:
"""Flatten a dict of {layer_name: [substring, ...]} into [(substring, layer_name), ...]."""
return [(substring, layer_name) for layer_name, substrings in layer_annotations.items() for substring in substrings]


# ---------------------------------------------------------------------------
# ir.Model path (used by conversion.py)
# ---------------------------------------------------------------------------


def _annotate_ir_graph(graph: ir.Graph, substring_annotations: list[tuple[str, str]]) -> None:
"""Annotate nodes in an ir.Graph, recursing into subgraphs."""
for node in graph:
if node.name is None:
continue

matched_annotation = None
for substring, annotation in substring_annotations:
if substring in node.name:
matched_annotation = annotation
break # If multiple annotations match, the first one in the list wins (consistent with ORT reference implementation)

if matched_annotation is not None:
node.metadata_props["layer_ann"] = matched_annotation

# Recurse into subgraphs for control-flow nodes (If, Loop, etc.)
for attr in node.attributes.values():
if isinstance(attr, ir.Attr) and attr.type == ir.AttributeType.GRAPH:
_annotate_ir_graph(attr.value, substring_annotations)
elif isinstance(attr, ir.Attr) and attr.type == ir.AttributeType.GRAPHS:
for sub_graph in attr.value:
_annotate_ir_graph(sub_graph, substring_annotations)


def annotate_ir_model(model: ir.Model, layer_annotations: dict[str, list[str]]) -> None:
"""Annotate an onnxscript ir.Model with layer annotations.
For each node whose name contains a configured substring, a metadata property
``layer_ann`` is set to the corresponding layer name. If multiple substrings
match, the first one in iteration order wins (consistent with the ORT reference
implementation).
:param model: The onnxscript IR model to annotate.
:param layer_annotations: Mapping of layer name to list of node-name substrings.
"""
substring_annotations = _flatten_annotations(layer_annotations)
_annotate_ir_graph(model.graph, substring_annotations)


# ---------------------------------------------------------------------------
# onnx.ModelProto path (used by model_builder.py)
# ---------------------------------------------------------------------------


def _annotate_proto_graph(graph: onnx.GraphProto, substring_annotations: list[tuple[str, str]]) -> None:
"""Annotate nodes in an onnx.GraphProto, recursing into subgraphs."""
for node in graph.node:
if not node.name:
continue

matched_annotation = None
for substring, annotation in substring_annotations:
if substring in node.name:
matched_annotation = annotation
break # If multiple annotations match, the first one in the list wins (consistent with ORT reference implementation)

if matched_annotation is not None:
entry = None
for prop in node.metadata_props:
if prop.key == "layer_ann":
entry = prop
break

if entry:
entry.value = matched_annotation
else:
entry = node.metadata_props.add()
entry.key = "layer_ann"
entry.value = matched_annotation

# Recurse into subgraphs for control-flow nodes (If, Loop, etc.)
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
_annotate_proto_graph(attr.g, substring_annotations)
elif attr.type == onnx.AttributeProto.GRAPHS:
for sub_graph in attr.graphs:
_annotate_proto_graph(sub_graph, substring_annotations)


def annotate_proto_model(model: onnx.ModelProto, layer_annotations: dict[str, list[str]]) -> None:
"""Annotate an onnx.ModelProto with layer annotations.
For each node whose name contains a configured substring, a metadata property
``layer_ann`` is set to the corresponding layer name. If multiple substrings
match, the first one in iteration order wins (consistent with the ORT reference
implementation).
:param model: The ONNX ModelProto to annotate.
:param layer_annotations: Mapping of layer name to list of node-name substrings.
"""
substring_annotations = _flatten_annotations(layer_annotations)
_annotate_proto_graph(model.graph, substring_annotations)
28 changes: 18 additions & 10 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,17 +283,25 @@ def _run_for_config(
**extra_args,
)

# add split information if present
split_assignments = model_attributes.get("split_assignments")
if not metadata_only and split_assignments:
# NOTE: currently the model builder renames modules to it's own naming convention
# so the assignments for the renamed modules won't match
split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()])

# load the model and set the split_assignments as model properties
# without the external data so that they can be used as is with the resaved model
# Apply post-processing annotations (split assignments and/or layer annotations)
# in a single load/save cycle to avoid redundant disk I/O.
split_assignments = model_attributes.get("split_assignments") if not metadata_only else None
layer_annotations = model_attributes.get("layer_annotations") if not metadata_only else None

if split_assignments or layer_annotations:
model_proto = onnx.load(output_model_filepath, load_external_data=False)
onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str})

if split_assignments:
# NOTE: currently the model builder renames modules to it's own naming convention
# so the assignments for the renamed modules won't match
split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()])
onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str})

if layer_annotations:
from olive.passes.onnx.layer_annotation import annotate_proto_model

annotate_proto_model(model_proto, layer_annotations)

onnx.save(model_proto, output_model_filepath)
except Exception:
# if model building fails, clean up the intermediate files in the cache_dir
Expand Down
76 changes: 76 additions & 0 deletions olive/passes/pytorch/capture_layer_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from copy import deepcopy
from typing import Union

from olive.hardware.accelerator import AcceleratorSpec
from olive.model import HfModelHandler, PyTorchModelHandler
from olive.passes import Pass
from olive.passes.pass_config import BasePassConfig, PassConfigParam

logger = logging.getLogger(__name__)


class CaptureLayerAnnotations(Pass):
"""Capture layer annotation metadata for an ONNX model.
Given a mapping of layer names to node-name substrings, attaches a
``layer_annotations`` dictionary to the model attributes. Downstream
ONNX conversion passes will read this attribute and annotate each ONNX
node whose name contains a matching substring with a ``layer_ann``
metadata property.
"""

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]:
return {
"layer_annotations": PassConfigParam(
type_=dict,
required=True,
description=(
"Mapping of layer name to a list of node-name substrings. "
'For example: {"encoder": ["attn", "mlp"], "decoder": ["cross_attn"]}. '
"During ONNX conversion every node whose name contains a listed substring "
"will receive a metadata property 'layer_ann' set to the layer name."
),
),
}

@classmethod
def validate_config(
cls,
config: type[BasePassConfig],
accelerator_spec: AcceleratorSpec,
) -> bool:
if not super().validate_config(config, accelerator_spec):
return False

if not config.layer_annotations:
logger.info("layer_annotations must be a non-empty dictionary.")
return False

for key, value in config.layer_annotations.items():
if not isinstance(key, str) or not key:
logger.info("layer_annotations keys must be non-empty strings, got %r.", key)
return False
if not isinstance(value, list) or not value:
logger.info("layer_annotations[%r] must be a non-empty list of strings, got %r.", key, type(value))
return False
if not all(isinstance(s, str) and s for s in value):
logger.info("layer_annotations[%r] must contain only non-empty strings.", key)
return False

return True

def _run_for_config(
self, model: Union[HfModelHandler, PyTorchModelHandler], config: type[BasePassConfig], output_model_path: str
) -> Union[HfModelHandler, PyTorchModelHandler]:
model.model = None
output_model = deepcopy(model)
output_model.model_attributes = model_attributes = output_model.model_attributes or {}
model_attributes["layer_annotations"] = config.layer_annotations

return output_model
Loading
Loading