Skip to content

Commit 959c258

Browse files
author
morelos
committed
Update on "[ET-VK] Migrate off of xnnpack_quantizer_utils"
# Context Eventually as the vulkan_quantizer file expands, we will need to migrate into a custom utils file and stop depending on the xnnpack_quantizer_utils. We migrate only the minimal amount of functions necessary to ensure the vulkan_quantizer works. # Changes We create a new file `vulkan_quantizer_utils.py` and migrate off of `xnnpack_quantizer_utils.py` in `vulkan_quantizer`. There are no specific modifications necessary to work separate from xnnpack utils except bits_to_range to allow not needing to specify the ranges everytime. Differential Revision: [D78290055](https://our.internmc.facebook.com/intern/diff/D78290055/) [ghstack-poisoned]
2 parents c6fca42 + 1fd704c commit 959c258

37 files changed

+2894
-73
lines changed

CMakeLists.txt

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@
4848
cmake_minimum_required(VERSION 3.24)
4949
project(executorch)
5050

51-
# MARK: - Start EXECUTORCH_H12025_BUILD_MIGRATION
52-
5351
include(${PROJECT_SOURCE_DIR}/tools/cmake/common/preset.cmake)
5452
include(${PROJECT_SOURCE_DIR}/tools/cmake/Utils.cmake)
5553
include(CMakeDependentOption)
@@ -82,6 +80,7 @@ announce_configured_options(BUCK2)
8280

8381
announce_configured_options(CMAKE_CXX_COMPILER_ID)
8482
announce_configured_options(CMAKE_TOOLCHAIN_FILE)
83+
announce_configured_options(BUILD_TESTING)
8584

8685
load_build_preset()
8786
include(${PROJECT_SOURCE_DIR}/tools/cmake/preset/default.cmake)
@@ -97,11 +96,6 @@ else()
9796
endif()
9897
announce_configured_options(CCACHE_PROGRAM)
9998

100-
# Print all the configs that were called with announce_configured_options.
101-
print_configured_options()
102-
103-
# MARK: - End EXECUTORCH_H12025_BUILD_MIGRATION
104-
10599
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
106100

107101
# Setup RPATH. See
@@ -750,3 +744,6 @@ if(EXECUTORCH_BUILD_ANDROID_JNI)
750744
endif()
751745

752746
include(Test.cmake)
747+
748+
# Print all the configs that were called with announce_configured_options.
749+
print_configured_options()

