Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions backends/xnnpack/recipes/xnnpack_recipe_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from executorch.export import (
BackendRecipeProvider,
ExportRecipe,
LoweringRecipe,
QuantizationRecipe,
RecipeType,
)
Expand Down Expand Up @@ -88,12 +89,19 @@ def create_recipe(
)
return None

def _get_xnnpack_lowering_recipe(
self, precision_type: Optional[ConfigPrecisionType] = None
) -> LoweringRecipe:
return LoweringRecipe(
partitioners=[XnnpackPartitioner(precision_type=precision_type)],
edge_compile_config=get_xnnpack_edge_compile_config(),
)

def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe:
return ExportRecipe(
name=recipe_type.value,
edge_compile_config=get_xnnpack_edge_compile_config(),
lowering_recipe=self._get_xnnpack_lowering_recipe(),
executorch_backend_config=get_xnnpack_executorch_backend_config(),
partitioners=[XnnpackPartitioner()],
)

def _build_quantized_recipe(
Expand All @@ -120,9 +128,8 @@ def _build_quantized_recipe(
return ExportRecipe(
name=recipe_type.value,
quantization_recipe=quant_recipe,
edge_compile_config=get_xnnpack_edge_compile_config(),
lowering_recipe=self._get_xnnpack_lowering_recipe(precision_type),
executorch_backend_config=get_xnnpack_executorch_backend_config(),
partitioners=[XnnpackPartitioner(config_precision=precision_type)],
)

def _build_int8da_intx_weight_recipe(
Expand Down Expand Up @@ -150,9 +157,8 @@ def _build_int8da_intx_weight_recipe(
return ExportRecipe(
name=recipe_type.value,
quantization_recipe=quant_recipe,
edge_compile_config=get_xnnpack_edge_compile_config(),
lowering_recipe=self._get_xnnpack_lowering_recipe(),
executorch_backend_config=get_xnnpack_executorch_backend_config(),
partitioners=[XnnpackPartitioner()],
)

def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
Expand Down
6 changes: 4 additions & 2 deletions export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
"""

from .export import export, ExportSession
from .recipe import ExportRecipe, QuantizationRecipe, RecipeType
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType
from .recipe_provider import BackendRecipeProvider
from .recipe_registry import recipe_registry

from .stages import StageType

__all__ = [
"StageType",
"ExportRecipe",
"LoweringRecipe",
"QuantizationRecipe",
"ExportSession",
"export",
Expand Down
42 changes: 35 additions & 7 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tabulate import tabulate
from torch import nn

from .recipe import ExportRecipe, QuantizationRecipe
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe
from .stages import (
EdgeTransformAndLowerStage,
ExecutorchStage,
Expand Down Expand Up @@ -145,6 +145,10 @@ def __init__(
self._export_recipe.quantization_recipe
)

self._lowering_recipe: Optional[LoweringRecipe] = (
self._export_recipe.lowering_recipe
)

# Stages to run
self._pipeline_stages = pipeline_stages or self._get_default_pipeline()

Expand Down Expand Up @@ -189,18 +193,42 @@ def _build_default_stages(self) -> Dict[StageType, Stage]:
stage = TorchExportStage(pre_edge_passes)
elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER:
stage = EdgeTransformAndLowerStage(
partitioners=self._export_recipe.partitioners,
transform_passes=self._export_recipe.edge_transform_passes,
compile_config=self._export_recipe.edge_compile_config,
partitioners=(
self._lowering_recipe.partitioners
if self._lowering_recipe
else None
),
transform_passes=(
self._lowering_recipe.edge_transform_passes
if self._lowering_recipe
else None
),
compile_config=(
self._lowering_recipe.edge_compile_config
if self._lowering_recipe
else None
),
)
elif stage_type == StageType.TO_EDGE:
stage = ToEdgeStage(
edge_compile_config=self._export_recipe.edge_compile_config
edge_compile_config=(
self._lowering_recipe.edge_compile_config
if self._lowering_recipe
else None
),
)
elif stage_type == StageType.TO_BACKEND:
stage = ToBackendStage(
partitioners=self._export_recipe.partitioners,
transform_passes=self._export_recipe.edge_transform_passes,
partitioners=(
self._lowering_recipe.partitioners
if self._lowering_recipe
else None
),
transform_passes=(
self._lowering_recipe.edge_transform_passes
if self._lowering_recipe
else None
),
)
elif stage_type == StageType.TO_EXECUTORCH:
stage = ExecutorchStage(self._export_recipe.executorch_backend_config)
Expand Down
36 changes: 24 additions & 12 deletions export/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ def get_quantizers(self) -> Optional[List[Quantizer]]:
return self.quantizers


@dataclass
class LoweringRecipe:
"""
Configuration recipe for lowering and partitioning.

This class holds the configuration parameters for lowering a model
to backend-specific representations.

Attributes:
partitioners: Optional list of partitioners for model partitioning
edge_transform_passes: Optional sequence of transformation passes to apply
edge_compile_config: Optional edge compilation configuration
"""

partitioners: Optional[List[Partitioner]] = None
edge_transform_passes: Optional[Sequence[PassType]] = None
# pyre-ignore[11]: Type not defined
edge_compile_config: Optional[EdgeCompileConfig] = None


@experimental(
"This API and all of its related functionality such as ExportSession and ExportRecipe are experimental."
)
Expand All @@ -101,27 +121,19 @@ class ExportRecipe:
Attributes:
name: Optional name for the recipe
quantization_recipe: Optional quantization recipe for model quantization
edge_compile_config: Optional edge compilation configuration
lowering_recipe: Optional lowering recipe for model lowering and partitioning
executorch_backend_config: Optional backend configuration for ExecuTorch
pre_edge_transform_passes: Optional function to apply transformation passes
before edge lowering
edge_transform_passes: Optional sequence of transformation passes to apply
during edge lowering
transform_check_ir_validity: Whether to check IR validity during transformation
partitioners: Optional list of partitioners for model partitioning
executorch_backend_config: Optional backend configuration for ExecuTorch
mode: Export mode (debug or release)
"""

name: Optional[str] = None
quantization_recipe: Optional[QuantizationRecipe] = None
# pyre-ignore[11]: Type not defined
edge_compile_config: Optional[EdgeCompileConfig] = None
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
edge_transform_passes: Optional[Sequence[PassType]] = None
transform_check_ir_validity: bool = True
partitioners: Optional[List[Partitioner]] = None
lowering_recipe: Optional[LoweringRecipe] = None
# pyre-ignore[11]: Type not defined
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
mode: Mode = Mode.RELEASE

@classmethod
Expand Down
47 changes: 47 additions & 0 deletions export/tests/test_export_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
from executorch.export import ExportRecipe, ExportSession
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
from executorch.export.stages import PipelineArtifact, StageType


Expand Down Expand Up @@ -329,3 +330,49 @@ def test_save_to_pte_invalid_name(self) -> None:

with self.assertRaises(AssertionError):
session.save_to_pte(None) # pyre-ignore


class TestExportSessionPipelineBuilding(unittest.TestCase):
"""Test pipeline building and stage configuration."""

def setUp(self) -> None:
self.model = SimpleTestModel()
self.example_inputs = [(torch.randn(2, 10),)]

def test_pipeline_building_with_all_recipes(self) -> None:
"""Test pipeline building with quantization and lowering recipes."""
# Create comprehensive recipes
quant_recipe = QuantizationRecipe(
ao_base_config=[Mock()],
quantizers=[Mock()],
)
lowering_recipe = LoweringRecipe(
partitioners=[Mock()],
edge_transform_passes=[Mock()],
edge_compile_config=Mock(),
)
recipe = ExportRecipe(
name="comprehensive_test",
quantization_recipe=quant_recipe,
lowering_recipe=lowering_recipe,
pre_edge_transform_passes=[Mock()],
executorch_backend_config=Mock(),
)

session = ExportSession(
model=self.model,
example_inputs=self.example_inputs,
export_recipe=recipe,
)

registered_stages = session.get_all_registered_stages()

self.assertEqual(len(registered_stages), 5)
expected_types = [
StageType.SOURCE_TRANSFORM,
StageType.QUANTIZE,
StageType.TORCH_EXPORT,
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.TO_EXECUTORCH,
]
self.assertListEqual(list(registered_stages.keys()), expected_types)
Loading