diff --git a/backends/transforms/duplicate_dynamic_quant_chain.py b/backends/transforms/duplicate_dynamic_quant_chain.py index 2ca65eec45f..fecc3543ca2 100644 --- a/backends/transforms/duplicate_dynamic_quant_chain.py +++ b/backends/transforms/duplicate_dynamic_quant_chain.py @@ -8,6 +8,7 @@ import operator import torch +from executorch.exir.program._program import _update_exported_program_graph_module from torch.ao.quantization.pt2e.utils import ( _filter_sym_size_users, @@ -194,3 +195,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph_module.graph.eliminate_dead_code() graph_module.recompile() return PassResult(graph_module, True) + + +def duplicate_dynamic_quant_chain_pass( + ep: torch.export.ExportedProgram, +) -> torch.export.ExportedProgram: + res = DuplicateDynamicQuantChainPass()(ep.graph_module) + assert res is not None + return _update_exported_program_graph_module(ep, res.graph_module) diff --git a/backends/xnnpack/TARGETS b/backends/xnnpack/TARGETS index d5c6d6303d2..aa91a86c221 100644 --- a/backends/xnnpack/TARGETS +++ b/backends/xnnpack/TARGETS @@ -38,5 +38,6 @@ runtime.python_library( ":xnnpack_preprocess", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", "//executorch/backends/xnnpack/utils:xnnpack_utils", + "//executorch/backends/xnnpack/recipes:xnnpack_recipes" ], ) diff --git a/backends/xnnpack/__init__.py b/backends/xnnpack/__init__.py index 6f4aafa8348..8ea2c6ff3f0 100644 --- a/backends/xnnpack/__init__.py +++ b/backends/xnnpack/__init__.py @@ -9,6 +9,7 @@ XnnpackDynamicallyQuantizedPartitioner, XnnpackPartitioner, ) +from .recipes.recipes import get_xnnpack_recipe # Exposed Configs in XNNPACK Package from .utils.configs import ( @@ -23,12 +24,12 @@ # XNNPACK Backend from .xnnpack_preprocess import XnnpackBackend - __all__ = [ "XnnpackDynamicallyQuantizedPartitioner", "XnnpackPartitioner", "XnnpackBackend", "capture_graph_for_xnnpack", + "get_xnnpack_recipe", "get_xnnpack_capture_config", "get_xnnpack_edge_compile_config", "get_xnnpack_executorch_backend_config", diff --git a/backends/xnnpack/recipes/TARGETS b/backends/xnnpack/recipes/TARGETS new file mode 100644 index 00000000000..c070b52e920 --- /dev/null +++ b/backends/xnnpack/recipes/TARGETS @@ -0,0 +1,19 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + + +oncall("executorch") + +python_library( + name = "xnnpack_recipes", + srcs = [ + "recipes.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/export:recipe", + "//executorch/backends/transforms:duplicate_dynamic_quant_chain", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + ], +) diff --git a/backends/xnnpack/recipes/recipes.py b/backends/xnnpack/recipes/recipes.py new file mode 100644 index 00000000000..6278351eda5 --- /dev/null +++ b/backends/xnnpack/recipes/recipes.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Any, Callable + +from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( + duplicate_dynamic_quant_chain_pass, +) + +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner + +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from executorch.export.recipe import ExportRecipe, QuantizationRecipe +from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight + + +def get_generic_fp32_cpu_recipe() -> ExportRecipe: + return ExportRecipe( + name="fp32_recipe", + quantization_recipe=None, + partitioners=[XnnpackPartitioner()], + ) + + +def get_dynamic_quant_recipe() -> ExportRecipe: + # Create quantizer + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + quantizer.set_global(operator_config) + + # Create quantization recipe + quant_recipe = QuantizationRecipe( + quantizer=quantizer, + ) + + # Create export recipe + return ExportRecipe( + name="dynamic_quant_recipe", + quantization_recipe=quant_recipe, + partitioners=[XnnpackPartitioner()], + pre_edge_transform_passes=duplicate_dynamic_quant_chain_pass, + ) + + +def get_8a4w_config(group_size: int = 32) -> ExportRecipe: + # Create quantization recipe + quant_recipe = QuantizationRecipe( + quantizer=None, + ao_base_config=[ + int8_dynamic_activation_int4_weight(group_size=32), + ], + ) + + # Create export recipe + return ExportRecipe( + name="8a4w_quant_recipe", + quantization_recipe=quant_recipe, + partitioners=[XnnpackPartitioner()], + ) + + +RECIPE_MAP: dict[str, Callable[[], ExportRecipe]] = { + "FP32_CPU_ACCELERATED_RECIPE": get_generic_fp32_cpu_recipe, + "DYNAMIC_QUANT_CPU_ACCELERATED_RECIPE": get_dynamic_quant_recipe, + "8A4W_CPU_ACCELERATED_RECIPE": get_8a4w_config, +} + + +def get_xnnpack_recipe(recipe_name: str, **kwargs: Any) -> ExportRecipe: + assert recipe_name in RECIPE_MAP, f"Recipe {recipe_name} not found." + return RECIPE_MAP[recipe_name](**kwargs) diff --git a/backends/xnnpack/targets.bzl b/backends/xnnpack/targets.bzl index aee5104b17a..3aa1a3622a3 100644 --- a/backends/xnnpack/targets.bzl +++ b/backends/xnnpack/targets.bzl @@ -67,6 +67,7 @@ def define_common_targets(): "//executorch/extension/threadpool:threadpool", "//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix, "//executorch/runtime/executor:pte_data_map" + aten_suffix, + "//executorch/backends/xnnpack/recipes:xnnpack_recipes", ], # XnnpackBackend.cpp needs to compile with executor as whole # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS index bd3dddd0985..14b7b38de73 100644 --- a/backends/xnnpack/test/TARGETS +++ b/backends/xnnpack/test/TARGETS @@ -94,3 +94,14 @@ runtime.python_test( "libtorch", ], ) + +runtime.python_test( + name = "test_xnnpack_recipes", + srcs = glob([ + "recipes/*.py", + ]), + deps = [ + "//executorch/backends/xnnpack:xnnpack_delegate", + "//executorch/export:lib", + ], +) diff --git a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py new file mode 100644 index 00000000000..0f547bfe1e6 --- /dev/null +++ b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +from executorch.backends.xnnpack import get_xnnpack_recipe +from executorch.exir.schema import DelegateCall, Program +from executorch.export import export +from torch import nn +from torch.testing._internal.common_quantization import TestHelperModules + + +class TestXnnpackRecipes(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def tearDown(self) -> None: + super().tearDown() + + def check_fully_delegated(self, program: Program) -> None: + instructions = program.execution_plan[0].chains[0].instructions + assert instructions is not None + self.assertEqual(len(instructions), 1) + self.assertIsInstance(instructions[0].instr_args, DelegateCall) + + def test_basic_recipe(self) -> None: + m_eager = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + session = export( + model=m_eager, + example_inputs=example_inputs, + export_recipe=get_xnnpack_recipe("FP32_CPU_ACCELERATED_RECIPE"), + ) + self.assertTrue( + torch.allclose( + session.run_method("forward", example_inputs[0])[0], + m_eager(*example_inputs[0]), + ) + ) + self.check_fully_delegated(session.get_executorch_program()) + + def test_dynamic_quant_recipe(self) -> None: + with torch.no_grad(): + m_eager = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + session = export( + model=m_eager, + example_inputs=example_inputs, + export_recipe=get_xnnpack_recipe( + "DYNAMIC_QUANT_CPU_ACCELERATED_RECIPE" + ), + ) + self.assertTrue( + torch.allclose( + session.run_method("forward", example_inputs[0])[0], + m_eager(*example_inputs[0]), + atol=1e-1, + ) + ) + self.check_fully_delegated(session.get_executorch_program()) + + def test_8a4w_recipe(self) -> None: + class SimpleLinearModel(nn.Module): + def __init__(self) -> None: + super(SimpleLinearModel, self).__init__() + self.layer1 = nn.Linear(32, 2) + + def forward(self, x) -> torch.Tensor: + x = self.layer1(x) + return x + + model = SimpleLinearModel() + example_inputs = [(torch.randn(1, 32),)] + session = export( + model=model, + example_inputs=example_inputs, + export_recipe=get_xnnpack_recipe( + "8A4W_CPU_ACCELERATED_RECIPE", group_size=32 + ), + ) + self.assertTrue( + torch.allclose( + session.run_method("forward", example_inputs[0])[0], + model(*example_inputs[0]), + atol=1e-1, + ) + ) + self.check_fully_delegated(session.get_executorch_program())