Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions export/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ python_library(
deps = [
":recipe",
"//executorch/runtime:runtime",
":recipe_registry"
]
)

Expand All @@ -35,5 +36,30 @@ python_library(
deps = [
":export",
":recipe",
":recipe_registry",
":recipe_provider"
],
)


python_library(
name = "recipe_registry",
srcs = [
"recipe_registry.py",
],
deps = [
":recipe",
":recipe_provider"
],
)


python_library(
name = "recipe_provider",
srcs = [
"recipe_provider.py",
],
deps = [
":recipe",
]
)
13 changes: 10 additions & 3 deletions export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""
ExecuTorch export module.

Expand All @@ -12,13 +14,18 @@
export management.
"""

# pyre-strict

from .export import export, ExportSession
from .recipe import ExportRecipe
from .recipe import ExportRecipe, QuantizationRecipe, RecipeType
from .recipe_provider import BackendRecipeProvider
from .recipe_registry import recipe_registry


__all__ = [
"ExportRecipe",
"QuantizationRecipe",
"ExportSession",
"export",
"BackendRecipeProvider",
"recipe_registry",
"RecipeType",
]
21 changes: 14 additions & 7 deletions export/export.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

Expand All @@ -10,11 +16,14 @@
ExecutorchProgramManager,
to_edge_transform_and_lower,
)
from executorch.exir.program._program import _transform
from executorch.exir.schema import Program
from executorch.extension.export_util.utils import save_pte_program
from executorch.runtime import Runtime, Verification
from tabulate import tabulate
from torch import nn

from torch._export.pass_base import PassType
from torch.export import ExportedProgram
from torchao.quantization import quantize_
from torchao.quantization.pt2e import allow_exported_model_train_eval
Expand Down Expand Up @@ -95,9 +104,7 @@ class ExportStage(Stage):

def __init__(
self,
pre_edge_transform_passes: Optional[
Callable[[ExportedProgram], ExportedProgram]
] = None,
pre_edge_transform_passes: Optional[List[PassType]] = None,
) -> None:
self._exported_program: Dict[str, ExportedProgram] = {}
self._pre_edge_transform_passes = pre_edge_transform_passes
Expand Down Expand Up @@ -153,10 +160,10 @@ def run(
)

# Apply pre-edge transform passes if available
if self._pre_edge_transform_passes is not None:
for pre_edge_transform_pass in self._pre_edge_transform_passes:
self._exported_program[method_name] = pre_edge_transform_pass(
self._exported_program[method_name]
if pre_edge_transform_passes := self._pre_edge_transform_passes or []:
for pass_ in pre_edge_transform_passes:
self._exported_program[method_name] = _transform(
self._exported_program[method_name], pass_
)

def get_artifacts(self) -> Dict[str, ExportedProgram]:
Expand Down
89 changes: 68 additions & 21 deletions export/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,19 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from enum import Enum, EnumMeta
from typing import List, Optional, Sequence

from executorch.exir._warnings import experimental

from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.pass_manager import PassType
from torchao.core.config import AOBaseConfig
from torchao.quantization.pt2e.quantizer import Quantizer


"""
Export recipe definitions for ExecuTorch.
Expand All @@ -11,18 +24,29 @@
for ExecuTorch models, including export configurations and quantization recipes.
"""

from dataclasses import dataclass
from enum import Enum
from typing import Callable, List, Optional, Sequence

from executorch.exir._warnings import experimental
class RecipeTypeMeta(EnumMeta, ABCMeta):
"""Metaclass that combines EnumMeta and ABCMeta"""

from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.pass_manager import PassType
from torch.export import ExportedProgram
from torchao.core.config import AOBaseConfig
from torchao.quantization.pt2e.quantizer import Quantizer
pass


class RecipeType(Enum, metaclass=RecipeTypeMeta):
"""
Base recipe type class that backends can extend to define their own recipe types.
Backends should create their own enum classes that inherit from RecipeType:
"""

@classmethod
@abstractmethod
def get_backend_name(cls) -> str:
"""
Return the backend name for this recipe type.

Returns:
str: The backend name (e.g., "xnnpack", "qnn", etc.)
"""
pass


class Mode(str, Enum):
Expand Down Expand Up @@ -52,7 +76,7 @@ class QuantizationRecipe:
quantizers: Optional[List[Quantizer]] = None
ao_base_config: Optional[List[AOBaseConfig]] = None

def get_quantizers(self) -> Optional[Quantizer]:
def get_quantizers(self) -> Optional[List[Quantizer]]:
"""
Get the quantizer associated with this recipe.

Expand Down Expand Up @@ -89,17 +113,40 @@ class ExportRecipe:

