Skip to content

Commit 6ec938b

Browse files
author
morelos
committed
[ET-VK][Ops] linear_qta8a_qga4w test framework
Pull Request resolved: #12005 # Context This test framework establishes the foundation for validating the `linear_qta8a_qga4w` operator implementation as part of enabling dynamic quantization. The motivation stems from advancing beyond weight-only quantization to full activation and weight quantized linear operations, enabling true integer arithmetic throughout the matrix multiplication process for improved performance on GPU hardware. The current weight-only quantized linear implementations in ET-VK dequantize weights to floating point before computation, missing the performance benefits of integer arithmetic. This operator nomenclature breakdown: - **qta8a**: Quantized per-token affine 8-bit activation inputs - **qga4w**: Quantized per-group affine 4-bit weights # Changes The reference implementation (`linear_qta8a_qga4w_4bit_dequant_impl`) provides a baseline for validating the GPU shader implementation through a deliberately simplified computation path. The quantized int8 input tensor is dequantized using the standard affine transformation `(quantized_input.to(at::kFloat) - input_zero_point) * input_scale`. After dequantization, the implementation performs standard floating point linear operation `at::linear(x_float, weights_dequantized)`. This two-stage approach of dequantize → compute provides a clear reference against which the GPU's integer arithmetic implementation can be validated. ghstack-source-id: 294197247 @exported-using-ghexport Differential Revision: [D77173442](https://our.internmc.facebook.com/intern/diff/D77173442/)
1 parent a7091bf commit 6ec938b

File tree

2 files changed

+257
-0
lines changed

2 files changed