CMakePresets.json

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
},
99
{
1010
"name": "macos",
11-
"displayName": "Build everything buildable on macOS",
11+
"displayName": "Build ExecuTorch for macOS",
1212
"inherits": ["common"],
1313
"generator": "Xcode",
1414
"cacheVariables": {
@@ -25,7 +25,7 @@
2525
},
2626
{
2727
"name": "ios",
28-
"displayName": "Build everything buildable on iOS",
28+
"displayName": "Build ExecuTorch for iOS",
2929
"inherits": ["common"],
3030
"generator": "Xcode",
3131
"cacheVariables": {
@@ -42,7 +42,7 @@
4242
},
4343
{
4444
"name": "ios-simulator",
45-
"displayName": "Build everything buildable on iOS simulator",
45+
"displayName": "Build ExecuTorch for iOS Simulator",
4646
"inherits": ["common"],
4747
"generator": "Xcode",
4848
"cacheVariables": {
@@ -59,7 +59,7 @@
5959
},
6060
{
6161
"name": "linux",
62-
"displayName": "Build everything buildable on Linux",
62+
"displayName": "Build ExecuTorch for Linux",
6363
"inherits": ["common"],
6464
"cacheVariables": {
6565
"CMAKE_SYSTEM_NAME": "Linux",
@@ -88,29 +88,21 @@
8888
{
8989
"name": "llm",
9090
"displayName": "Build LLM libraries",
91-
"inherits": [
92-
"common"
93-
],
91+
"inherits": ["common"],
9492
"cacheVariables": {
9593
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/llm.cmake",
9694
"CMAKE_OSX_DEPLOYMENT_TARGET": "12.0"
9795
},
9896
"condition": {
9997
"type": "inList",
10098
"string": "${hostSystemName}",
101-
"list": [
102-
"Darwin",
103-
"Linux",
104-
"Windows"
105-
]
99+
"list": ["Darwin", "Linux", "Windows"]
106100
}
107101
},
108102
{
109103
"name": "zephyr",
110-
"displayName": "Build everything buildable on Zephyr RTOS",
111-
"inherits": [
112-
"common"
113-
],
104+
"displayName": "Build ExecuTorch for Zephyr RTOS",
105+
"inherits": ["common"],
114106
"cacheVariables": {
115107
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/zephyr.cmake",
116108
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/zephyr/x86_64-linux-arm-zephyr-eabi-gcc.cmake"

backends/nxp/nxp_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ def preprocess(
174174
# Otherwise, we get violation that this op is not part of ATen Core ops.
175175
edge_program._verifiers = [
176176
EXIREdgeDialectVerifier(
177-
class_only=True, core_aten_ops_exception_list=[torch.ops.aten.max_pool2d.default]
177+
class_only=True,
178+
core_aten_ops_exception_list=[torch.ops.aten.max_pool2d.default],
178179
)
179180
]
180181

backends/qualcomm/qnn_preprocess.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ def preprocess_multimethod(
178178

179179
if len(py_op_wrapper_list) == len(edge_programs.values()):
180180
qnn_context_binary = qnn_manager.Compile(graph_name, py_op_wrapper_list)
181+
if option.saver:
182+
# TODO: Currently, only the first method is saved. Update this logic if saving multiple methods becomes necessary in the future.
183+
exit(
184+
f"Record all QNN API calls from saver backend at: {option.saver_output_dir}"
185+
)
181186
assert (
182187
len(qnn_context_binary) != 0
183188
), "Failed to generate Qnn context binary."

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3384,6 +3384,38 @@ def test_qnn_backend_rewrite_prepared_observer(self):
33843384
quantized_module = convert_pt2e(prepared)
33853385
self.lower_module_and_test_output(quantized_module, sample_input)
33863386

3387+
def test_qnn_backend_saver_backend(self):
3388+
backend_options = generate_htp_compiler_spec(use_fp16=False)
3389+
TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
3390+
soc_model=self.chipset_table[TestQNN.model],
3391+
backend_options=backend_options,
3392+
saver=True,
3393+
)
3394+
module = Relu() # noqa: F405
3395+
sample_input = (torch.randn([2, 5, 1, 3]),)
3396+
module = self.get_qdq_module(module, sample_input)
3397+
3398+
from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
3399+
flatbuffer_to_option,
3400+
option_to_flatbuffer,
3401+
)
3402+
3403+
with tempfile.TemporaryDirectory() as tmp_dir:
3404+
option = flatbuffer_to_option(TestQNN.compiler_specs[0].value)
3405+
option.saver_output_dir = f"{tmp_dir}/saver_output"
3406+
TestQNN.compiler_specs[0].value = option_to_flatbuffer(option)
3407+
3408+
with self.assertRaises(SystemExit):
3409+
self.lower_module_and_test_output(module, sample_input)
3410+
self.assertTrue(
3411+
os.path.isfile(f"{tmp_dir}/saver_output/params.bin"),
3412+
"failed to find params.bin",
3413+
)
3414+
self.assertTrue(
3415+
os.path.isfile(f"{tmp_dir}/saver_output/saver_output.c"),
3416+
"failed to find saver_output.c",
3417+
)
3418+
33873419
def test_qnn_backend_skip_node_id_partitioner(self):
33883420
module = SimpleModel() # noqa: F405
33893421
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
@@ -5022,6 +5054,40 @@ def test_swin_transformer(self):
50225054
self.assertGreaterEqual(msg["top_1"], 60)
50235055
self.assertGreaterEqual(msg["top_5"], 80)
50245056

5057+
def test_t5(self):
5058+
if not self.required_envs([self.qa_dataset]):
5059+
self.skipTest("missing required envs")
5060+
cmds = [
5061+
"python",
5062+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/t5/t5.py",
5063+
"--dataset",
5064+
self.sentence_dataset,
5065+
"--artifact",
5066+
self.artifact_dir,
5067+
"--build_folder",
5068+
self.build_folder,
5069+
"--device",
5070+
self.device,
5071+
"--model",
5072+
self.model,
5073+
"--ip",
5074+
self.ip,
5075+
"--port",
5076+
str(self.port),
5077+
]
5078+
if self.host:
5079+
cmds.extend(["--host", self.host])
5080+
5081+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
5082+
with Listener((self.ip, self.port)) as listener:
5083+
conn = listener.accept()
5084+
p.communicate()
5085+
msg = json.loads(conn.recv())
5086+
if "Error" in msg:
5087+
self.fail(msg["Error"])
5088+
else:
5089+
self.assertGreaterEqual(msg["f1"], 0.7)
5090+
50255091
def test_whisper(self):
50265092
if not self.required_envs():
50275093
self.skipTest("missing required envs")

backends/qualcomm/tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ class TestQNN(unittest.TestCase):
183183
executorch_root: str = ""
184184
artifact_dir: str = ""
185185
image_dataset: str = ""
186+
qa_dataset: str = ""
186187
sentence_dataset: str = ""
187188
pretrained_weight: str = ""
188189
enable_profile: bool = False

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,7 @@ def register_transfer_ops(features: OpFeatures):
688688
exir_ops.edge.aten.full_like.default,
689689
exir_ops.edge.aten.ones.default,
690690
exir_ops.edge.aten.ones_like.default,
691+
exir_ops.edge.aten.scalar_tensor.default,
691692
exir_ops.edge.aten.upsample_nearest2d.vec,
692693
exir_ops.edge.aten.upsample_bilinear2d.vec,
693694
exir_ops.edge.aten.zeros.default,

backends/vulkan/quantizer/TARGETS

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,18 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
33
oncall("executorch")
44

55
python_library(
6-
name = "vulkan_quantizer_utils",
7-
srcs = [
8-
"vulkan_quantizer_utils.py",
9-
],
6+
name = "vulkan_quantizer",
7+
srcs = ["vulkan_quantizer.py"],
108
deps = [
9+
":vulkan_quantizer_utils",
1110
"//caffe2:torch",
12-
"//pytorch/ao/torchao/quantization/pt2e:quantizer",
13-
"//pytorch/ao/torchao/quantization/pt2e:utils",
1411
],
1512
)
1613

1714
python_library(
18-
name = "vulkan_quantizer",
19-
srcs = [
20-
"vulkan_quantizer.py",
21-
],
15+
name = "vulkan_quantizer_utils",
16+
srcs = ["vulkan_quantizer_utils.py"],
2217
deps = [
23-
":vulkan_quantizer_utils",
2418
"//caffe2:torch",
25-
"//pytorch/ao/torchao/quantization/pt2e:quantizer",
2619
],
2720
)

backends/vulkan/quantizer/vulkan_quantizer_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Callable, Optional
9+
from typing import Callable, Optional, Tuple
1010

1111
import torch
1212
from torch.fx import Node
@@ -48,7 +48,7 @@ def decorator(annotator: AnnotatorType) -> None:
4848
return decorator
4949

5050

51-
def _is_annotated(nodes: list[Node]):
51+
def _is_annotated(nodes: list[Node]) -> bool:
5252
"""
5353
Given a list of nodes (that represents an operator pattern),
5454
check if any of the node is annotated, return True if any of the node
@@ -63,7 +63,7 @@ def _is_annotated(nodes: list[Node]):
6363
return annotated
6464

6565

66-
def _mark_nodes_as_annotated(nodes: list[Node]):
66+
def _mark_nodes_as_annotated(nodes: list[Node]) -> None:
6767
for node in nodes:
6868
if node is not None:
6969
if "quantization_annotation" not in node.meta:
@@ -119,7 +119,7 @@ def _annotate_linear(
119119
return annotated_partitions
120120

121121

122-
def _is_share_obs_or_fq_op(op: Callable) -> bool:
122+
def _is_share_obs_or_fq_op(op: Callable[..., torch.Tensor]) -> bool:
123123
return op in [
124124
torch.ops.aten.relu.default,
125125
torch.ops.aten.hardtanh.default,

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,14 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
273273
return val.toConstTensor().dtype();
274274
} else if (val.isTensorRef()) {
275275
return val.toConstTensorRef().dtype;
276+
} else if (val.isBool()) {
277+
return vkapi::ScalarType::Bool;
278+
} else if (val.isDouble()) {
279+
// We downcast anyway in the shader and we want to avoid having to
280+
// write special cases there.
281+
return vkapi::ScalarType::Float;
282+
} else if (val.isInt()) {
283+
return vkapi::ScalarType::Int;
276284
}
277285
VK_THROW("Could not get dtype of value with type ", val.type());
278286
}

0 commit comments

Comments
 (0)