Skip to content

Commit bd198d4

Browse files
Support for unified recipe registry and user interfaces (#12248)
Summary: Implements the RFC: #12248 1. `BackendRecipeProvider` -> Abstract interface that all backends must implement while providing recipes. 1. `recipe_registry` -> Singleton registry that maintains `BackendRecipeProviders` 1. `RecipeType` -> Abstract enum, backends extend this to provide support for specific recipes. 1. `RecipeBuilder` -> Helps building the recipes 1. `ExportRecipe` will have two class methods a. `get_recipe` -> Queries registry to get recipe w/ specific backend b. `recipe_builder` -> To get a recipe builder, starting with a base recipe to further customize Differential Revision: D78034047
1 parent bf4acb3 commit bd198d4

File tree

10 files changed

+923
-25
lines changed

10 files changed

+923
-25
lines changed

export/TARGETS

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ python_library(
2424
deps = [
2525
":recipe",
2626
"//executorch/runtime:runtime",
27+
":recipe_registry"
2728
]
2829
)
2930

@@ -35,5 +36,30 @@ python_library(
3536
deps = [
3637
":export",
3738
":recipe",
39+
":recipe_registry",
40+
":recipe_provider"
3841
],
3942
)
43+
44+
45+
python_library(
46+
name = "recipe_registry",
47+
srcs = [
48+
"recipe_registry.py",
49+
],
50+
deps = [
51+
":recipe",
52+
":recipe_provider"
53+
],
54+
)
55+
56+
57+
python_library(
58+
name = "recipe_provider",
59+
srcs = [
60+
"recipe_provider.py",
61+
],
62+
deps = [
63+
":recipe",
64+
]
65+
)

export/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
"""
810
ExecuTorch export module.
911
@@ -12,13 +14,19 @@
1214
export management.
1315
"""
1416

15-
# pyre-strict
16-
1717
from .export import export, ExportSession
18-
from .recipe import ExportRecipe
18+
from .recipe import ExportRecipe, QuantizationRecipe, RecipeBuilder, RecipeType
19+
from .recipe_provider import BackendRecipeProvider
20+
from .recipe_registry import recipe_registry
21+
1922

2023
__all__ = [
2124
"ExportRecipe",
25+
"QuantizationRecipe",
2226
"ExportSession",
2327
"export",
28+
"BackendRecipeProvider",
29+
"recipe_registry",
30+
"RecipeBuilder",
31+
"RecipeType",
2432
]

export/export.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from abc import ABC, abstractmethod
28
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
39

@@ -10,11 +16,14 @@
1016
ExecutorchProgramManager,
1117
to_edge_transform_and_lower,
1218
)
19+
from executorch.exir.program._program import _transform
1320
from executorch.exir.schema import Program
1421
from executorch.extension.export_util.utils import save_pte_program
1522
from executorch.runtime import Runtime, Verification
1623
from tabulate import tabulate
1724
from torch import nn
25+
26+
from torch._export.pass_base import PassType
1827
from torch.export import ExportedProgram
1928
from torchao.quantization import quantize_
2029
from torchao.quantization.pt2e import allow_exported_model_train_eval
@@ -95,9 +104,7 @@ class ExportStage(Stage):
95104

96105
def __init__(
97106
self,
98-
pre_edge_transform_passes: Optional[
99-
Callable[[ExportedProgram], ExportedProgram]
100-
] = None,
107+
pre_edge_transform_passes: Optional[List[PassType]] = None,
101108
) -> None:
102109
self._exported_program: Dict[str, ExportedProgram] = {}
103110
self._pre_edge_transform_passes = pre_edge_transform_passes
@@ -153,10 +160,10 @@ def run(
153160
)
154161

