Skip to content

Commit e2e6bf2

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] torchao.dequantize_affine vulkan impl and shader and cleanup"
# Changes * Implement `torchao.dequantize_affine` operator in Vulkan backend with comprehensive texture and buffer storage support * Add block-wise dequantization mode in `dequantize_texture.glsl` and `dequantize_buffer.glsl` shaders for configurable tensor block dequantization * Extend dequantization infrastructure in `Dequantize.cpp` to handle affine transformations with configurable block sizes and quantization parameters * Support integer-to-floating-point conversion with precise reconstruction of original values BE: Improved the documentation in the shader logic which is more detailed and clear # Motivation The existing Vulkan quantization infrastructure lacked support for the `torchao.dequantize_affine` operator, which is essential for completing the quantization-dequantization cycle in dynamic quantization workflows. The `dequantize_affine` operator provides flexible block-wise dequantization that reconstructs floating-point values from quantized integer blocks, enabling: * **Block-wise Dequantization**: Reconstructs floating-point values from configurable tensor blocks using separate scale and zero-point parameters, enabling precise recovery of original data distributions * **Affine Transformation**: Uses the formula `value = (qvalue - zero_point) * scale` for accurate integer-to-floating-point mapping * **TorchAO Integration**: Seamless compatibility with TorchAO quantization workflows and completes the quantization-dequantization round-trip # Operator Description The `dequantize_affine` operator converts n-bit integer tensor values back to floating-point representations using pre-computed quantization parameters (scale and zero_point) applied to configurable tensor blocks. Block-wise dequantization divides tensors into blocks and applies separate dequantization parameters to each block, allowing fine-grained reconstruction of the original floating-point precision. The dequantization formula is: `value = (qvalue - zero_point) * scale` **Storage Requirements**: Scale and zero_point tensors must use buffer storage with width-packed layout. Input/output tensors support both buffer and texture storage with standard axis mapping. Input tensors must be integer types (kByte, kChar, kInt). # Block-wise Dequantization Implementation Block-wise dequantization enables fine-grained reconstruction by dividing tensors into blocks and applying separate dequantization parameters to each block. The implementation uses the same key data structures computed in `Dequantize.cpp`: * **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks) * **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()` * **`num_blocks_vec`**: Number of blocks per dimension calculated as `tensor_size_whcn / block_size_vec` * **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation The block coordinate calculation uses: `bcoord = tidx / blockSize` where `tidx` is the tensor coordinate in WHCN layout, then the linear block ID is computed as: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w` # Shader Algorithm Overview ## Texture Storage Implementation (`dequantize_texture.glsl`) **Workgroup Configuration**: - **Global WG Size**: Default sizing based on texture dimensions - **Local WG Size**: Default with special handling for batch dimension dequantization (Z dimension set to 1 for proper workgroup dispatching when `global_workgroup_size[2] > 1`) **Block-wise Mode Algorithm**: The shader processes 3D texture positions where each position represents a texel containing 4 width-packed integer components. For each texel at position `pos`, it calculates a base tensor index `base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0)` to account for width-packing. For each of the 4 components in the texel, it computes the actual tensor coordinate: `tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total))` where `foldedZ = pos.z` handles batch-channel folding in 4D tensors and `C_total = numBlocks.z * blockSize.z` represents the total channel dimension. The block coordinate is calculated using integer division: `bcoord = tidx / blockSize`, then the linear block ID uses pre-computed strides: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`. Each integer component is dequantized using its corresponding block's parameters: `value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id])` where `dequantize_val()` applies the formula `(qvalue - zero_point) * scale`. The reconstructed floating-point values are written to the output texel with proper type handling for double precision outputs. ## Buffer Storage Implementation (`dequantize_buffer.glsl`) **Workgroup Configuration**: - **Global WG Size**: Default sizing based on buffer element count - **Local WG Size**: Default sizing without special constraints **Block-wise Mode Algorithm**: The shader processes linear buffer indices using `gl_GlobalInvocationID.x` as the output buffer index. It converts this to tensor coordinates using `bufi_to_tidx(out_bufi, t_out_strides, out_dim_order)` which handles the buffer-to-tensor index mapping with proper stride calculations. For each element, it computes the block coordinate directly: `bcoord = out_tidx / blockSize` where `out_tidx` is the 4D tensor coordinate in WHCN layout. The linear block ID calculation uses the same pre-computed stride approach: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`. The quantized integer value is loaded using the corresponding input buffer index: `qvalue = t_in[in_bufi]` where `in_bufi = tidx_to_bufi(out_tidx, t_in_strides)`. Dequantization applies the block-specific parameters: `value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id])` to reconstruct the original floating-point value. **Future Improvements**: Dynamic workgroup sizing based on block dimensions Differential Revision: [D78435552](https://our.internmc.facebook.com/intern/diff/D78435552/) cc SS-JIA manuelcandales cbilgin [ghstack-poisoned]
2 parents d691080 + 2978cb5 commit e2e6bf2

