diff --git a/backends/xnnpack/TARGETS b/backends/xnnpack/TARGETS index d5c6d6303d2..62a703bddb7 100644 --- a/backends/xnnpack/TARGETS +++ b/backends/xnnpack/TARGETS @@ -36,7 +36,10 @@ runtime.python_library( ], deps = [ ":xnnpack_preprocess", + "//executorch/export:lib", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", "//executorch/backends/xnnpack/utils:xnnpack_utils", + "//executorch/backends/xnnpack/recipes:xnnpack_recipe_provider", + "//executorch/backends/xnnpack/recipes:xnnpack_recipe_types", ], ) diff --git a/backends/xnnpack/__init__.py b/backends/xnnpack/__init__.py index 6f4aafa8348..01b73101c86 100644 --- a/backends/xnnpack/__init__.py +++ b/backends/xnnpack/__init__.py @@ -4,11 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from executorch.export import recipe_registry + # Exposed Partitioners in XNNPACK Package from .partition.xnnpack_partitioner import ( XnnpackDynamicallyQuantizedPartitioner, XnnpackPartitioner, ) +from .recipes.xnnpack_recipe_provider import XNNPACKRecipeProvider +from .recipes.xnnpack_recipe_types import XNNPackRecipeType + +# Auto-register XNNPACK recipe provider +recipe_registry.register_backend_recipe_provider(XNNPACKRecipeProvider()) # Exposed Configs in XNNPACK Package from .utils.configs import ( @@ -23,11 +30,11 @@ # XNNPACK Backend from .xnnpack_preprocess import XnnpackBackend - __all__ = [ "XnnpackDynamicallyQuantizedPartitioner", "XnnpackPartitioner", "XnnpackBackend", + "XNNPackRecipeType", "capture_graph_for_xnnpack", "get_xnnpack_capture_config", "get_xnnpack_edge_compile_config", diff --git a/backends/xnnpack/recipes/TARGETS b/backends/xnnpack/recipes/TARGETS new file mode 100644 index 00000000000..60968a5085d --- /dev/null +++ b/backends/xnnpack/recipes/TARGETS @@ -0,0 +1,35 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "xnnpack_recipe_provider", + srcs = [ + "xnnpack_recipe_provider.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//caffe2:torch", + "//executorch/export:lib", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + ":xnnpack_recipe_types", + ], +) + +runtime.python_library( + name = "xnnpack_recipe_types", + srcs = [ + "xnnpack_recipe_types.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/export:lib", + ], +) diff --git a/backends/xnnpack/recipes/xnnpack_recipe_provider.py b/backends/xnnpack/recipes/xnnpack_recipe_provider.py new file mode 100644 index 00000000000..19b30eb8f50 --- /dev/null +++ b/backends/xnnpack/recipes/xnnpack_recipe_provider.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Optional, Sequence + +import torch + +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + ConfigPrecisionType, +) +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) + +from executorch.backends.xnnpack.recipes.xnnpack_recipe_types import XNNPackRecipeType +from executorch.backends.xnnpack.utils.configs import ( + get_xnnpack_edge_compile_config, + get_xnnpack_executorch_backend_config, +) +from executorch.export import ( + BackendRecipeProvider, + ExportRecipe, + QuantizationRecipe, + RecipeType, +) +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig + + +class XNNPACKRecipeProvider(BackendRecipeProvider): + @property + def backend_name(self) -> str: + return "xnnpack" + + def get_supported_recipes(self) -> Sequence[RecipeType]: + return list(XNNPackRecipeType) + + def create_recipe( + self, recipe_type: RecipeType, **kwargs: Any + ) -> Optional[ExportRecipe]: + """Create XNNPACK recipe""" + + if recipe_type not in self.get_supported_recipes(): + return None + + # Validate kwargs + self._validate_recipe_kwargs(recipe_type, **kwargs) + + if recipe_type == XNNPackRecipeType.FP32: + return self._build_fp32_recipe(recipe_type) + + elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL: + return self._build_quantized_recipe( + recipe_type, is_per_channel=True, is_dynamic=True + ) + + elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_PER_TENSOR: + return self._build_quantized_recipe( + recipe_type, is_per_channel=False, is_dynamic=True + ) + + elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_CHANNEL: + return self._build_quantized_recipe( + recipe_type, is_per_channel=True, is_dynamic=False + ) + + elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_TENSOR: + return self._build_quantized_recipe( + recipe_type, is_per_channel=False, is_dynamic=False + ) + + elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL: + return self._build_int8da_intx_weight_recipe( + recipe_type=recipe_type, + is_per_channel=True, + weight_dtype=torch.int4, + ) + + elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR: + group_size = kwargs.get("group_size", 32) + return self._build_int8da_intx_weight_recipe( + recipe_type=recipe_type, + is_per_channel=False, + weight_dtype=torch.int4, + group_size=group_size, + ) + return None + + def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe: + return ExportRecipe( + name=recipe_type.value, + edge_compile_config=get_xnnpack_edge_compile_config(), + executorch_backend_config=get_xnnpack_executorch_backend_config(), + partitioners=[XnnpackPartitioner()], + ) + + def _build_quantized_recipe( + self, + recipe_type: RecipeType, + is_per_channel: bool = True, + is_dynamic: bool = True, + is_qat: bool = False, + ) -> ExportRecipe: + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=is_per_channel, is_dynamic=is_dynamic, is_qat=is_qat + ) + quantizer.set_global(operator_config) + + quant_recipe = QuantizationRecipe(quantizers=[quantizer]) + + precision_type = ( + ConfigPrecisionType.DYNAMIC_QUANT + if is_dynamic + else ConfigPrecisionType.STATIC_QUANT + ) + + return ExportRecipe( + name=recipe_type.value, + quantization_recipe=quant_recipe, + edge_compile_config=get_xnnpack_edge_compile_config(), + executorch_backend_config=get_xnnpack_executorch_backend_config(), + partitioners=[XnnpackPartitioner(config_precision=precision_type)], + ) + + def _build_int8da_intx_weight_recipe( + self, + recipe_type: RecipeType, + is_per_channel: bool = True, + weight_dtype: torch.dtype = torch.int4, + group_size: int = 32, + ) -> ExportRecipe: + if is_per_channel: + weight_granularity = PerAxis(axis=0) + else: + weight_granularity = PerGroup(group_size=group_size) + + config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=weight_granularity, + ) + + quant_recipe = QuantizationRecipe( + quantizers=None, + ao_base_config=[config], + ) + + return ExportRecipe( + name=recipe_type.value, + quantization_recipe=quant_recipe, + edge_compile_config=get_xnnpack_edge_compile_config(), + executorch_backend_config=get_xnnpack_executorch_backend_config(), + partitioners=[XnnpackPartitioner()], + ) + + def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: + if recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR: + expected_keys = {"group_size"} + unexpected = set(kwargs.keys()) - expected_keys + if unexpected: + raise ValueError( + f"Recipe '{recipe_type.value}' only accepts 'group_size' parameter. " + f"Unexpected parameters: {list(unexpected)}" + ) + if "group_size" in kwargs: + group_size = kwargs["group_size"] + if not isinstance(group_size, int): + raise ValueError( + f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}" + ) + elif kwargs: + # All other recipes don't expect any kwargs + unexpected = list(kwargs.keys()) + raise ValueError( + f"Recipe '{recipe_type.value}' does not accept any parameters. " + f"Unexpected parameters: {unexpected}" + ) diff --git a/backends/xnnpack/recipes/xnnpack_recipe_types.py b/backends/xnnpack/recipes/xnnpack_recipe_types.py new file mode 100644 index 00000000000..ec7183eb005 --- /dev/null +++ b/backends/xnnpack/recipes/xnnpack_recipe_types.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from executorch.export import RecipeType + + +class XNNPackRecipeType(RecipeType): + """XNNPACK-specific recipe types""" + + FP32 = "fp32" + # INT8 Dynamic Quantization + INT8_DYNAMIC_PER_CHANNEL = "int8_dynamic_per_channel" + INT8_DYNAMIC_PER_TENSOR = "int8_dynamic_per_tensor" + # INT8 Dynamic Activations INT4 Weight Quantization, Axis = 0 + INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8da_int4w_per_channel" + # INT8 Dynamic Activations INT4 Weight Quantization, default group_size = 32 + # can be overriden by group_size kwarg + INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8da_int4w_per_tensor" + # INT8 Static Activations INT4 Weight Quantization + INT8_STATIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8a_int4w_per_channel" + INT8_STATIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8a_int44w_per_tensor" + # INT8 Static Quantization, needs calibration dataset + INT8_STATIC_PER_CHANNEL = "int8_static_per_channel" + INT8_STATIC_PER_TENSOR = "int8_static_per_tensor" + + @classmethod + def get_backend_name(cls) -> str: + return "xnnpack" diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS index bd3dddd0985..e024721b556 100644 --- a/backends/xnnpack/test/TARGETS +++ b/backends/xnnpack/test/TARGETS @@ -94,3 +94,18 @@ runtime.python_test( "libtorch", ], ) + +runtime.python_test( + name = "test_xnnpack_recipes", + srcs = glob([ + "recipes/*.py", + ]), + deps = [ + "//executorch/backends/xnnpack:xnnpack_delegate", + "//executorch/export:lib", + "//pytorch/vision:torchvision", # @manual + "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/examples/models:models", # @manual + "//executorch/examples/xnnpack:models", # @manual + ], +) diff --git a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py new file mode 100644 index 00000000000..198bf7f1679 --- /dev/null +++ b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py @@ -0,0 +1,249 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +from executorch.backends.xnnpack.recipes.xnnpack_recipe_provider import ( + XNNPACKRecipeProvider, +) +from executorch.backends.xnnpack.recipes.xnnpack_recipe_types import XNNPackRecipeType +from executorch.backends.xnnpack.test.tester import Tester +from executorch.examples.models import MODEL_NAME_TO_MODEL +from executorch.examples.models.model_factory import EagerModelFactory +from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType +from executorch.exir.schema import DelegateCall, Program +from executorch.export import export, ExportRecipe +from torch import nn +from torch.testing._internal.common_quantization import TestHelperModules + + +class TestXnnpackRecipes(unittest.TestCase): + def setUp(self) -> None: + torch._dynamo.reset() + super().setUp() + + def tearDown(self) -> None: + super().tearDown() + + def check_fully_delegated(self, program: Program) -> None: + instructions = program.execution_plan[0].chains[0].instructions + assert instructions is not None + self.assertEqual(len(instructions), 1) + self.assertIsInstance(instructions[0].instr_args, DelegateCall) + + def test_basic_recipe(self) -> None: + m_eager = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + session = export( + model=m_eager, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe(XNNPackRecipeType.FP32), + ) + self.assertTrue( + torch.allclose( + session.run_method("forward", example_inputs[0])[0], + m_eager(*example_inputs[0]), + atol=1e-3, + ) + ) + self.check_fully_delegated(session.get_executorch_program()) + + def test_int8_dynamic_quant_recipe(self) -> None: + test_cases = [ + ExportRecipe.get_recipe(XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL), + ExportRecipe.get_recipe(XNNPackRecipeType.INT8_DYNAMIC_PER_TENSOR), + ] + + for export_recipe in test_cases: + with self.subTest(export_recipe=export_recipe): + with torch.no_grad(): + m_eager = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + session = export( + model=m_eager, + example_inputs=example_inputs, + export_recipe=export_recipe, + ) + self.assertTrue( + torch.allclose( + session.run_method("forward", example_inputs[0])[0], + m_eager(*example_inputs[0]), + atol=1e-3, + ) + ) + self.check_fully_delegated(session.get_executorch_program()) + + def test_int8_static_quant_recipe(self) -> None: + test_cases = [ + ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_CHANNEL), + ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_TENSOR), + ] + + for export_recipe in test_cases: + with self.subTest(export_recipe=export_recipe): + with torch.no_grad(): + m_eager = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + session = export( + model=m_eager, + example_inputs=example_inputs, + export_recipe=export_recipe, + ) + self.assertTrue( + torch.allclose( + session.run_method("forward", example_inputs[0])[0], + m_eager(*example_inputs[0]), + atol=1e-3, + ) + ) + self.check_fully_delegated(session.get_executorch_program()) + + def test_8a4w_recipe(self) -> None: + class SimpleLinearModel(nn.Module): + def __init__(self) -> None: + super(SimpleLinearModel, self).__init__() + self.layer1 = nn.Linear(32, 2) + + def forward(self, x) -> torch.Tensor: + x = self.layer1(x) + return x + + test_cases = [ + ExportRecipe.get_recipe( + XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL, + ), + ExportRecipe.get_recipe( + XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, + group_size=32, + ), + ] + + for export_recipe in test_cases: + with self.subTest(export_recipe=export_recipe): + model = SimpleLinearModel() + example_inputs = [(torch.randn(1, 32),)] + session = export( + model=model, + example_inputs=example_inputs, + export_recipe=export_recipe, + ) + self.assertTrue( + torch.allclose( + session.run_method("forward", example_inputs[0])[0], + model(*example_inputs[0]), + atol=1e-2, + ) + ) + self.check_fully_delegated(session.get_executorch_program()) + + def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType: + # Map QuantType to corresponding recipe name. + if quant_type == QuantType.STATIC_PER_CHANNEL: + return XNNPackRecipeType.INT8_STATIC_PER_CHANNEL + elif quant_type == QuantType.DYNAMIC_PER_CHANNEL: + return XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL + elif quant_type == QuantType.STATIC_PER_TENSOR: + return XNNPackRecipeType.INT8_STATIC_PER_TENSOR + elif quant_type == QuantType.NONE: + return XNNPackRecipeType.FP32 + else: + raise ValueError(f"Unsupported QuantType: {quant_type}") + + def _test_model_with_factory(self, model_name: str) -> None: + if model_name not in MODEL_NAME_TO_MODEL: + self.skipTest(f"Model {model_name} not found in MODEL_NAME_TO_MODEL") + return + + if model_name not in MODEL_NAME_TO_OPTIONS: + self.skipTest(f"Model {model_name} not found in MODEL_NAME_TO_OPTIONS") + return + + # Create model using factory + model, example_inputs, _example_kwarg_inputs, dynamic_shapes = ( + EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL[model_name]) + ) + model = model.eval() + + # Get the appropriate recipe based on quantization type + options = MODEL_NAME_TO_OPTIONS[model_name] + recipe_name = self._get_recipe_for_quant_type(options.quantization) + + # Export with recipe + session = export( + model=model, + example_inputs=[example_inputs], + export_recipe=ExportRecipe.get_recipe(recipe_name), + dynamic_shapes=dynamic_shapes, + ) + + # Verify outputs match + Tester._assert_outputs_equal( + session.run_method("forward", example_inputs)[0], + model(*example_inputs), + atol=1e-3, + ) + + def test_all_models_with_recipes(self) -> None: + models_to_test = [ + "linear", + "add", + "add_mul", + "ic3", + "mv2", + "mv3", + "resnet18", + "resnet50", + "vit", + "w2l", + "llama2", + ] + for model_name in models_to_test: + with self.subTest(model=model_name): + self._test_model_with_factory(model_name) + + def test_validate_recipe_kwargs_fp32(self) -> None: + provider = XNNPACKRecipeProvider() + + with self.assertRaises(ValueError) as cm: + provider.create_recipe(XNNPackRecipeType.FP32, invalid_param=123) + + error_msg = str(cm.exception) + self.assertIn("Recipe 'fp32' does not accept any parameters", error_msg) + + def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size( + self, + ) -> None: + provider = XNNPACKRecipeProvider() + + # Should not raise any exception + recipe_w_default_group = provider.create_recipe( + XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR + ) + self.assertIsNotNone(recipe_w_default_group) + + recipe = provider.create_recipe( + XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, group_size=64 + ) + self.assertIsNotNone(recipe) + + def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size( + self, + ) -> None: + provider = XNNPACKRecipeProvider() + + with self.assertRaises(ValueError) as cm: + provider.create_recipe( + XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, + group_size="32", # String instead of int + ) + + error_msg = str(cm.exception) + self.assertIn( + "Parameter 'group_size' must be an integer, got str: 32", error_msg + ) diff --git a/examples/xnnpack/targets.bzl b/examples/xnnpack/targets.bzl index ce9575e8cca..e710c478250 100644 --- a/examples/xnnpack/targets.bzl +++ b/examples/xnnpack/targets.bzl @@ -14,6 +14,7 @@ def define_common_targets(): ], visibility = [ "//executorch/examples/xnnpack/...", + "//executorch/backends/xnnpack/test/...", ], deps = [ "//executorch/examples/models:models", # @manual diff --git a/export/TARGETS b/export/TARGETS index 77d6b07795e..defb508b33a 100644 --- a/export/TARGETS +++ b/export/TARGETS @@ -1,12 +1,16 @@ -load("@fbcode_macros//build_defs:python_library.bzl", "python_library") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") oncall("executorch") -python_library( +runtime.python_library( name = "recipe", srcs = [ "recipe.py", ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], deps = [ "//caffe2:torch", "//executorch/exir/backend:backend_api", @@ -16,11 +20,15 @@ python_library( ] ) -python_library( +runtime.python_library( name = "export", srcs = [ "export.py", ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], deps = [ ":recipe", "//executorch/runtime:runtime", @@ -28,11 +36,15 @@ python_library( ] ) -python_library( +runtime.python_library( name = "lib", srcs = [ "__init__.py", ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], deps = [ ":export", ":recipe", @@ -41,12 +53,15 @@ python_library( ], ) - -python_library( +runtime.python_library( name = "recipe_registry", srcs = [ "recipe_registry.py", ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], deps = [ ":recipe", ":recipe_provider" @@ -54,7 +69,7 @@ python_library( ) -python_library( +runtime.python_library( name = "recipe_provider", srcs = [ "recipe_provider.py", diff --git a/src/executorch/examples/xnnpack b/src/executorch/examples/xnnpack new file mode 120000 index 00000000000..ce7b138dfc6 --- /dev/null +++ b/src/executorch/examples/xnnpack @@ -0,0 +1 @@ +../../../examples/xnnpack \ No newline at end of file diff --git a/src/executorch/export b/src/executorch/export new file mode 120000 index 00000000000..1773c569c7d --- /dev/null +++ b/src/executorch/export @@ -0,0 +1 @@ +../../export \ No newline at end of file