|
3 | 3 | #
|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
| 6 | +from abc import ABCMeta, abstractmethod |
| 7 | +from dataclasses import dataclass |
| 8 | +from enum import Enum, EnumMeta |
| 9 | +from typing import List, Optional, Sequence |
| 10 | + |
| 11 | +from executorch.exir._warnings import experimental |
| 12 | + |
| 13 | +from executorch.exir.backend.partitioner import Partitioner |
| 14 | +from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig |
| 15 | +from executorch.exir.pass_manager import PassType |
| 16 | +from torchao.core.config import AOBaseConfig |
| 17 | +from torchao.quantization.pt2e.quantizer import Quantizer |
| 18 | + |
6 | 19 |
|
7 | 20 | """
|
8 | 21 | Export recipe definitions for ExecuTorch.
|
|
11 | 24 | for ExecuTorch models, including export configurations and quantization recipes.
|
12 | 25 | """
|
13 | 26 |
|
14 |
| -from dataclasses import dataclass |
15 |
| -from enum import Enum |
16 |
| -from typing import Callable, List, Optional, Sequence |
17 | 27 |
|
18 |
| -from executorch.exir._warnings import experimental |
| 28 | +class RecipeTypeMeta(EnumMeta, ABCMeta): |
| 29 | + """Metaclass that combines EnumMeta and ABCMeta""" |
19 | 30 |
|
20 |
| -from executorch.exir.backend.partitioner import Partitioner |
21 |
| -from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig |
22 |
| -from executorch.exir.pass_manager import PassType |
23 |
| -from torch.export import ExportedProgram |
24 |
| -from torchao.core.config import AOBaseConfig |
25 |
| -from torchao.quantization.pt2e.quantizer import Quantizer |
| 31 | + pass |
| 32 | + |
| 33 | + |
| 34 | +class RecipeType(Enum, metaclass=RecipeTypeMeta): |
| 35 | + """ |
| 36 | + Base recipe type class that backends can extend to define their own recipe types. |
| 37 | + Backends should create their own enum classes that inherit from RecipeType: |
| 38 | + """ |
| 39 | + |
| 40 | + @classmethod |
| 41 | + @abstractmethod |
| 42 | + def get_backend_name(cls) -> str: |
| 43 | + """ |
| 44 | + Return the backend name for this recipe type. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + str: The backend name (e.g., "xnnpack", "qnn", etc.) |
| 48 | + """ |
| 49 | + pass |
26 | 50 |
|
27 | 51 |
|
28 | 52 | class Mode(str, Enum):
|
@@ -52,7 +76,7 @@ class QuantizationRecipe:
|
52 | 76 | quantizers: Optional[List[Quantizer]] = None
|
53 | 77 | ao_base_config: Optional[List[AOBaseConfig]] = None
|
54 | 78 |
|
55 |
| - def get_quantizers(self) -> Optional[Quantizer]: |
| 79 | + def get_quantizers(self) -> Optional[List[Quantizer]]: |
56 | 80 | """
|
57 | 81 | Get the quantizer associated with this recipe.
|
58 | 82 |
|
@@ -89,17 +113,40 @@ class ExportRecipe:
|
89 | 113 |
|
90 | 114 | name: Optional[str] = None
|
91 | 115 | quantization_recipe: Optional[QuantizationRecipe] = None
|
92 |
| - edge_compile_config: Optional[EdgeCompileConfig] = ( |
93 |
| - None # pyre-ignore[11]: Type not defined |
94 |
| - ) |
95 |
| - pre_edge_transform_passes: Optional[ |
96 |
| - Callable[[ExportedProgram], ExportedProgram] |
97 |
| - | List[Callable[[ExportedProgram], ExportedProgram]] |
98 |
| - ] = None |
| 116 | + # pyre-ignore[11]: Type not defined |
| 117 | + edge_compile_config: Optional[EdgeCompileConfig] = None |
| 118 | + pre_edge_transform_passes: Optional[Sequence[PassType]] = None |
99 | 119 | edge_transform_passes: Optional[Sequence[PassType]] = None
|
100 | 120 | transform_check_ir_validity: bool = True
|
101 | 121 | partitioners: Optional[List[Partitioner]] = None
|
102 |
| - executorch_backend_config: Optional[ExecutorchBackendConfig] = ( |
103 |
| - None # pyre-ignore[11]: Type not defined |
104 |
| - ) |
| 122 | + # pyre-ignore[11]: Type not defined |
| 123 | + executorch_backend_config: Optional[ExecutorchBackendConfig] = None |
105 | 124 | mode: Mode = Mode.RELEASE
|
| 125 | + |
| 126 | + @classmethod |
| 127 | + def get_recipe(cls, recipe: "RecipeType", **kwargs) -> "ExportRecipe": |
| 128 | + """ |
| 129 | + Get an export recipe from backend. Backend is automatically determined based on the |
| 130 | + passed recipe type. |
| 131 | +
|
| 132 | + Args: |
| 133 | + recipe: The type of recipe to create |
| 134 | + **kwargs: Recipe-specific parameters |
| 135 | +
|
| 136 | + Returns: |
| 137 | + ExportRecipe configured for the specified recipe type |
| 138 | + """ |
| 139 | + from .recipe_registry import recipe_registry |
| 140 | + |
| 141 | + if not isinstance(recipe, RecipeType): |
| 142 | + raise ValueError(f"Invalid recipe type: {recipe}") |
| 143 | + |
| 144 | + backend = recipe.get_backend_name() |
| 145 | + export_recipe = recipe_registry.create_recipe(recipe, backend, **kwargs) |
| 146 | + if export_recipe is None: |
| 147 | + supported = recipe_registry.get_supported_recipes(backend) |
| 148 | + raise ValueError( |
| 149 | + f"Recipe '{recipe.value}' not supported by '{backend}'. " |
| 150 | + f"Supported: {[r.value for r in supported]}" |
| 151 | + ) |
| 152 | + return export_recipe |
0 commit comments