Skip to content

Commit ef5e756

Browse files
author
morelos
committed
[ET-VK][Ops] linear_qta8a_qga4w_qta8o test framework
# Context This test framework establishes the foundation for validating the `linear_qta8a_qga4w_qta8o` operator implementation as part of enabling dynamic quantization. The motivation stems from advancing beyond weight-only quantization to full activation and weight quantized linear operations, enabling true integer arithmetic throughout the matrix multiplication process for improved performance on GPU hardware. The current weight-only quantized linear implementations in ET-VK dequantize weights to floating point before computation, missing the performance benefits of integer arithmetic. This operator nomenclature breakdown: - **qta8a**: Quantized per-token affine 8-bit activation inputs - **qga4w**: Quantized per-group affine 4-bit weights - **qta8o**: Quantized per-token affine 8-bit outputs # Changes The reference implementation (`linear_qta8a_qga4w_qta8o_4bit_dequant_impl`) provides a baseline for validating the GPU shader implementation through a deliberately simplified computation path. The quantized int8 input tensor is dequantized using the standard affine transformation `(quantized_input.to(at::kFloat) - input_zero_point) * input_scale`. After dequantization, the implementation performs standard floating point linear operation `at::linear(x_float, weights_dequantized)`, then manually quantizes the result using `at::round(linear_result / output_scale) + output_zero_point` with clamping to the int8 range [-128,127]. This two-stage approach of dequantize → compute → quantize provides a clear reference against which the GPU's integer arithmetic implementation can be validated. Differential Revision: [D77173442](https://our.internmc.facebook.com/intern/diff/D77173442/) ghstack-source-id: 292879819 Pull Request resolved: #12005
1 parent 85cf6ce commit ef5e756

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <ATen/ATen.h>
12+
13+
#include <executorch/backends/vulkan/runtime/api/api.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16+
17+
#include "test_utils.h"
18+
19+
#include <cassert>
20+
21+
at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) {
22+
std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
23+
weights_shape[1] *= 2;
24+
25+
at::Tensor weights_unpacked =
26+
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kInt));
27+
28+
const int64_t N = weights_unpacked.size(0);
29+
const int64_t K = weights_unpacked.size(1);
30+
31+
for (int n = 0; n < N; n++) {
32+
for (int k = 0; k < K; k += 2) {
33+
const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
34+
const uint8_t second_val = packed_val & 0x0F;
35+
const uint8_t first_val = (packed_val & 0xF0) >> 4;
36+
37+
weights_unpacked[n][k] = int(first_val);
38+
weights_unpacked[n][k + 1] = int(second_val);
39+
}
40+
}
41+
42+
return weights_unpacked;
43+
}
44+
45+
at::Tensor linear_qta8a_qga4w_qta8o_4bit_dequant_impl(
46+
const at::Tensor& quantized_input,
47+
const at::Tensor& input_scale,
48+
const at::Tensor& input_zero_point,
49+
const at::Tensor& weights_4x2,
50+
const int64_t group_size,
51+
const at::Tensor& weight_scales_and_zeros,
52+
const at::Tensor& output_scale,
53+
const at::Tensor& output_zero_point) {
54+
// Calculate number of input tokens
55+
int64_t input_num_tokens = 1;
56+
for (size_t i = 0; i < quantized_input.sizes().size() - 1; i++) {
57+
input_num_tokens *= quantized_input.size(i);
58+
}
59+
60+
// Manually dequantize the char tensor using per-token quantization
61+
at::Tensor x_float = at::zeros_like(quantized_input, at::kFloat);
62+
63+
// Apply per-token dequantization
64+
auto input_accessor = quantized_input.accessor<int8_t, 3>();
65+
auto output_accessor = x_float.accessor<float, 3>();
66+
67+
for (int64_t token_idx = 0; token_idx < input_num_tokens; token_idx++) {
68+
float scale_val = input_scale[token_idx].item<float>();
69+
int zero_point_val = input_zero_point[token_idx].item<int>();
70+
71+
// Calculate batch and sequence indices for this token
72+
int64_t b = token_idx / quantized_input.size(1);
73+
int64_t m = token_idx % quantized_input.size(1);
74+
75+
// Apply dequantization for all features in this token
76+
for (int64_t k = 0; k < quantized_input.size(-1); k++) {
77+
float dequant_val =
78+
(input_accessor[b][m][k] - zero_point_val) * scale_val;
79+
output_accessor[b][m][k] = dequant_val;
80+
}
81+
}
82+
83+
std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
84+
weights_shape[1] *= 2;
85+
86+
at::Tensor weights_dequantized =
87+
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat));
88+
89+
const int64_t N = weights_dequantized.size(0);
90+
const int64_t K = weights_dequantized.size(1);
91+
92+
for (int n = 0; n < N; n++) {
93+
for (int k = 0; k < K; k += 2) {
94+
const int group_idx = k / group_size;
95+
const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
96+
const uint8_t second_val = packed_val & 0x0F;
97+
const uint8_t first_val = (packed_val & 0xF0) >> 4;
98+
99+
const float scale =
100+
weight_scales_and_zeros[group_idx][n][0].item().to<float>();
101+
const float zero =
102+
weight_scales_and_zeros[group_idx][n][1].item().to<float>();
103+
104+
weights_dequantized[n][k] = (float(first_val) - 8.0) * scale + zero;
105+
weights_dequantized[n][k + 1] = (float(second_val) - 8.0) * scale + zero;
106+
}
107+
}
108+
109+
at::Tensor linear_result = at::linear(x_float, weights_dequantized);
110+
111+
// Calculate number of output tokens
112+
int64_t output_num_tokens = 1;
113+
for (size_t i = 0; i < linear_result.sizes().size() - 1; i++) {
114+
output_num_tokens *= linear_result.size(i);
115+
}
116+
117+
// Quantize the result manually using per-token quantization
118+
at::Tensor quantized_result = at::zeros_like(linear_result, at::kChar);
119+
120+
// Apply per-token quantization
121+
auto linear_accessor = linear_result.accessor<float, 3>();
122+
auto quant_accessor = quantized_result.accessor<int8_t, 3>();
123+
124+
for (int64_t token_idx = 0; token_idx < output_num_tokens; token_idx++) {
125+
float scale_val = output_scale[token_idx].item<float>();
126+
int zero_point_val = output_zero_point[token_idx].item<int>();
127+
128+
// Calculate batch and sequence indices for this token
129+
int64_t b = token_idx / linear_result.size(1);
130+
int64_t m = token_idx % linear_result.size(1);
131+
132+
// Apply quantization for all features in this token
133+
for (int64_t n = 0; n < linear_result.size(-1); n++) {
134+
float quant_val =
135+
std::round(linear_accessor[b][m][n] / scale_val) + zero_point_val;
136+
quant_val = std::clamp(quant_val, -128.0f, 127.0f);
137+
quant_accessor[b][m][n] = static_cast<int8_t>(quant_val);
138+
}
139+
}
140+
141+
return quantized_result;
142+
}

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ def define_common_targets(is_fbcode = False):
210210
":test_utils",
211211
]
212212
)
213+
define_test_targets(
214+
"linear_qta8a_qga4w_qta8o_test",
215+
extra_deps = [
216+
":test_utils",
217+
]
218+
)
213219
define_test_targets(
214220
"rotary_embedding_test",
215221
extra_deps = [

0 commit comments

Comments
 (0)