+257
-0
lines changed
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <ATen/ATen.h>
12+
13+
#include <executorch/backends/vulkan/runtime/api/api.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16+
17+
#include "test_utils.h"
18+
19+
#include <cassert>
20+
#include <cstdint>
21+
#include <iostream>
22+
#include <vector>
23+
24+
class VulkanLinearQTA8AQGA4WTest : public ::testing::Test {
25+
public:
26+
void SetUp() override {
27+
if (!vkcompute::api::context()
28+
->adapter_ptr()
29+
->has_full_int8_buffers_support()) {
30+
GTEST_SKIP();
31+
}
32+
}
33+
34+
void TearDown() override {
35+
// Clean up any resources if needed
36+
}
37+
};
38+
39+
at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) {
40+
std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
41+
weights_shape[1] *= 2;
42+
43+
at::Tensor weights_unpacked =
44+
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kInt));
45+
46+
const int64_t N = weights_unpacked.size(0);
47+
const int64_t K = weights_unpacked.size(1);
48+
49+
for (int n = 0; n < N; n++) {
50+
for (int k = 0; k < K; k += 2) {
51+
const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
52+
const uint8_t second_val = packed_val & 0x0F;
53+
const uint8_t first_val = (packed_val & 0xF0) >> 4;
54+
55+
weights_unpacked[n][k] = int(first_val);
56+
weights_unpacked[n][k + 1] = int(second_val);
57+
}
58+
}
59+
60+
return weights_unpacked;
61+
}
62+
63+
at::Tensor dequantize_pergroup_weights(
64+
const at::Tensor& weights_4x2,
65+
const int64_t group_size,
66+
const at::Tensor& weight_scales,
67+
const at::Tensor& weight_zeros) {
68+
// First unpack the 4-bit weights to 8-bit integers
69+
at::Tensor weights_unpacked = unpack_weights_4x2(weights_4x2);
70+
71+
// Now dequantize using per-group quantization parameters
72+
std::vector<int64_t> weights_shape(weights_unpacked.sizes().vec());
73+
at::Tensor weights_dequantized =
74+
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat));
75+
76+
const int64_t N = weights_dequantized.size(0);
77+
const int64_t K = weights_dequantized.size(1);
78+
79+
for (int n = 0; n < N; n++) {
80+
for (int k = 0; k < K; k++) {
81+
const int group_idx = k / group_size;
82+
const float scale = weight_scales[group_idx][n].item().to<float>();
83+
const int zero = weight_zeros[group_idx][n].item().to<int>();
84+
85+
// Apply proper quantization paradigm: ((int_val - 8) - zero) * scale
86+
weights_dequantized[n][k] =
87+
((float(weights_unpacked[n][k].item().to<int>()) - 8.0f) -
88+
float(zero)) *
89+
scale;
90+
}
91+
}
92+
93+
return weights_dequantized;
94+
}
95+
96+
// Quantized matrix multiplication following quantization.md paradigms
97+
at::Tensor linear_qta8a_qga4w_quantized_matmul(
98+
const at::Tensor& quantized_input, // [B, M, K] int8 quantized input
99+
const at::Tensor& input_scale, // [B*M] per-token input scales
100+
const at::Tensor& input_zero_point, // [B*M] per-token input zero points
101+
const at::Tensor& weights_4x2, // [N, K/2] 4-bit packed weights
102+
const int64_t group_size, // Group size for weight quantization
103+
const at::Tensor& weight_scales, // [K/group_size, N] weight scales
104+
const at::Tensor& weight_zeros) { // [K/group_size, N] weight zeros
105+
106+
const int64_t B = quantized_input.size(0);
107+
const int64_t M = quantized_input.size(1);
108+
const int64_t K = quantized_input.size(2);
109+
const int64_t N = weights_4x2.size(0);
110+
111+
// Create output tensor for floating point results
112+
at::Tensor float_output =
113+
at::zeros({B, M, N}, at::device(at::kCPU).dtype(at::kFloat));
114+
115+
// Accessors for efficient access
116+
auto input_accessor = quantized_input.accessor<int8_t, 3>();
117+
auto output_accessor = float_output.accessor<float, 3>();
118+
auto weights_accessor = weights_4x2.accessor<uint8_t, 2>();
119+
auto weight_scales_accessor = weight_scales.accessor<float, 2>();
120+
auto weight_zeros_accessor = weight_zeros.accessor<int32_t, 2>();
121+
auto input_scale_accessor = input_scale.accessor<float, 1>();
122+
auto input_zero_accessor = input_zero_point.accessor<int32_t, 1>();
123+
124+
// Perform quantized matrix multiplication following quantization.md equation
125+
// (5): result_real_value = lhs_scale * rhs_scale * Sum_over_k(
126+
// (lhs_quantized_value[k] - lhs_zero_point) *
127+
// (rhs_quantized_value[k] - rhs_zero_point)
128+
// )
129+
for (int64_t b = 0; b < B; b++) {
130+
for (int64_t m = 0; m < M; m++) {
131+
const int64_t token_idx = b * M + m;
132+
const float lhs_scale =
133+
input_scale_accessor[token_idx]; // Per-token input scale
134+
const int32_t lhs_zero_point =
135+
input_zero_accessor[token_idx]; // Per-token input zero point
136+
137+
for (int64_t n = 0; n < N; n++) {
138+
float result_real_value = 0.0f;
139+
140+
for (int64_t k = 0; k < K; k++) {
141+
// Get per-group weight quantization parameters
142+
const int64_t group_idx = k / group_size;
143+
const float rhs_scale =
144+
weight_scales_accessor[group_idx][n]; // Per-group weight scale
145+
const int32_t rhs_zero_point =
146+
weight_zeros_accessor[group_idx]
147+
[n]; // Per-group weight zero point
148+
149+
// Unpack the 4-bit weight for this position
150+
const uint8_t packed_val = weights_accessor[n][k / 2];
151+
uint8_t weight_4bit;
152+
if (k % 2 == 0) {
153+
weight_4bit = (packed_val & 0xF0) >> 4; // First weight in pair
154+
} else {
155+
weight_4bit = packed_val & 0x0F; // Second weight in pair
156+
}
157+
158+
// Get quantized values
159+
const int32_t lhs_quantized_value =
160+
static_cast<int32_t>(input_accessor[b][m][k]);
161+
// Convert 4-bit weight to signed: subtract 8 to get range [-8, 7]
162+
const int32_t rhs_quantized_value =
163+
static_cast<int32_t>(weight_4bit) - 8;
164+
165+
// Apply proper quantization paradigm from quantization.md equation
166+
// (3): real_value = scale * (quantized_value - zero_point) Following
167+
// equation (5): result = lhs_scale * rhs_scale *
168+
// (lhs_quantized - lhs_zero) * (rhs_quantized - rhs_zero)
169+
const float lhs_diff =
170+
static_cast<float>(lhs_quantized_value - lhs_zero_point);
171+
const float rhs_diff =
172+
static_cast<float>(rhs_quantized_value - rhs_zero_point);
173+
174+
result_real_value += lhs_scale * rhs_scale * lhs_diff * rhs_diff;
175+
}
176+
177+
output_accessor[b][m][n] = result_real_value;
178+
}
179+
}
180+
}
181+
182+
return float_output;
183+
}
184+
185+
at::Tensor linear_qta8a_qga4w_4bit_dequant_impl(
186+
const at::Tensor& quantized_input,
187+
const at::Tensor& input_scale,
188+
const at::Tensor& input_zero_point,
189+
const at::Tensor& weights_4x2,
190+
const int64_t group_size,
191+
const at::Tensor& weight_scales,
192+
const at::Tensor& weight_zeros) {
193+
// Calculate number of input tokens
194+
int64_t input_num_tokens = 1;
195+
for (size_t i = 0; i < quantized_input.sizes().size() - 1; i++) {
196+
input_num_tokens *= quantized_input.size(i);
197+
}
198+
199+
// Manually dequantize the char tensor using per-token quantization
200+
at::Tensor x_float = at::zeros_like(quantized_input, at::kFloat);
201+
202+
// Apply per-token dequantization
203+
auto input_accessor = quantized_input.accessor<int8_t, 3>();
204+
auto output_accessor = x_float.accessor<float, 3>();
205+
206+
for (int64_t token_idx = 0; token_idx < input_num_tokens; token_idx++) {
207+
float scale_val = input_scale[token_idx].item<float>();
208+
int zero_point_val = input_zero_point[token_idx].item<int>();
209+
210+
// Calculate batch and sequence indices for this token
211+
int64_t b = token_idx / quantized_input.size(1);
212+
int64_t m = token_idx % quantized_input.size(1);
213+
214+
// Apply dequantization for all features in this token
215+
for (int64_t k = 0; k < quantized_input.size(-1); k++) {
216+
float dequant_val =
217+
(input_accessor[b][m][k] - zero_point_val) * scale_val;
218+
output_accessor[b][m][k] = dequant_val;
219+
}
220+
}
221+
222+
std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
223+
weights_shape[1] *= 2;
224+
225+
at::Tensor weights_dequantized =
226+
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat));
227+
228+
const int64_t N = weights_dequantized.size(0);
229+
const int64_t K = weights_dequantized.size(1);
230+
231+
for (int n = 0; n < N; n++) {
232+
for (int k = 0; k < K; k += 2) {
233+
const int group_idx = k / group_size;
234+
const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
235+
const uint8_t second_val = packed_val & 0x0F;
236+
const uint8_t first_val = (packed_val & 0xF0) >> 4;
237+
238+
const float scale = weight_scales[group_idx][n].item().to<float>();
239+
const int zero = weight_zeros[group_idx][n].item().to<int>();
240+
241+
weights_dequantized[n][k] =
242+
((float(first_val) - 8.0) - float(zero)) * scale;
243+
weights_dequantized[n][k + 1] =
244+
((float(second_val) - 8.0) - float(zero)) * scale;
245+
}
246+
}
247+
248+
at::Tensor linear_result = at::linear(x_float, weights_dequantized);
249+
250+
return linear_result;
251+
}

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ def define_common_targets(is_fbcode = False):
210210
":test_utils",
211211
]
212212
)
213+
define_test_targets(
214+
"linear_qta8a_qga4w_test",
215+
extra_deps = [
216+
":test_utils",
217+
]
218+
)
213219
define_test_targets(
214220
"rotary_embedding_test",
215221
extra_deps = [

0 commit comments

Comments
 (0)