Skip to content

Commit ef6bd56

Browse files
author
morelos
committed
Update base for Update on "[ET-VK] Creating get_symmetric_quantization_config"
# Context Eventually dynamic quantization will be enabled in the vulkan_quantizer (particularly 8bit dyn act with 8bit weights). In order to enable this functionality we need to utilize a similar method as XNNPack with how they define their quantization config. This diff aims to align with XNNPack quantizer logic and also migrate away from utilizing the old static quantization config logic. # Changes A few noticable changes is that we migrate away from `get_linear_weight_only_qcs_xnn_qconfig`, and we now define a symmetric config that has parameters to define whether it's dynamically quantized or not. Furthermore, we also incorporate bits_to_range so that we can automatically designate the min and max quant ranges without having to set them during initialization. We also change some wording from using just static as we are now enabling dynamic quantization as well. Furthermore, we change internally other codebases that are calling our existing legacy config, and move them into the more universal symmetric config. Since this follows the same naming scheme as XNNPack, I have decided to just add aliases in cases where its being imported directly along with XNNPack. Differential Revision: [D78291249](https://our.internmc.facebook.com/intern/diff/D78291249/) [ghstack-poisoned]
2 parents c6fca42 + f35de65 commit ef6bd56

37 files changed

+2894
-73
lines changed

CMakeLists.txt

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@
4848
cmake_minimum_required(VERSION 3.24)
4949
project(executorch)
5050

51-
# MARK: - Start EXECUTORCH_H12025_BUILD_MIGRATION
52-
5351
include(${PROJECT_SOURCE_DIR}/tools/cmake/common/preset.cmake)
5452
include(${PROJECT_SOURCE_DIR}/tools/cmake/Utils.cmake)
5553
include(CMakeDependentOption)
@@ -82,6 +80,7 @@ announce_configured_options(BUCK2)
8280

8381
announce_configured_options(CMAKE_CXX_COMPILER_ID)
8482
announce_configured_options(CMAKE_TOOLCHAIN_FILE)
83+
announce_configured_options(BUILD_TESTING)
8584

8685
load_build_preset()
8786
include(${PROJECT_SOURCE_DIR}/tools/cmake/preset/default.cmake)
@@ -97,11 +96,6 @@ else()
9796
endif()
9897
announce_configured_options(CCACHE_PROGRAM)
9998

100-
# Print all the configs that were called with announce_configured_options.
101-
print_configured_options()
102-
103-
# MARK: - End EXECUTORCH_H12025_BUILD_MIGRATION
104-
10599
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
106100

107101
# Setup RPATH. See
@@ -750,3 +744,6 @@ if(EXECUTORCH_BUILD_ANDROID_JNI)
750744
endif()
751745

752746
include(Test.cmake)
747+
748+
# Print all the configs that were called with announce_configured_options.
749+
print_configured_options()