name: Optional[str] = None
quantization_recipe: Optional[QuantizationRecipe] = None
edge_compile_config: Optional[EdgeCompileConfig] = (
None # pyre-ignore[11]: Type not defined
)
pre_edge_transform_passes: Optional[
Callable[[ExportedProgram], ExportedProgram]
| List[Callable[[ExportedProgram], ExportedProgram]]
] = None
# pyre-ignore[11]: Type not defined
edge_compile_config: Optional[EdgeCompileConfig] = None
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
edge_transform_passes: Optional[Sequence[PassType]] = None
transform_check_ir_validity: bool = True
partitioners: Optional[List[Partitioner]] = None
executorch_backend_config: Optional[ExecutorchBackendConfig] = (
None # pyre-ignore[11]: Type not defined
)
# pyre-ignore[11]: Type not defined
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
mode: Mode = Mode.RELEASE

@classmethod
def get_recipe(cls, recipe: "RecipeType", **kwargs) -> "ExportRecipe":
"""
Get an export recipe from backend. Backend is automatically determined based on the
passed recipe type.

Args:
recipe: The type of recipe to create
**kwargs: Recipe-specific parameters

Returns:
ExportRecipe configured for the specified recipe type
"""
from .recipe_registry import recipe_registry

if not isinstance(recipe, RecipeType):
raise ValueError(f"Invalid recipe type: {recipe}")

backend = recipe.get_backend_name()
export_recipe = recipe_registry.create_recipe(recipe, backend, **kwargs)
if export_recipe is None:
supported = recipe_registry.get_supported_recipes(backend)
raise ValueError(
f"Recipe '{recipe.value}' not supported by '{backend}'. "
f"Supported: {[r.value for r in supported]}"
)
return export_recipe
60 changes: 60 additions & 0 deletions export/recipe_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""
Recipe registry for managing backend recipe providers.

This module provides the registry system for backend recipe providers and
the abstract interface that all backends must implement.
"""

from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence

from .recipe import ExportRecipe, RecipeType


class BackendRecipeProvider(ABC):
"""
Abstract recipe provider that all backends must implement
"""

@property
@abstractmethod
def backend_name(self) -> str:
"""
Name of the backend (ex: 'xnnpack', 'qnn' etc)
"""
pass

@abstractmethod
def get_supported_recipes(self) -> Sequence[RecipeType]:
"""
Get list of supported recipes.
"""
pass

@abstractmethod
def create_recipe(
self, recipe_type: RecipeType, **kwargs: Any
) -> Optional[ExportRecipe]:
"""
Create a recipe for the given type.
Returns None if the recipe is not supported by this backend.

Args:
recipe_type: The type of recipe to create
**kwargs: Recipe-specific parameters (ex: group_size)

Returns:
ExportRecipe if supported, None otherwise
"""
pass

def supports_recipe(self, recipe_type: RecipeType) -> bool:
return recipe_type in self.get_supported_recipes()
86 changes: 86 additions & 0 deletions export/recipe_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Recipe registry for managing backend recipe providers.

This module provides the registry system for backend recipe providers and
the abstract interface that all backends must implement.
"""

from typing import Any, Dict, Optional, Sequence

from .recipe import ExportRecipe, RecipeType
from .recipe_provider import BackendRecipeProvider


class RecipeRegistry:
"""Global registry for all backend recipe providers"""

_instance = None
_initialized = False

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self) -> None:
# Only initialize once to avoid resetting state on subsequent calls
if not RecipeRegistry._initialized:
self._providers: Dict[str, BackendRecipeProvider] = {}
RecipeRegistry._initialized = True

def register_backend_recipe_provider(self, provider: BackendRecipeProvider) -> None:
"""
Register a backend recipe provider
"""
self._providers[provider.backend_name] = provider

def create_recipe(
self, recipe_type: RecipeType, backend: str, **kwargs: Any
) -> Optional[ExportRecipe]:
"""
Create a recipe for a specific backend.

Args:
recipe_type: The type of recipe to create
backend: Backend name
**kwargs: Recipe-specific parameters

Returns:
ExportRecipe if supported, None if not supported
"""
if backend not in self._providers:
raise ValueError(
f"Backend '{backend}' not available. Available: {list(self._providers.keys())}"
)

return self._providers[backend].create_recipe(recipe_type, **kwargs)

def get_supported_recipes(self, backend: str) -> Sequence[RecipeType]:
"""
Get list of recipes supported by a backend.

Args:
backend: Backend name

Returns:
List of supported recipe types
"""
if backend not in self._providers:
raise ValueError(f"Backend '{backend}' not available")
return self._providers[backend].get_supported_recipes()

def list_backends(self) -> Sequence[str]:
"""
Get list of all registered backends
"""
return list(self._providers.keys())


# initialize recipe registry
recipe_registry = RecipeRegistry()
Loading
Loading