|
13 | 13 |
|
14 | 14 | from dataclasses import dataclass |
15 | 15 | from enum import Enum |
16 | | -from typing import Callable, List, Optional, Sequence |
| 16 | +from typing import Any, List, Optional, Sequence |
17 | 17 |
|
18 | 18 | from executorch.exir._warnings import experimental |
19 | 19 |
|
20 | 20 | from executorch.exir.backend.partitioner import Partitioner |
21 | 21 | from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig |
22 | 22 | from executorch.exir.pass_manager import PassType |
23 | | -from torch.export import ExportedProgram |
24 | 23 | from torchao.core.config import AOBaseConfig |
25 | 24 | from torchao.quantization.pt2e.quantizer import Quantizer |
26 | 25 |
|
27 | 26 |
|
| 27 | +class RecipeType(Enum): |
| 28 | + """ |
| 29 | + Base recipe type class that backends can extend to define their own recipe types. |
| 30 | + Backends should create their own enum classes that inherit from RecipeType: |
| 31 | + Example: |
| 32 | + class MyBackendRecipeType(RecipeType): |
| 33 | + FP32 = "fp32" |
| 34 | + INT8_DYNAMIC = "int8_dynamic" |
| 35 | + """ |
| 36 | + |
| 37 | + pass |
| 38 | + |
| 39 | + |
28 | 40 | class Mode(str, Enum): |
29 | 41 | """ |
30 | 42 | Export mode enumeration. |
@@ -52,7 +64,7 @@ class QuantizationRecipe: |
52 | 64 | quantizers: Optional[List[Quantizer]] = None |
53 | 65 | ao_base_config: Optional[List[AOBaseConfig]] = None |
54 | 66 |
|
55 | | - def get_quantizers(self) -> Optional[Quantizer]: |
| 67 | + def get_quantizers(self) -> Optional[List[Quantizer]]: |
56 | 68 | """ |
57 | 69 | Get the quantizer associated with this recipe. |
58 | 70 |
|
@@ -89,17 +101,164 @@ class ExportRecipe: |
89 | 101 |
|
90 | 102 | name: Optional[str] = None |
91 | 103 | quantization_recipe: Optional[QuantizationRecipe] = None |
92 | | - edge_compile_config: Optional[EdgeCompileConfig] = ( |
93 | | - None # pyre-ignore[11]: Type not defined |
94 | | - ) |
95 | | - pre_edge_transform_passes: Optional[ |
96 | | - Callable[[ExportedProgram], ExportedProgram] |
97 | | - | List[Callable[[ExportedProgram], ExportedProgram]] |
98 | | - ] = None |
| 104 | + # pyre-ignore[11]: Type not defined |
| 105 | + edge_compile_config: Optional[EdgeCompileConfig] = None |
| 106 | + pre_edge_transform_passes: Optional[Sequence[PassType]] = None |
99 | 107 | edge_transform_passes: Optional[Sequence[PassType]] = None |
100 | 108 | transform_check_ir_validity: bool = True |
101 | 109 | partitioners: Optional[List[Partitioner]] = None |
102 | | - executorch_backend_config: Optional[ExecutorchBackendConfig] = ( |
103 | | - None # pyre-ignore[11]: Type not defined |
104 | | - ) |
| 110 | + # pyre-ignore[11]: Type not defined |
| 111 | + executorch_backend_config: Optional[ExecutorchBackendConfig] = None |
105 | 112 | mode: Mode = Mode.RELEASE |
| 113 | + |
| 114 | + @classmethod |
| 115 | + def get_recipe( |
| 116 | + cls, recipe_type: "RecipeType", backend: str, **kwargs |
| 117 | + ) -> "ExportRecipe": |
| 118 | + """Get a clean base recipe from backend""" |
| 119 | + from .recipe_registry import recipe_registry |
| 120 | + |
| 121 | + recipe = recipe_registry.create_recipe(recipe_type, backend, **kwargs) |
| 122 | + if recipe is None: |
| 123 | + supported = recipe_registry.get_supported_recipes(backend) |
| 124 | + raise ValueError( |
| 125 | + f"Recipe '{recipe_type.value}' not supported by '{backend}'. " |
| 126 | + f"Supported: {[r.value for r in supported]}" |
| 127 | + ) |
| 128 | + return recipe |
| 129 | + |
| 130 | + @classmethod |
| 131 | + def recipe_builder( |
| 132 | + cls, recipe_type: "RecipeType", backend: str, **kwargs |
| 133 | + ) -> "RecipeBuilder": |
| 134 | + """Create a recipe builder from a base recipe""" |
| 135 | + from .recipe_registry import recipe_registry |
| 136 | + |
| 137 | + base_recipe = recipe_registry.create_recipe(recipe_type, backend, **kwargs) |
| 138 | + if base_recipe is None: |
| 139 | + supported = recipe_registry.get_supported_recipes(backend) |
| 140 | + raise ValueError( |
| 141 | + f"Recipe '{recipe_type.value}' not supported by '{backend}'. " |
| 142 | + f"Supported: {[r.value for r in supported]}" |
| 143 | + ) |
| 144 | + return RecipeBuilder(base_recipe) |
| 145 | + |
| 146 | + |
| 147 | +class RecipeBuilder: |
| 148 | + def __init__( |
| 149 | + self, base_recipe: Optional[ExportRecipe] = None, name: str = "custom_recipe" |
| 150 | + ): |
| 151 | + """ |
| 152 | + Initialize builder with optional base recipe. |
| 153 | +
|
| 154 | + Args: |
| 155 | + base_recipe: Existing recipe to customize or None to start fresh |
| 156 | + name: Name for the recipe |
| 157 | + """ |
| 158 | + if base_recipe is not None: |
| 159 | + # Initialize from existing recipe |
| 160 | + self._name = base_recipe.name |
| 161 | + self._quantization_recipe = base_recipe.quantization_recipe |
| 162 | + self._partitioners = base_recipe.partitioners |
| 163 | + self._edge_compile_config = base_recipe.edge_compile_config |
| 164 | + self._executorch_backend_config = base_recipe.executorch_backend_config |
| 165 | + self._pre_edge_transform_passes = base_recipe.pre_edge_transform_passes |
| 166 | + self._edge_transform_passes = base_recipe.edge_transform_passes |
| 167 | + self._transform_check_ir_validity = getattr( |
| 168 | + base_recipe, "transform_check_ir_validity", True |
| 169 | + ) |
| 170 | + self._mode = getattr(base_recipe, "mode", None) |
| 171 | + else: |
| 172 | + self._name = name |
| 173 | + self._quantization_recipe: Optional[QuantizationRecipe] = None |
| 174 | + self._partitioners: Optional[List[Any]] = None |
| 175 | + self._edge_compile_config: Optional[Any] = None |
| 176 | + self._executorch_backend_config: Optional[Any] = None |
| 177 | + self._pre_edge_transform_passes: Optional[Sequence[PassType]] = None |
| 178 | + self._edge_transform_passes: Optional[Sequence[PassType]] = None |
| 179 | + self._transform_check_ir_validity: bool = True |
| 180 | + self._mode = None |
| 181 | + |
| 182 | + def with_name(self, name: str) -> "RecipeBuilder": |
| 183 | + """Set recipe name""" |
| 184 | + self._name = name |
| 185 | + return self |
| 186 | + |
| 187 | + def with_quantization( |
| 188 | + self, quantization_recipe: QuantizationRecipe |
| 189 | + ) -> "RecipeBuilder": |
| 190 | + """Set quantization recipe""" |
| 191 | + self._quantization_recipe = quantization_recipe |
| 192 | + return self |
| 193 | + |
| 194 | + def with_partitioners(self, partitioners: List[Any]) -> "RecipeBuilder": |
| 195 | + """Set partitioners""" |
| 196 | + self._partitioners = partitioners |
| 197 | + return self |
| 198 | + |
| 199 | + def with_pre_edge_passes(self, passes: List[PassType]) -> "RecipeBuilder": |
| 200 | + """Set pre-edge transform passes""" |
| 201 | + self._pre_edge_transform_passes = passes |
| 202 | + return self |
| 203 | + |
| 204 | + def with_edge_passes(self, passes: Sequence[PassType]) -> "RecipeBuilder": |
| 205 | + """Set edge transform passes""" |
| 206 | + self._edge_transform_passes = passes |
| 207 | + return self |
| 208 | + |
| 209 | + def with_edge_compile_config(self, config: Any) -> "RecipeBuilder": |
| 210 | + """Set edge compile config""" |
| 211 | + self._edge_compile_config = config |
| 212 | + return self |
| 213 | + |
| 214 | + def with_backend_config(self, config: Any) -> "RecipeBuilder": |
| 215 | + """Set backend config""" |
| 216 | + self._executorch_backend_config = config |
| 217 | + return self |
| 218 | + |
| 219 | + def with_ir_validity_check(self, check: bool) -> "RecipeBuilder": |
| 220 | + """Set IR validity check""" |
| 221 | + self._transform_check_ir_validity = check |
| 222 | + return self |
| 223 | + |
| 224 | + def with_mode(self, mode: Any) -> "RecipeBuilder": |
| 225 | + """Set export mode""" |
| 226 | + self._mode = mode |
| 227 | + return self |
| 228 | + |
| 229 | + def add_pre_edge_pass(self, pass_: PassType) -> "RecipeBuilder": |
| 230 | + """Add a single pre-edge transform pass to existing passes""" |
| 231 | + if self._pre_edge_transform_passes is None: |
| 232 | + self._pre_edge_transform_passes = [] |
| 233 | + self._pre_edge_transform_passes = list(self._pre_edge_transform_passes) + [ |
| 234 | + pass_ |
| 235 | + ] |
| 236 | + return self |
| 237 | + |
| 238 | + def add_edge_pass(self, pass_: PassType) -> "RecipeBuilder": |
| 239 | + """Add a single edge transform pass to existing passes""" |
| 240 | + if self._edge_transform_passes is None: |
| 241 | + self._edge_transform_passes = [] |
| 242 | + self._edge_transform_passes = list(self._edge_transform_passes) + [pass_] |
| 243 | + return self |
| 244 | + |
| 245 | + def add_partitioner(self, partitioner: Any) -> "RecipeBuilder": |
| 246 | + """Add a partitioner to existing partitioners""" |
| 247 | + if self._partitioners is None: |
| 248 | + self._partitioners = [] |
| 249 | + self._partitioners = list(self._partitioners) + [partitioner] |
| 250 | + return self |
| 251 | + |
| 252 | + def build(self) -> ExportRecipe: |
| 253 | + """Build the final recipe""" |
| 254 | + return ExportRecipe( |
| 255 | + name=self._name, |
| 256 | + quantization_recipe=self._quantization_recipe, |
| 257 | + partitioners=self._partitioners, |
| 258 | + edge_compile_config=self._edge_compile_config, |
| 259 | + executorch_backend_config=self._executorch_backend_config, |
| 260 | + pre_edge_transform_passes=self._pre_edge_transform_passes, |
| 261 | + edge_transform_passes=self._edge_transform_passes, |
| 262 | + transform_check_ir_validity=self._transform_check_ir_validity, |
| 263 | + mode=self._mode, |
| 264 | + ) |
0 commit comments