155162
# Apply pre-edge transform passes if available
156-
if self._pre_edge_transform_passes is not None:
157-
for pre_edge_transform_pass in self._pre_edge_transform_passes:
158-
self._exported_program[method_name] = pre_edge_transform_pass(
159-
self._exported_program[method_name]
163+
if pre_edge_transform_passes := self._pre_edge_transform_passes or []:
164+
for pass_ in pre_edge_transform_passes:
165+
self._exported_program[method_name] = _transform(
166+
self._exported_program[method_name], pass_
160167
)
161168

162169
def get_artifacts(self) -> Dict[str, ExportedProgram]:

export/recipe.py

Lines changed: 172 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,30 @@
1313

1414
from dataclasses import dataclass
1515
from enum import Enum
16-
from typing import Callable, List, Optional, Sequence
16+
from typing import Any, List, Optional, Sequence
1717

1818
from executorch.exir._warnings import experimental
1919

2020
from executorch.exir.backend.partitioner import Partitioner
2121
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
2222
from executorch.exir.pass_manager import PassType
23-
from torch.export import ExportedProgram
2423
from torchao.core.config import AOBaseConfig
2524
from torchao.quantization.pt2e.quantizer import Quantizer
2625

2726

27+
class RecipeType(Enum):
28+
"""
29+
Base recipe type class that backends can extend to define their own recipe types.
30+
Backends should create their own enum classes that inherit from RecipeType:
31+
Example:
32+
class MyBackendRecipeType(RecipeType):
33+
FP32 = "fp32"
34+
INT8_DYNAMIC = "int8_dynamic"
35+
"""
36+
37+
pass
38+
39+
2840
class Mode(str, Enum):
2941
"""
3042
Export mode enumeration.
@@ -52,7 +64,7 @@ class QuantizationRecipe:
5264
quantizers: Optional[List[Quantizer]] = None
5365
ao_base_config: Optional[List[AOBaseConfig]] = None
5466

55-
def get_quantizers(self) -> Optional[Quantizer]:
67+
def get_quantizers(self) -> Optional[List[Quantizer]]:
5668
"""
5769
Get the quantizer associated with this recipe.
5870
@@ -89,17 +101,164 @@ class ExportRecipe:
89101

90102
name: Optional[str] = None
91103
quantization_recipe: Optional[QuantizationRecipe] = None
92-
edge_compile_config: Optional[EdgeCompileConfig] = (
93-
None # pyre-ignore[11]: Type not defined
94-
)
95-
pre_edge_transform_passes: Optional[
96-
Callable[[ExportedProgram], ExportedProgram]
97-
| List[Callable[[ExportedProgram], ExportedProgram]]
98-
] = None
104+
# pyre-ignore[11]: Type not defined
105+
edge_compile_config: Optional[EdgeCompileConfig] = None
106+
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
99107
edge_transform_passes: Optional[Sequence[PassType]] = None
100108
transform_check_ir_validity: bool = True
101109
partitioners: Optional[List[Partitioner]] = None
102-
executorch_backend_config: Optional[ExecutorchBackendConfig] = (
103-
None # pyre-ignore[11]: Type not defined
104-
)
110+
# pyre-ignore[11]: Type not defined
111+
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
105112
mode: Mode = Mode.RELEASE
113+
114+
@classmethod
115+
def get_recipe(
116+
cls, recipe_type: "RecipeType", backend: str, **kwargs
117+
) -> "ExportRecipe":
118+
"""Get a clean base recipe from backend"""
119+
from .recipe_registry import recipe_registry
120+
121+
recipe = recipe_registry.create_recipe(recipe_type, backend, **kwargs)
122+
if recipe is None:
123+
supported = recipe_registry.get_supported_recipes(backend)
124+
raise ValueError(
125+
f"Recipe '{recipe_type.value}' not supported by '{backend}'. "
126+
f"Supported: {[r.value for r in supported]}"
127+
)
128+
return recipe
129+
130+
@classmethod
131+
def recipe_builder(
132+
cls, recipe_type: "RecipeType", backend: str, **kwargs
133+
) -> "RecipeBuilder":
134+
"""Create a recipe builder from a base recipe"""
135+
from .recipe_registry import recipe_registry
136+
137+
base_recipe = recipe_registry.create_recipe(recipe_type, backend, **kwargs)
138+
if base_recipe is None:
139+
supported = recipe_registry.get_supported_recipes(backend)
140+
raise ValueError(
141+
f"Recipe '{recipe_type.value}' not supported by '{backend}'. "
142+
f"Supported: {[r.value for r in supported]}"
143+
)
144+
return RecipeBuilder(base_recipe)
145+
146+
147+
class RecipeBuilder:
148+
def __init__(
149+
self, base_recipe: Optional[ExportRecipe] = None, name: str = "custom_recipe"
150+
):
151+
"""
152+
Initialize builder with optional base recipe.
153+
154+
Args:
155+
base_recipe: Existing recipe to customize or None to start fresh
156+
name: Name for the recipe
157+
"""
158+
if base_recipe is not None:
159+
# Initialize from existing recipe
160+
self._name = base_recipe.name
161+
self._quantization_recipe = base_recipe.quantization_recipe
162+
self._partitioners = base_recipe.partitioners
163+
self._edge_compile_config = base_recipe.edge_compile_config
164+
self._executorch_backend_config = base_recipe.executorch_backend_config
165+
self._pre_edge_transform_passes = base_recipe.pre_edge_transform_passes
166+
self._edge_transform_passes = base_recipe.edge_transform_passes
167+
self._transform_check_ir_validity = getattr(
168+
base_recipe, "transform_check_ir_validity", True
169+
)
170+
self._mode = getattr(base_recipe, "mode", None)
171+
else:
172+
self._name = name
173+
self._quantization_recipe: Optional[QuantizationRecipe] = None
174+
self._partitioners: Optional[List[Any]] = None
175+
self._edge_compile_config: Optional[Any] = None
176+
self._executorch_backend_config: Optional[Any] = None
177+
self._pre_edge_transform_passes: Optional[Sequence[PassType]] = None
178+
self._edge_transform_passes: Optional[Sequence[PassType]] = None
179+
self._transform_check_ir_validity: bool = True
180+
self._mode = None
181+
182+
def with_name(self, name: str) -> "RecipeBuilder":
183+
"""Set recipe name"""
184+
self._name = name
185+
return self
186+
187+
def with_quantization(
188+
self, quantization_recipe: QuantizationRecipe
189+
) -> "RecipeBuilder":
190+
"""Set quantization recipe"""
191+
self._quantization_recipe = quantization_recipe
192+
return self
193+
194+
def with_partitioners(self, partitioners: List[Any]) -> "RecipeBuilder":
195+
"""Set partitioners"""
196+
self._partitioners = partitioners
197+
return self
198+
199+
def with_pre_edge_passes(self, passes: List[PassType]) -> "RecipeBuilder":
200+
"""Set pre-edge transform passes"""
201+
self._pre_edge_transform_passes = passes
202+
return self
203+
204+
def with_edge_passes(self, passes: Sequence[PassType]) -> "RecipeBuilder":
205+
"""Set edge transform passes"""
206+
self._edge_transform_passes = passes
207+
return self
208+
209+
def with_edge_compile_config(self, config: Any) -> "RecipeBuilder":
210+
"""Set edge compile config"""
211+
self._edge_compile_config = config
212+
return self
213+
214+
def with_backend_config(self, config: Any) -> "RecipeBuilder":
215+
"""Set backend config"""
216+
self._executorch_backend_config = config
217+
return self
218+
219+
def with_ir_validity_check(self, check: bool) -> "RecipeBuilder":
220+
"""Set IR validity check"""
221+
self._transform_check_ir_validity = check
222+
return self
223+
224+
def with_mode(self, mode: Any) -> "RecipeBuilder":
225+
"""Set export mode"""
226+
self._mode = mode
227+
return self
228+
229+
def add_pre_edge_pass(self, pass_: PassType) -> "RecipeBuilder":
230+
"""Add a single pre-edge transform pass to existing passes"""
231+
if self._pre_edge_transform_passes is None:
232+
self._pre_edge_transform_passes = []
233+
self._pre_edge_transform_passes = list(self._pre_edge_transform_passes) + [
234+
pass_
235+
]
236+
return self
237+
238+
def add_edge_pass(self, pass_: PassType) -> "RecipeBuilder":
239+
"""Add a single edge transform pass to existing passes"""
240+
if self._edge_transform_passes is None:
241+
self._edge_transform_passes = []
242+
self._edge_transform_passes = list(self._edge_transform_passes) + [pass_]
243+
return self
244+
245+
def add_partitioner(self, partitioner: Any) -> "RecipeBuilder":
246+
"""Add a partitioner to existing partitioners"""
247+
if self._partitioners is None:
248+
self._partitioners = []
249+
self._partitioners = list(self._partitioners) + [partitioner]
250+
return self
251+
252+
def build(self) -> ExportRecipe:
253+
"""Build the final recipe"""
254+
return ExportRecipe(
255+
name=self._name,
256+
quantization_recipe=self._quantization_recipe,
257+
partitioners=self._partitioners,
258+
edge_compile_config=self._edge_compile_config,
259+
executorch_backend_config=self._executorch_backend_config,
260+
pre_edge_transform_passes=self._pre_edge_transform_passes,
261+
edge_transform_passes=self._edge_transform_passes,
262+
transform_check_ir_validity=self._transform_check_ir_validity,
263+
mode=self._mode,
264+
)

export/recipe_provider.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
"""
10+
Recipe registry for managing backend recipe providers.
11+
12+
This module provides the registry system for backend recipe providers and
13+
the abstract interface that all backends must implement.
14+
"""
15+
16+
from abc import ABC, abstractmethod
17+
from typing import Any, Optional, Sequence
18+
19+
from .recipe import ExportRecipe, RecipeType
20+
21+
22+
class BackendRecipeProvider(ABC):
23+
"""
24+
Abstract recipe provider that all backends must implement
25+
"""
26+
27+
@property
28+
@abstractmethod
29+
def backend_name(self) -> str:
30+
"""
31+
Name of the backend (ex: 'xnnpack', 'qnn' etc)
32+
"""
33+
pass
34+
35+
@abstractmethod
36+
def get_supported_recipes(self) -> Sequence[RecipeType]:
37+
"""
38+
Get list of supported recipes.
39+
"""
40+
pass
41+
42+
@abstractmethod
43+
def create_recipe(
44+
self, recipe_type: RecipeType, **kwargs: Any
45+
) -> Optional[ExportRecipe]:
46+
"""
47+
Create a recipe for the given type.
48+
Returns None if the recipe is not supported by this backend.
49+
50+
Args:
51+
recipe_type: The type of recipe to create
52+
**kwargs: Recipe-specific parameters (ex: group_size)
53+
54+
Returns:
55+
ExportRecipe if supported, None otherwise
56+
"""
57+
pass
58+
59+
def supports_recipe(self, recipe_type: RecipeType) -> bool:
60+
return recipe_type in self.get_supported_recipes()

0 commit comments

Comments
 (0)