-
Notifications
You must be signed in to change notification settings - Fork 698
Add some basic xnnpack recipes #10035
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e42c965
910ac24
fab9fe8
6ae1d97
3c52f3a
be8de30
a02decd
62592eb
3e56ded
2b1e447
ec902d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| load("@fbcode_macros//build_defs:python_library.bzl", "python_library") | ||
|
|
||
|
|
||
| oncall("executorch") | ||
|
|
||
| python_library( | ||
| name = "xnnpack_recipes", | ||
| srcs = [ | ||
| "recipes.py", | ||
| ], | ||
| deps = [ | ||
| "//caffe2:torch", | ||
| "//executorch/exir:lib", | ||
| "//executorch/export:recipe", | ||
| "//executorch/backends/transforms:duplicate_dynamic_quant_chain", | ||
| "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", | ||
| "//executorch/backends/xnnpack/partition:xnnpack_partitioner", | ||
| ], | ||
| ) |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,80 @@ | ||||||||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||||
| # All rights reserved. | ||||||||||
| # | ||||||||||
| # This source code is licensed under the BSD-style license found in the | ||||||||||
| # LICENSE file in the root directory of this source tree. | ||||||||||
|
|
||||||||||
| # pyre-strict | ||||||||||
| from typing import Any, Callable | ||||||||||
|
|
||||||||||
| from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( | ||||||||||
| duplicate_dynamic_quant_chain_pass, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner | ||||||||||
|
|
||||||||||
| from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( | ||||||||||
| get_symmetric_quantization_config, | ||||||||||
| XNNPACKQuantizer, | ||||||||||
| ) | ||||||||||
| from executorch.export.recipe import ExportRecipe, QuantizationRecipe | ||||||||||
| from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def get_generic_fp32_cpu_recipe() -> ExportRecipe: | ||||||||||
| return ExportRecipe( | ||||||||||
| name="fp32_recipe", | ||||||||||
| quantization_recipe=None, | ||||||||||
| partitioners=[XnnpackPartitioner()], | ||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def get_dynamic_quant_recipe() -> ExportRecipe: | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. organizationally maybe quant recipes can be in a separate folder? |
||||||||||
| # Create quantizer | ||||||||||
| quantizer = XNNPACKQuantizer() | ||||||||||
| operator_config = get_symmetric_quantization_config( | ||||||||||
| is_per_channel=True, is_dynamic=True | ||||||||||
| ) | ||||||||||
| quantizer.set_global(operator_config) | ||||||||||
|
|
||||||||||
| # Create quantization recipe | ||||||||||
| quant_recipe = QuantizationRecipe( | ||||||||||
| quantizer=quantizer, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Create export recipe | ||||||||||
| return ExportRecipe( | ||||||||||
| name="dynamic_quant_recipe", | ||||||||||
| quantization_recipe=quant_recipe, | ||||||||||
| partitioners=[XnnpackPartitioner()], | ||||||||||
| pre_edge_transform_passes=duplicate_dynamic_quant_chain_pass, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def get_8a4w_config(group_size: int = 32) -> ExportRecipe: | ||||||||||
| # Create quantization recipe | ||||||||||
| quant_recipe = QuantizationRecipe( | ||||||||||
| quantizer=None, | ||||||||||
| ao_base_config=[ | ||||||||||
| int8_dynamic_activation_int4_weight(group_size=32), | ||||||||||
| ], | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Create export recipe | ||||||||||
| return ExportRecipe( | ||||||||||
| name="8a4w_quant_recipe", | ||||||||||
| quantization_recipe=quant_recipe, | ||||||||||
| partitioners=[XnnpackPartitioner()], | ||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| RECIPE_MAP: dict[str, Callable[[], ExportRecipe]] = { | ||||||||||
| "FP32_CPU_ACCELERATED_RECIPE": get_generic_fp32_cpu_recipe, | ||||||||||
| "DYNAMIC_QUANT_CPU_ACCELERATED_RECIPE": get_dynamic_quant_recipe, | ||||||||||
|
Comment on lines
+72
to
+73
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit
Suggested change
|
||||||||||
| "8A4W_CPU_ACCELERATED_RECIPE": get_8a4w_config, | ||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def get_xnnpack_recipe(recipe_name: str, **kwargs: Any) -> ExportRecipe: | ||||||||||
| assert recipe_name in RECIPE_MAP, f"Recipe {recipe_name} not found." | ||||||||||
| return RECIPE_MAP[recipe_name](**kwargs) | ||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| # pyre-strict | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from executorch.backends.xnnpack import get_xnnpack_recipe | ||
| from executorch.exir.schema import DelegateCall, Program | ||
| from executorch.export import export | ||
| from torch import nn | ||
| from torch.testing._internal.common_quantization import TestHelperModules | ||
|
|
||
|
|
||
| class TestXnnpackRecipes(unittest.TestCase): | ||
| def setUp(self) -> None: | ||
| super().setUp() | ||
|
|
||
| def tearDown(self) -> None: | ||
| super().tearDown() | ||
|
Comment on lines
+20
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need these? |
||
|
|
||
| def check_fully_delegated(self, program: Program) -> None: | ||
| instructions = program.execution_plan[0].chains[0].instructions | ||
| assert instructions is not None | ||
| self.assertEqual(len(instructions), 1) | ||
| self.assertIsInstance(instructions[0].instr_args, DelegateCall) | ||
|
|
||
| def test_basic_recipe(self) -> None: | ||
| m_eager = TestHelperModules.TwoLinearModule().eval() | ||
| example_inputs = [(torch.randn(9, 8),)] | ||
| session = export( | ||
| model=m_eager, | ||
| example_inputs=example_inputs, | ||
| export_recipe=get_xnnpack_recipe("FP32_CPU_ACCELERATED_RECIPE"), | ||
| ) | ||
| self.assertTrue( | ||
| torch.allclose( | ||
| session.run_method("forward", example_inputs[0])[0], | ||
| m_eager(*example_inputs[0]), | ||
| ) | ||
| ) | ||
| self.check_fully_delegated(session.get_executorch_program()) | ||
|
|
||
| def test_dynamic_quant_recipe(self) -> None: | ||
| with torch.no_grad(): | ||
| m_eager = TestHelperModules.TwoLinearModule().eval() | ||
| example_inputs = [(torch.randn(9, 8),)] | ||
| session = export( | ||
| model=m_eager, | ||
| example_inputs=example_inputs, | ||
| export_recipe=get_xnnpack_recipe( | ||
| "DYNAMIC_QUANT_CPU_ACCELERATED_RECIPE" | ||
| ), | ||
| ) | ||
| self.assertTrue( | ||
| torch.allclose( | ||
| session.run_method("forward", example_inputs[0])[0], | ||
| m_eager(*example_inputs[0]), | ||
| atol=1e-1, | ||
| ) | ||
| ) | ||
| self.check_fully_delegated(session.get_executorch_program()) | ||
|
|
||
| def test_8a4w_recipe(self) -> None: | ||
| class SimpleLinearModel(nn.Module): | ||
| def __init__(self) -> None: | ||
| super(SimpleLinearModel, self).__init__() | ||
| self.layer1 = nn.Linear(32, 2) | ||
|
|
||
| def forward(self, x) -> torch.Tensor: | ||
| x = self.layer1(x) | ||
| return x | ||
|
|
||
| model = SimpleLinearModel() | ||
| example_inputs = [(torch.randn(1, 32),)] | ||
| session = export( | ||
| model=model, | ||
| example_inputs=example_inputs, | ||
| export_recipe=get_xnnpack_recipe( | ||
| "8A4W_CPU_ACCELERATED_RECIPE", group_size=32 | ||
| ), | ||
| ) | ||
| self.assertTrue( | ||
| torch.allclose( | ||
| session.run_method("forward", example_inputs[0])[0], | ||
| model(*example_inputs[0]), | ||
| atol=1e-1, | ||
| ) | ||
| ) | ||
| self.check_fully_delegated(session.get_executorch_program()) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if namespaced to XNNPACK then cpu may not be needed?