CMakePresets.json

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
},
99
{
1010
"name": "macos",
11-
"displayName": "Build everything buildable on macOS",
11+
"displayName": "Build ExecuTorch for macOS",
1212
"inherits": ["common"],
1313
"generator": "Xcode",
1414
"cacheVariables": {
@@ -25,7 +25,7 @@
2525
},
2626
{
2727
"name": "ios",
28-
"displayName": "Build everything buildable on iOS",
28+
"displayName": "Build ExecuTorch for iOS",
2929
"inherits": ["common"],
3030
"generator": "Xcode",
3131
"cacheVariables": {
@@ -42,7 +42,7 @@
4242
},
4343
{
4444
"name": "ios-simulator",
45-
"displayName": "Build everything buildable on iOS simulator",
45+
"displayName": "Build ExecuTorch for iOS Simulator",
4646
"inherits": ["common"],
4747
"generator": "Xcode",
4848
"cacheVariables": {
@@ -59,7 +59,7 @@
5959
},
6060
{
6161
"name": "linux",
62-
"displayName": "Build everything buildable on Linux",
62+
"displayName": "Build ExecuTorch for Linux",
6363
"inherits": ["common"],
6464
"cacheVariables": {
6565
"CMAKE_SYSTEM_NAME": "Linux",
@@ -88,29 +88,21 @@
8888
{
8989
"name": "llm",
9090
"displayName": "Build LLM libraries",
91-
"inherits": [
92-
"common"
93-
],
91+
"inherits": ["common"],
9492
"cacheVariables": {
9593
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/llm.cmake",
9694
"CMAKE_OSX_DEPLOYMENT_TARGET": "12.0"
9795
},
9896
"condition": {
9997
"type": "inList",
10098
"string": "${hostSystemName}",
101-
"list": [
102-
"Darwin",
103-
"Linux",
104-
"Windows"
105-
]
99+
"list": ["Darwin", "Linux", "Windows"]
106100
}
107101
},
108102
{
109103
"name": "zephyr",
110-
"displayName": "Build everything buildable on Zephyr RTOS",
111-
"inherits": [
112-
"common"
113-
],
104+
"displayName": "Build ExecuTorch for Zephyr RTOS",
105+
"inherits": ["common"],
114106
"cacheVariables": {
115107
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/zephyr.cmake",
116108
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/zephyr/x86_64-linux-arm-zephyr-eabi-gcc.cmake"

backends/nxp/nxp_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ def preprocess(
174174
# Otherwise, we get violation that this op is not part of ATen Core ops.
175175
edge_program._verifiers = [
176176
EXIREdgeDialectVerifier(
177-
class_only=True, core_aten_ops_exception_list=[torch.ops.aten.max_pool2d.default]
177+
class_only=True,
178+
core_aten_ops_exception_list=[torch.ops.aten.max_pool2d.default],
178179
)
179180
]
180181

backends/qualcomm/qnn_preprocess.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ def preprocess_multimethod(
178178

179179
if len(py_op_wrapper_list) == len(edge_programs.values()):
180180
qnn_context_binary = qnn_manager.Compile(graph_name, py_op_wrapper_list)
181+
if option.saver:
182+
# TODO: Currently, only the first method is saved. Update this logic if saving multiple methods becomes necessary in the future.
183+
exit(
184+
f"Record all QNN API calls from saver backend at: {option.saver_output_dir}"
185+
)
181186
assert (
182187
len(qnn_context_binary) != 0
183188
), "Failed to generate Qnn context binary."

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3384,6 +3384,38 @@ def test_qnn_backend_rewrite_prepared_observer(self):
33843384
quantized_module = convert_pt2e(prepared)
33853385
self.lower_module_and_test_output(quantized_module, sample_input)
33863386

3387+
def test_qnn_backend_saver_backend(self):
3388+
backend_options = generate_htp_compiler_spec(use_fp16=False)
3389+
TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
3390+
soc_model=self.chipset_table[TestQNN.model],
3391+
backend_options=backend_options,
3392+
saver=True,
3393+
)
3394+
module = Relu() # noqa: F405
3395+
sample_input = (torch.randn([2, 5, 1, 3]),)
3396+
module = self.get_qdq_module(module, sample_input)
3397+
3398+
from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
3399+
flatbuffer_to_option,
3400+
option_to_flatbuffer,
3401+
)
3402+
3403+
with tempfile.TemporaryDirectory() as tmp_dir:
3404+
option = flatbuffer_to_option(TestQNN.compiler_specs[0].value)
3405+
option.saver_output_dir = f"{tmp_dir}/saver_output"
3406+
TestQNN.compiler_specs[0].value = option_to_flatbuffer(option)
3407+
3408+
with self.assertRaises(SystemExit):
3409+
self.lower_module_and_test_output(module, sample_input)
3410+
self.assertTrue(
3411+
os.path.isfile(f"{tmp_dir}/saver_output/params.bin"),
3412+
"failed to find params.bin",
3413+
)
3414+
self.assertTrue(
3415+
os.path.isfile(f"{tmp_dir}/saver_output/saver_output.c"),
3416+
"failed to find saver_output.c",
3417+
)
3418+
33873419
def test_qnn_backend_skip_node_id_partitioner(self):
33883420
module = SimpleModel() # noqa: F405
33893421
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
@@ -5022,6 +5054,40 @@ def test_swin_transformer(self):
50225054
self.assertGreaterEqual(msg["top_1"], 60)
50235055
self.assertGreaterEqual(msg["top_5"], 80)
50245056

5057+
def test_t5(self):
5058+
if not self.required_envs([self.qa_dataset]):
5059+
self.skipTest("missing required envs")
5060+
cmds = [
5061+
"python",
5062+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/t5/t5.py",
5063+
"--dataset",
5064+
self.sentence_dataset,
5065+
"--artifact",
5066+
self.artifact_dir,
5067+
"--build_folder",
5068+
self.build_folder,
5069+
"--device",
5070+
self.device,
5071+
"--model",
5072+
self.model,
5073+
"--ip",
5074+
self.ip,
5075+
"--port",
5076+
str(self.port),
5077+
]
5078+
if self.host:
5079+
cmds.extend(["--host", self.host])
5080+
5081+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
5082+
with Listener((self.ip, self.port)) as listener:
5083+
conn = listener.accept()
5084+
p.communicate()
5085+
msg = json.loads(conn.recv())
5086+
if "Error" in msg:
5087+
self.fail(msg["Error"])
5088+
else:
5089+
self.assertGreaterEqual(msg["f1"], 0.7)
5090+
50255091
def test_whisper(self):
50265092
if not self.required_envs():
50275093
self.skipTest("missing required envs")

backends/qualcomm/tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ class TestQNN(unittest.TestCase):
183183
executorch_root: str = ""
184184
artifact_dir: str = ""
185185
image_dataset: str = ""
186+
qa_dataset: str = ""
186187
sentence_dataset: str = ""
187188
pretrained_weight: str = ""
188189
enable_profile: bool = False

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,7 @@ def register_transfer_ops(features: OpFeatures):
688688
exir_ops.edge.aten.full_like.default,
689689
exir_ops.edge.aten.ones.default,
690690
exir_ops.edge.aten.ones_like.default,
691+
exir_ops.edge.aten.scalar_tensor.default,
691692
exir_ops.edge.aten.upsample_nearest2d.vec,
692693
exir_ops.edge.aten.upsample_bilinear2d.vec,
693694
exir_ops.edge.aten.zeros.default,

backends/vulkan/quantizer/TARGETS

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,18 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
33
oncall("executorch")
44

55
python_library(
6-
name = "vulkan_quantizer_utils",
7-
srcs = [
8-
"vulkan_quantizer_utils.py",
9-
],
6+
name = "vulkan_quantizer",
7+
srcs = ["vulkan_quantizer.py"],
108
deps = [
9+
":vulkan_quantizer_utils",
1110
"//caffe2:torch",
12-
"//pytorch/ao/torchao/quantization/pt2e:quantizer",
13-
"//pytorch/ao/torchao/quantization/pt2e:utils",
1411
],
1512
)
1613

1714
python_library(
18-
name = "vulkan_quantizer",
19-
srcs = [
20-
"vulkan_quantizer.py",
21-
],
15+
name = "vulkan_quantizer_utils",
16+
srcs = ["vulkan_quantizer_utils.py"],
2217
deps = [
23-
":vulkan_quantizer_utils",
2418
"//caffe2:torch",
25-
"//pytorch/ao/torchao/quantization/pt2e:quantizer",
2619
],
2720
)

backends/vulkan/quantizer/vulkan_quantizer_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Callable, Optional
9+
from typing import Callable, Optional, Tuple
1010

1111
import torch
1212
from torch.fx import Node
@@ -48,7 +48,7 @@ def decorator(annotator: AnnotatorType) -> None:
4848
return decorator
4949

5050

51-
def _is_annotated(nodes: list[Node]):
51+
def _is_annotated(nodes: list[Node]) -> bool:
5252
"""
5353
Given a list of nodes (that represents an operator pattern),
5454
check if any of the node is annotated, return True if any of the node
@@ -63,7 +63,7 @@ def _is_annotated(nodes: list[Node]):
6363
return annotated
6464

6565

66-
def _mark_nodes_as_annotated(nodes: list[Node]):
66+
def _mark_nodes_as_annotated(nodes: list[Node]) -> None:
6767
for node in nodes:
6868
if node is not None:
6969
if "quantization_annotation" not in node.meta:
@@ -119,7 +119,7 @@ def _annotate_linear(
119119
return annotated_partitions
120120

121121

122-
def _is_share_obs_or_fq_op(op: Callable) -> bool:
122+
def _is_share_obs_or_fq_op(op: Callable[..., torch.Tensor]) -> bool:
123123
return op in [
124124
torch.ops.aten.relu.default,
125125
torch.ops.aten.hardtanh.default,

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,14 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
273273
return val.toConstTensor().dtype();
274274
} else if (val.isTensorRef()) {
275275
return val.toConstTensorRef().dtype;
276+
} else if (val.isBool()) {
277+
return vkapi::ScalarType::Bool;
278+
} else if (val.isDouble()) {
279+
// We downcast anyway in the shader and we want to avoid having to
280+
// write special cases there.
281+
return vkapi::ScalarType::Float;
282+
} else if (val.isInt()) {
283+
return vkapi::ScalarType::Int;
276284
}
277285
VK_THROW("Could not get dtype of value with type ", val.type());
278286
}

0 commit comments

Comments
 (0)