-
Notifications
You must be signed in to change notification settings - Fork 691
Arm backend: Add recipe infrastructure #14849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# @noautodeps | ||
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") | ||
|
||
runtime.python_library( | ||
name = "arm_recipe_types", | ||
srcs = ["arm_recipe_types.py"], | ||
deps = [ | ||
"//executorch/export:lib", | ||
], | ||
) | ||
|
||
runtime.python_library( | ||
name = "arm_recipe", | ||
srcs = ["recipe.py"], | ||
deps = [ | ||
"//executorch/backends/arm:arm_compile_spec", | ||
"//executorch/backends/arm:_factory", | ||
"//executorch/exir:pass_manager", | ||
"//executorch/exir:lib", | ||
"//executorch/export:lib", | ||
], | ||
) | ||
|
||
runtime.python_library( | ||
name = "arm_recipe_provider", | ||
srcs = ["arm_recipe_provider.py"], | ||
deps = [ | ||
":arm_recipe", | ||
":arm_recipe_types", | ||
"//executorch/backends/arm/quantizer:lib", | ||
"//executorch/backends/arm/test:common", | ||
"//executorch/backends/arm:_factory", | ||
"//executorch/export:lib", | ||
], | ||
) | ||
|
||
runtime.python_library( | ||
name = "lib", | ||
srcs = ["__init__.py"], | ||
deps = [ | ||
":arm_recipe", | ||
":arm_recipe_provider", | ||
":arm_recipe_types", | ||
"//executorch/export:lib", | ||
], | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from executorch.export import recipe_registry # type: ignore[import-untyped] | ||
|
||
from .recipe import ArmExportRecipe, TargetRecipe # noqa # usort: skip | ||
from .arm_recipe_types import ArmRecipeType # noqa # usort: skip | ||
from .arm_recipe_provider import ArmRecipeProvider # noqa # usort: skip | ||
|
||
# Auto-register Arm recipe provider | ||
recipe_registry.register_backend_recipe_provider(ArmRecipeProvider()) | ||
|
||
__all__ = ["ArmRecipeType", "ArmExportRecipe", "TargetRecipe"] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,145 @@ | ||||||
# Copyright 2025 Arm Limited and/or its affiliates. | ||||||
# | ||||||
# This source code is licensed under the BSD-style license found in the | ||||||
# LICENSE file in the root directory of this source tree. | ||||||
|
||||||
# pyre-strict | ||||||
|
||||||
from typing import Any, Callable, Optional, Sequence | ||||||
|
||||||
from executorch.backends.arm.quantizer import ( | ||||||
get_symmetric_quantization_config, | ||||||
TOSAQuantizer, | ||||||
) | ||||||
from executorch.backends.arm.recipe import ArmExportRecipe, ArmRecipeType, TargetRecipe | ||||||
from executorch.backends.arm.test import common | ||||||
from executorch.backends.arm.util._factory import create_quantizer | ||||||
from executorch.export import ( # type: ignore[import-untyped] | ||||||
BackendRecipeProvider, | ||||||
ExportRecipe, | ||||||
QuantizationRecipe, | ||||||
RecipeType, | ||||||
) | ||||||
|
||||||
QuantizerConfigurator = Callable[[TOSAQuantizer], None] | ||||||
|
||||||
|
||||||
def global_int8_per_channel(quantizer: TOSAQuantizer): | ||||||
quantizer.set_global(get_symmetric_quantization_config(is_per_channel=True)) | ||||||
|
||||||
|
||||||
def global_int8_per_tensor(quantizer: TOSAQuantizer): | ||||||
quantizer.set_global(get_symmetric_quantization_config(is_per_channel=False)) | ||||||
|
||||||
|
||||||
class ArmRecipeProvider(BackendRecipeProvider): | ||||||
@property | ||||||
def backend_name(self) -> str: | ||||||
return ArmRecipeType.get_backend_name() | ||||||
|
||||||
def get_supported_recipes(self) -> Sequence[RecipeType]: | ||||||
return list(ArmRecipeType) | ||||||
|
||||||
@classmethod | ||||||
def build_export_recipe( | ||||||
cls, | ||||||
recipe_type: RecipeType, | ||||||
target_recipe: TargetRecipe, | ||||||
quantization_configurators: Optional[list[QuantizerConfigurator]] = None, | ||||||
) -> ArmExportRecipe: | ||||||
|
||||||
if quantization_configurators is not None: | ||||||
quantizer = create_quantizer(target_recipe.compile_spec) | ||||||
for configure in quantization_configurators: | ||||||
configure(quantizer) | ||||||
quantization_recipe = QuantizationRecipe([quantizer]) | ||||||
else: | ||||||
quantization_recipe = None | ||||||
|
||||||
return ArmExportRecipe( | ||||||
name=str(recipe_type), | ||||||
target_recipe=target_recipe, | ||||||
quantization_recipe=quantization_recipe, | ||||||
) | ||||||
|
||||||
def create_recipe( | ||||||
self, recipe_type: RecipeType, **kwargs: Any | ||||||
) -> Optional[ExportRecipe]: | ||||||
"""Create arm recipe""" | ||||||
return create_recipe(recipe_type, **kwargs) | ||||||
|
||||||
|
||||||
def create_recipe(recipe_type: RecipeType, **kwargs: Any) -> ArmExportRecipe: | ||||||
"""Create an ArmExportRecipe depending on the ArmRecipeType enum, with some kwargs. See documentation for | ||||||
the ArmRecipeType for the available kwargs.""" | ||||||
|
||||||
match recipe_type: | ||||||
case ArmRecipeType.TOSA_FP: | ||||||
return ArmRecipeProvider.build_export_recipe( | ||||||
recipe_type, | ||||||
TargetRecipe(common.get_tosa_compile_spec("TOSA-1.0+FP", **kwargs)), | ||||||
) | ||||||
case ArmRecipeType.TOSA_INT8_STATIC_PER_TENSOR: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need any kwarg validation, to let user know if they've configured the recipes wrong? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea, will do that in an upcoming patch. |
||||||
return ArmRecipeProvider.build_export_recipe( | ||||||
recipe_type, | ||||||
TargetRecipe(common.get_tosa_compile_spec("TOSA-1.0+INT", **kwargs)), | ||||||
[global_int8_per_tensor], | ||||||
) | ||||||
case ArmRecipeType.TOSA_INT8_STATIC_PER_CHANNEL: | ||||||
return ArmRecipeProvider.build_export_recipe( | ||||||
recipe_type, | ||||||
TargetRecipe(common.get_tosa_compile_spec("TOSA-1.0+INT", **kwargs)), | ||||||
[global_int8_per_channel], | ||||||
) | ||||||
case ArmRecipeType.ETHOSU_U55_INT8_STATIC_PER_CHANNEL: | ||||||
return ArmRecipeProvider.build_export_recipe( | ||||||
recipe_type, | ||||||
TargetRecipe(common.get_u55_compile_spec(**kwargs)), | ||||||
[global_int8_per_channel], | ||||||
) | ||||||
case ArmRecipeType.ETHOSU_U55_INT8_STATIC_PER_TENSOR: | ||||||
return ArmRecipeProvider.build_export_recipe( | ||||||
recipe_type, | ||||||
TargetRecipe(common.get_u55_compile_spec(**kwargs)), | ||||||
[global_int8_per_tensor], | ||||||
) | ||||||
case ArmRecipeType.ETHOSU_U85_INT8_STATIC_PER_TENSOR: | ||||||
return ArmRecipeProvider.build_export_recipe( | ||||||
recipe_type, | ||||||
TargetRecipe(common.get_u85_compile_spec(**kwargs)), | ||||||
[global_int8_per_tensor], | ||||||
) | ||||||
case ArmRecipeType.ETHOSU_U85_INT8_STATIC_PER_CHANNEL: | ||||||
return ArmRecipeProvider.build_export_recipe( | ||||||
recipe_type, | ||||||
TargetRecipe(common.get_u85_compile_spec(**kwargs)), | ||||||
[global_int8_per_channel], | ||||||
) | ||||||
|
||||||
case ArmRecipeType.VGF_FP: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit
Suggested change
|
||||||
return ArmRecipeProvider.build_export_recipe( | ||||||
recipe_type, | ||||||
TargetRecipe(common.get_vgf_compile_spec("TOSA-1.0+FP", **kwargs)), | ||||||
) | ||||||
case ArmRecipeType.VGF_INT8_STATIC_PER_TENSOR: | ||||||
return ArmRecipeProvider.build_export_recipe( | ||||||
recipe_type, | ||||||
TargetRecipe(common.get_vgf_compile_spec("TOSA-1.0+INT", **kwargs)), | ||||||
[global_int8_per_tensor], | ||||||
) | ||||||
case ArmRecipeType.VGF_INT8_STATIC_PER_CHANNEL: | ||||||
return ArmRecipeProvider.build_export_recipe( | ||||||
recipe_type, | ||||||
TargetRecipe(common.get_vgf_compile_spec("TOSA-1.0+INT", **kwargs)), | ||||||
[global_int8_per_channel], | ||||||
) | ||||||
case ArmRecipeType.CUSTOM: | ||||||
if "recipe" not in kwargs or not isinstance( | ||||||
kwargs["recipe"], ArmExportRecipe | ||||||
): | ||||||
raise ValueError( | ||||||
"ArmRecipeType.CUSTOM requires a kwarg 'recipe' that provides the ArmExportRecipe" | ||||||
) | ||||||
return kwargs["recipe"] | ||||||
case _: | ||||||
raise ValueError(f"Unsupported recipe type {recipe_type}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from executorch.export import RecipeType # type: ignore[import-untyped] | ||
|
||
|
||
class ArmRecipeType(RecipeType): | ||
"""Arm-specific recipe types""" | ||
|
||
TOSA_FP = "arm_tosa_fp" | ||
""" Kwargs: | ||
- custom_path: str=None, | ||
- tosa_debug_mode: TosaCompileSpec.DebugMode | None = None, | ||
""" | ||
TOSA_INT8_STATIC_PER_TENSOR = "arm_tosa_int8_static_per_tensor" | ||
""" Kwargs: | ||
- custom_path: str=None, | ||
- tosa_debug_mode: TosaCompileSpec.DebugMode | None = None, | ||
""" | ||
TOSA_INT8_STATIC_PER_CHANNEL = "arm_tosa_int8_static_per_channel" | ||
""" Kwargs: | ||
- custom_path: str=None, | ||
- tosa_debug_mode: TosaCompileSpec.DebugMode | None = None, | ||
""" | ||
ETHOSU_U55_INT8_STATIC_PER_CHANNEL = "arm_ethosu_u55_int_static_per_channel" | ||
""" Kwargs: | ||
- macs: int = 128, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need these comments? I am afraid they will get out of sync without being tested.. we should have this in validation as you are planning.. |
||
- system_config: str = "Ethos_U55_High_End_Embedded", | ||
- memory_mode: str = "Shared_Sram", | ||
- extra_flags: str = "--debug-force-regor --output-format=raw", | ||
- custom_path: Optional[str] = None, | ||
- config: Optional[str] = None, | ||
- tosa_debug_mode: EthosUCompileSpec.DebugMode | None = None, | ||
""" | ||
ETHOSU_U55_INT8_STATIC_PER_TENSOR = "arm_ethosu_u55_int_static_per_channel" | ||
""" Kwargs: | ||
- macs: int = 128, | ||
- system_config: str = "Ethos_U55_High_End_Embedded", | ||
- memory_mode: str = "Shared_Sram", | ||
- extra_flags: str = "--debug-force-regor --output-format=raw", | ||
- custom_path: Optional[str] = None, | ||
- config: Optional[str] = None, | ||
- tosa_debug_mode: EthosUCompileSpec.DebugMode | None = None, | ||
""" | ||
ETHOSU_U85_INT8_STATIC_PER_TENSOR = "arm_ethosu_u85_int_static_per_tensor" | ||
""" Kwargs: | ||
- macs: int = 128, | ||
- system_config="Ethos_U85_SYS_DRAM_Mid", | ||
- memory_mode="Shared_Sram", | ||
- extra_flags="--output-format=raw", | ||
- custom_path: Optional[str] = None, | ||
- config: Optional[str] = None, | ||
- tosa_debug_mode: EthosUCompileSpec.DebugMode | None = None, | ||
""" | ||
ETHOSU_U85_INT8_STATIC_PER_CHANNEL = "arm_ethosu_u85_int_static_per_channel" | ||
""" Kwargs: | ||
- macs: int = 128, | ||
- system_config="Ethos_U85_SYS_DRAM_Mid", | ||
- memory_mode="Shared_Sram", | ||
- extra_flags="--output-format=raw", | ||
- custom_path: Optional[str] = None, | ||
- config: Optional[str] = None, | ||
- tosa_debug_mode: EthosUCompileSpec.DebugMode | None = None, | ||
""" | ||
|
||
VGF_FP = "arm_vgf_fp" | ||
""" Kwargs: | ||
- compiler_flags: Optional[str] = "", | ||
- custom_path=None, | ||
- tosa_debug_mode: VgfCompileSpec.DebugMode | None = None, | ||
""" | ||
VGF_INT8_STATIC_PER_TENSOR = "arm_vgf_int8_static_per_tensor" | ||
""" Kwargs: | ||
- compiler_flags: Optional[str] = "", | ||
- custom_path=None, | ||
- tosa_debug_mode: VgfCompileSpec.DebugMode | None = None, | ||
""" | ||
VGF_INT8_STATIC_PER_CHANNEL = "arm_vgf_int8_static_per_channel" | ||
""" Kwargs: | ||
- compiler_flags: Optional[str] = "", | ||
- custom_path=None, | ||
- tosa_debug_mode: VgfCompileSpec.DebugMode | None = None, | ||
""" | ||
CUSTOM = "Provide your own ArmRecipeType to the kwarg 'recipe'." | ||
|
||
@classmethod | ||
def get_backend_name(cls) -> str: | ||
return "arm" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass, field | ||
from typing import List | ||
|
||
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec | ||
from executorch.backends.arm.util._factory import create_partitioner | ||
from executorch.exir import EdgeCompileConfig | ||
from executorch.exir.pass_manager import PassType | ||
from executorch.export import ( # type: ignore[import-untyped] | ||
ExportRecipe, | ||
LoweringRecipe, | ||
QuantizationRecipe, | ||
) | ||
|
||
|
||
@dataclass | ||
class TargetRecipe: | ||
"""Contains target-level export configuration.""" | ||
|
||
compile_spec: ArmCompileSpec | ||
edge_compile_config: EdgeCompileConfig = field( | ||
default_factory=lambda: EdgeCompileConfig(_check_ir_validity=False) | ||
) | ||
edge_transform_passes: List[PassType] = field(default_factory=lambda: []) | ||
|
||
|
||
class ArmExportRecipe(ExportRecipe): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add at least a basic test to make this PR self containing.. |
||
"""Wraps ExportRecipe to provide the constructor we want and easy access to some variables.""" | ||
|
||
def __init__( | ||
self, | ||
name, | ||
target_recipe: TargetRecipe, | ||
quantization_recipe: QuantizationRecipe | None, | ||
): | ||
self.compile_spec = target_recipe.compile_spec | ||
self.edge_transform_passes = target_recipe.edge_transform_passes | ||
|
||
lowering_recipe = LoweringRecipe( | ||
[create_partitioner(self.compile_spec)], | ||
edge_transform_passes=[lambda _, __: target_recipe.edge_transform_passes], | ||
edge_compile_config=target_recipe.edge_compile_config, | ||
) | ||
super().__init__( | ||
name=name, | ||
quantization_recipe=quantization_recipe, | ||
lowering_recipe=lowering_recipe, | ||
executorch_backend_config=None, | ||
pipeline_stages=None, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for adding these recipes, is it possible to add a couple of model tests using any of these recipes, just to make sure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I plan on adding tests using recipes in an upcoming patch.