File tree

547 files changed

+28002
-6226
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

547 files changed

+28002
-6226
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

.ci/docker/build.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77

88
set -exu
99

10-
IMAGE_NAME="$1"
10+
FULL_IMAGE_NAME="$1"
1111
shift
1212

13+
IMAGE_NAME=$(echo "${FULL_IMAGE_NAME}" | sed 's/ci-image://')
14+
1315
echo "Building ${IMAGE_NAME} Docker image"
1416

1517
OS=ubuntu
@@ -41,6 +43,10 @@ case "${IMAGE_NAME}" in
4143
ARM_SDK=yes
4244
CLANG_VERSION=12
4345
;;
46+
executorch-ubuntu-22.04-zephyr-sdk)
47+
ZEPHYR_SDK=yes
48+
GCC_VERSION=11
49+
;;
4450
executorch-ubuntu-22.04-qnn-sdk)
4551
QNN_SDK=yes
4652
CLANG_VERSION=12
@@ -85,6 +91,7 @@ docker build \
8591
--build-arg "LINTRUNNER=${LINTRUNNER:-}" \
8692
--build-arg "BUILD_DOCS=${BUILD_DOCS}" \
8793
--build-arg "ARM_SDK=${ARM_SDK:-}" \
94+
--build-arg "ZEPHYR_SDK=${ZEPHYR_SDK:-}" \
8895
--build-arg "QNN_SDK=${QNN_SDK:-}" \
8996
--build-arg "MEDIATEK_SDK=${MEDIATEK_SDK:-}" \
9097
--build-arg "ANDROID_NDK_VERSION=${ANDROID_NDK_VERSION:-}" \
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
a3942627f5ac048e06b4b1d703b0a6a53bf6da5b
1+
eea657ddbdeb1118943a92fb73c289985c3ee1ba
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
90f1e7bed15ca5e48c61c5b6dc5ad4810524f82f
1+
6fc0ad22f0a07b6f38d138861c56a765d5a9bb02
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
2+
#!/bin/bash
3+
# Copyright (c) Meta Platforms, Inc. and affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE file in the root directory of this source tree.
8+
9+
set -ex
10+
11+
# shellcheck source=/dev/null
12+
source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"
13+
14+
# Double check if the NDK version is set
15+
[ -n "${ZEPHYR_SDK}" ]
16+
17+
install_prerequiresites() {
18+
rm /var/lib/dpkg/info/libc-bin.*
19+
apt-get clean
20+
apt-get -y update
21+
apt-get install -y libc-bin
22+
apt-get -y update
23+
apt-get clean
24+
apt-get install --no-install-recommends -y dos2unix
25+
apt-get install --no-install-recommends -y ca-certificates
26+
apt-get install -y --reinstall libc-bin
27+
apt-get install --no-install-recommends -y file
28+
apt-get install --no-install-recommends -y locales
29+
apt-get install --no-install-recommends -y git
30+
apt-get install --no-install-recommends -y build-essential
31+
apt-get install --no-install-recommends -y cmake
32+
apt-get install --no-install-recommends -y ninja-build gperf
33+
apt-get install --no-install-recommends -y device-tree-compiler
34+
apt-get install --no-install-recommends -y wget
35+
apt-get install --no-install-recommends -y curl
36+
apt-get install --no-install-recommends -y xz-utils
37+
apt-get install --no-install-recommends -y dos2unix
38+
apt-get install --no-install-recommends -y vim
39+
apt-get install --no-install-recommends -y nano
40+
apt-get install --no-install-recommends -y mc
41+
apt-get install --no-install-recommends -y openssh-server
42+
apt-get install -y gdb
43+
44+
# Zephyr SDK relies on python 3.12
45+
apt install software-properties-common -y
46+
add-apt-repository ppa:deadsnakes/ppa -y
47+
apt update
48+
apt install -y python3.12 python3.12-dev python3.12-venv python3-pip
49+
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
50+
51+
# Upgrade cmake ot 3.24
52+
apt update
53+
apt install cmake
54+
apt install software-properties-common lsb-release
55+
apt update
56+
test -f /usr/share/doc/kitware-archive-keyring/copyright || \
57+
wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null
58+
"deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/kitware.list > /dev/null
59+
apt update
60+
apt install cmake
61+
62+
# Install additional required software for Zephyr
63+
apt install --no-install-recommends -y ccache \
64+
dfu-util \
65+
python3-setuptools \
66+
python3-tk \
67+
python3-wheel \
68+
make \
69+
gcc \
70+
libsdl2-dev \
71+
libmagic1 \
72+
xterm \
73+
telnet \
74+
net-tools
75+
apt install --no-install-recommends -y gcc-multilib g++-multilib
76+
apt-get clean -y
77+
apt-get autoremove --purge -y
78+
rm -rf /var/lib/apt/lists/*
79+
wget https://apt.kitware.com/kitware-archive.sh && \
80+
chmod +x kitware-archive.sh && \
81+
./kitware-archive.sh && \
82+
rm -f kitware-archive.sh
83+
pip_install --no-cache-dir west
84+
pip_install pyelftools
85+
}
86+
87+
install_prerequiresites

.ci/docker/ubuntu/Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ RUN rm install_android.sh
8484

8585
ARG ARM_SDK
8686

87+
ARG ZEPHYR_SDK
88+
COPY ./common/install_zephyr.sh install_zephyr.sh
89+
COPY ./common/utils.sh utils.sh
90+
RUN if [ -n "${ZEPHYR_SDK}" ]; then bash ./install_zephyr.sh; fi
91+
RUN rm install_zephyr.sh utils.sh
92+
8793
ARG QNN_SDK
8894

8995
ARG MEDIATEK_SDK

.ci/scripts/build-qnn-sdk.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ set_up_aot() {
3333
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
3434
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
3535
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
36+
-DEXECUTORCH_BUILD_EXTENSION_EXTENSION_LLM=ON \
37+
-DEXECUTORCH_BUILD_EXTENSION_EXTENSION_LLM_RUNNER=ON \
3638
-DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \
3739
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
3840
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \

.ci/scripts/build_llama_android.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ install_executorch_and_backend_lib() {
1919
echo "Installing executorch and xnnpack backend"
2020
clean_executorch_install_folders
2121
mkdir cmake-android-out
22-
ANDROID_NDK=/opt/ndk
22+
ANDROID_NDK=${ANDROID_NDK:-/opt/ndk}
2323
BUCK2=buck2
2424
ANDROID_ABI=arm64-v8a
2525
cmake --preset llm \

.ci/scripts/gather_benchmark_configs.py

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,65 @@
1616
from examples.models import MODEL_NAME_TO_MODEL
1717

1818

19-
# Device pools for AWS Device Farm
19+
DEVICE_POOLS_REGEX = re.compile(r"(?P<device_name>[^\+]+)\+(?P<variant>[^\+]+)")
20+
# Device pools for AWS Device Farm. Initially, I choose to distribute models to these pool
21+
# round-robin for simplicity. For public pool, only one per device type is needed because
22+
# AWS will scale the number of devices there for us. However, for private pool, we need to
23+
# manually maintain multiple pools of the same device to evenly distribute models there.
24+
# The pool ARNs are extracted from the output of the following command:
25+
# aws devicefarm list-device-pools \
26+
# --arn arn:aws:devicefarm:us-west-2:308535385114:project:02a2cf0f-6d9b-45ee-ba1a-a086587469e6 \
27+
# --region us-west-2
2028
DEVICE_POOLS = {
21-
"apple_iphone_15": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/3b5acd2e-92e2-4778-b651-7726bafe129d",
22-
"apple_iphone_15+ios_18": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/12c8b15c-8d03-4e07-950d-0a627e7595b4",
23-
"samsung_galaxy_s22": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa",
24-
"samsung_galaxy_s22_private": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/ea6b049d-1508-4233-9a56-5d9eacbe1078",
25-
"samsung_galaxy_s24": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98f8788c-2e25-4a3c-8bb2-0d1e8897c0db",
26-
"google_pixel_8_pro": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/d65096ab-900b-4521-be8b-a3619b69236a",
27-
"google_pixel_3_private_rooted": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98d23ca8-ea9e-4fb7-b725-d402017b198d",
28-
"apple_iphone_15_private": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/55929353-2f28-4ee5-bdff-d1a95f58cb28",
29+
"apple_iphone_15": {
30+
"public": [
31+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/3b5acd2e-92e2-4778-b651-7726bafe129d",
32+
],
33+
"ios_18_public": [
34+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/12c8b15c-8d03-4e07-950d-0a627e7595b4",
35+
],
36+
"private": [
37+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/55929353-2f28-4ee5-bdff-d1a95f58cb28",
38+
],
39+
"plus_private": [
40+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/767bfb3e-a00e-4d92-998b-4eafdcf7213b",
41+
],
42+
"pro_private": [
43+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/1394f34c-2981-4c55-aaa2-246871ac713b",
44+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/099e8def-4609-4383-8787-76b88e500c1d",
45+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/d6707270-b009-479e-a83a-7bdb255f9de5",
46+
],
47+
},
48+
"samsung_galaxy_s22": {
49+
"public": [
50+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa",
51+
],
52+
"private": [
53+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/ea6b049d-1508-4233-9a56-5d9eacbe1078",
54+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/1fa924a1-5aff-475b-8f4d-f7c6d8de4fe9",
55+
],
56+
"ultra_private": [
57+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/5f79d72e-e229-4f9c-962f-5d37196fcfe7",
58+
],
59+
},
60+
"samsung_galaxy_s24": {
61+
"public": [
62+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98f8788c-2e25-4a3c-8bb2-0d1e8897c0db",
63+
],
64+
"ultra_private": [
65+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/5f79d72e-e229-4f9c-962f-5d37196fcfe7",
66+
],
67+
},
68+
"google_pixel_8": {
69+
"pro_public": [
70+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/d65096ab-900b-4521-be8b-a3619b69236a",
71+
],
72+
},
73+
"google_pixel_3": {
74+
"rooted_private": [
75+
"arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98d23ca8-ea9e-4fb7-b725-d402017b198d",
76+
],
77+
},
2978
}
3079

3180
# Predefined benchmark configurations
@@ -318,25 +367,56 @@ def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
318367

319368
# Add configurations for each valid device
320369
for device in devices:
370+
# Parse the device name
371+
m = re.match(DEVICE_POOLS_REGEX, device)
372+
if not m:
373+
logging.warning(
374+
f"Invalid device name: {device} is not in DEVICE_NAME+VARIANT format. Skipping."
375+
)
376+
continue
377+
378+
device_name = m.group("device_name")
379+
variant = m.group("variant")
380+
381+
if device_name not in DEVICE_POOLS:
382+
logging.warning(f"Unsupported device '{device}'. Skipping.")
383+
continue
384+
385+
if variant not in DEVICE_POOLS[device_name]:
386+
logging.warning(
387+
f"Unsupported {device}'s variant '{variant}'. Skipping."
388+
)
389+
continue
390+
391+
device_pool_count = len(DEVICE_POOLS[device_name][variant])
392+
if not device_pool_count:
393+
logging.warning(
394+
f"No device pool defined for {device}'s variant '{variant}'. Skipping."
395+
)
396+
continue
397+
398+
device_pool_index = 0
321399
for config in configs:
322-
if config == "llama3_coreml_ane" and not device.endswith("+ios_18"):
323-
device = f"{device}+ios_18"
400+
if config == "llama3_coreml_ane" and "ios_18" not in variant:
401+
variant = "ios_18_public"
324402
logging.info(
325-
f"Benchmark config '{config}' only works on iOS 18+, auto-upgraded device pool to '{device}'"
403+
f"Benchmark config '{config}' only works on iOS 18+, auto-upgraded device variant to '{variant}'"
326404
)
327405

328-
if device not in DEVICE_POOLS:
329-
logging.warning(f"Unsupported device '{device}'. Skipping.")
330-
continue
331-
332406
record = {
333407
"model": model_name,
334408
"config": config,
335-
"device_name": device,
336-
"device_arn": DEVICE_POOLS[device],
409+
"device_name": device_name,
410+
"variant": variant,
411+
"device_arn": DEVICE_POOLS[device_name][variant][
412+
device_pool_index % device_pool_count
413+
],
337414
}
338415
benchmark_configs["include"].append(record)
339416

417+
# Distribute configs to pools of the same device round-robin
418+
device_pool_index += 1
419+
340420
set_output("benchmark_configs", json.dumps(benchmark_configs))
341421

342422

0 commit comments

Comments
 (0)