Skip to content

Commit bc7106e

Browse files
author
morelos
committed
[ET-VK][Ops] linear_qta8a_qga4w test framework
Pull Request resolved: #12005 # Context This test framework establishes the foundation for validating the `linear_qta8a_qga4w` operator implementation as part of enabling dynamic quantization. The motivation stems from advancing beyond weight-only quantization to full activation and weight quantized linear operations, enabling true integer arithmetic throughout the matrix multiplication process for improved performance on GPU hardware. The current weight-only quantized linear implementations in ET-VK dequantize weights to floating point before computation, missing the performance benefits of integer arithmetic. This operator nomenclature breakdown: - **qta8a**: Quantized per-token affine 8-bit activation inputs - **qga4w**: Quantized per-group affine 4-bit weights # Changes The reference implementation (`linear_qta8a_qga4w_4bit_dequant_impl`) provides a baseline for validating the GPU shader implementation through a deliberately simplified computation path. The quantized int8 input tensor is dequantized using the standard affine transformation `(quantized_input.to(at::kFloat) - input_zero_point) * input_scale`. After dequantization, the implementation performs standard floating point linear operation `at::linear(x_float, weights_dequantized)`. This two-stage approach of dequantize → compute provides a clear reference against which the GPU's integer arithmetic implementation can be validated. ghstack-source-id: 294631565 @exported-using-ghexport Differential Revision: [D77173442](https://our.internmc.facebook.com/intern/diff/D77173442/)
1 parent 6669637 commit bc7106e

File tree

2 files changed

+191
-4
lines changed

2 files changed

+191
-4
lines changed

backends/vulkan/test/op_tests/linear_weight_int4_test.cpp renamed to backends/vulkan/test/op_tests/quantized_linear_test.cpp

Lines changed: 190 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,36 @@
1818

1919
#include <cassert>
2020

21+
class VulkanLinearQCS4WTest : public ::testing::Test {
22+
public:
23+
void SetUp() override {
24+
if (!vkcompute::api::context()
25+
->adapter_ptr()
26+
->supports_int16_shader_types()) {
27+
GTEST_SKIP();
28+
}
29+
}
30+
31+
void TearDown() override {
32+
// Clean up any resources if needed
33+
}
34+
};
35+
36+
class VulkanLinearQTA8AQGA4WTest : public ::testing::Test {
37+
public:
38+
void SetUp() override {
39+
if (!vkcompute::api::context()
40+
->adapter_ptr()
41+
->has_full_int8_buffers_support()) {
42+
GTEST_SKIP();
43+
}
44+
}
45+
46+
void TearDown() override {
47+
// Clean up any resources if needed
48+
}
49+
};
50+
2151
//
2252
// Reference Implementations
2353
//
@@ -149,6 +179,163 @@ at::Tensor linear_qcs4w_reference_impl(
149179
return out.reshape(out_shape);
150180
}
151181

182+
// Quantized matrix multiplication following quantization.md paradigms
183+
at::Tensor linear_qta8a_qga4w_quantized_matmul(
184+
const at::Tensor& quantized_input, // [B, M, K] int8 quantized input
185+
const at::Tensor& input_scale, // [B*M] per-token input scales
186+
const at::Tensor& input_zero_point, // [B*M] per-token input zero points
187+
const at::Tensor& weights_4x2, // [N, K/2] 4-bit packed weights
188+
const int64_t group_size, // Group size for weight quantization
189+
const at::Tensor& weight_scales, // [K/group_size, N] weight scales
190+
const at::Tensor& weight_zeros) { // [K/group_size, N] weight zeros
191+
192+
const int64_t B = quantized_input.size(0);
193+
const int64_t M = quantized_input.size(1);
194+
const int64_t K = quantized_input.size(2);
195+
const int64_t N = weights_4x2.size(0);
196+
197+
// Create output tensor for floating point results
198+
at::Tensor float_output =
199+
at::zeros({B, M, N}, at::device(at::kCPU).dtype(at::kFloat));
200+
201+
// Accessors for efficient access
202+
auto input_accessor = quantized_input.accessor<int8_t, 3>();
203+
auto output_accessor = float_output.accessor<float, 3>();
204+
auto weights_accessor = weights_4x2.accessor<uint8_t, 2>();
205+
auto weight_scales_accessor = weight_scales.accessor<float, 2>();
206+
auto weight_zeros_accessor = weight_zeros.accessor<int32_t, 2>();
207+
auto input_scale_accessor = input_scale.accessor<float, 1>();
208+
auto input_zero_accessor = input_zero_point.accessor<int32_t, 1>();
209+
210+
// Perform quantized matrix multiplication following quantization.md equation
211+
// (5): result_real_value = lhs_scale * rhs_scale * Sum_over_k(
212+
// (lhs_quantized_value[k] - lhs_zero_point) *
213+
// (rhs_quantized_value[k] - rhs_zero_point)
214+
// )
215+
for (int64_t b = 0; b < B; b++) {
216+
for (int64_t m = 0; m < M; m++) {
217+
const int64_t token_idx = b * M + m;
218+
const float lhs_scale =
219+
input_scale_accessor[token_idx]; // Per-token input scale
220+
const int32_t lhs_zero_point =
221+
input_zero_accessor[token_idx]; // Per-token input zero point
222+
223+
for (int64_t n = 0; n < N; n++) {
224+
float result_real_value = 0.0f;
225+
226+
for (int64_t k = 0; k < K; k++) {
227+
// Get per-group weight quantization parameters
228+
const int64_t group_idx = k / group_size;
229+
const float rhs_scale =
230+
weight_scales_accessor[group_idx][n]; // Per-group weight scale
231+
const int32_t rhs_zero_point =
232+
weight_zeros_accessor[group_idx]
233+
[n]; // Per-group weight zero point
234+
235+
// Unpack the 4-bit weight for this position
236+
const uint8_t packed_val = weights_accessor[n][k / 2];
237+
uint8_t weight_4bit;
238+
if (k % 2 == 0) {
239+
weight_4bit = (packed_val & 0xF0) >> 4; // First weight in pair
240+
} else {
241+
weight_4bit = packed_val & 0x0F; // Second weight in pair
242+
}
243+
244+
// Get quantized values
245+
const int32_t lhs_quantized_value =
246+
static_cast<int32_t>(input_accessor[b][m][k]);
247+
// Convert 4-bit weight to signed: subtract 8 to get range [-8, 7]
248+
const int32_t rhs_quantized_value =
249+
static_cast<int32_t>(weight_4bit) - 8;
250+
251+
// Apply proper quantization paradigm from quantization.md equation
252+
// (3): real_value = scale * (quantized_value - zero_point) Following
253+
// equation (5): result = lhs_scale * rhs_scale *
254+
// (lhs_quantized - lhs_zero) * (rhs_quantized - rhs_zero)
255+
const float lhs_diff =
256+
static_cast<float>(lhs_quantized_value - lhs_zero_point);
257+
const float rhs_diff =
258+
static_cast<float>(rhs_quantized_value - rhs_zero_point);
259+
260+
result_real_value += lhs_scale * rhs_scale * lhs_diff * rhs_diff;
261+
}
262+
263+
output_accessor[b][m][n] = result_real_value;
264+
}
265+
}
266+
}
267+
268+
return float_output;
269+
}
270+
271+
at::Tensor linear_qta8a_qga4w_4bit_dequant_impl(
272+
const at::Tensor& quantized_input,
273+
const at::Tensor& input_scale,
274+
const at::Tensor& input_zero_point,
275+
const at::Tensor& weights_4x2,
276+
const int64_t group_size,
277+
const at::Tensor& weight_scales,
278+
const at::Tensor& weight_zeros) {
279+
// Calculate number of input tokens
280+
int64_t input_num_tokens = 1;
281+
for (size_t i = 0; i < quantized_input.sizes().size() - 1; i++) {
282+
input_num_tokens *= quantized_input.size(i);
283+
}
284+
285+
// Manually dequantize the char tensor using per-token quantization
286+
at::Tensor x_float = at::zeros_like(quantized_input, at::kFloat);
287+
288+
// Apply per-token dequantization
289+
auto input_accessor = quantized_input.accessor<int8_t, 3>();
290+
auto output_accessor = x_float.accessor<float, 3>();
291+
292+
for (int64_t token_idx = 0; token_idx < input_num_tokens; token_idx++) {
293+
float scale_val = input_scale[token_idx].item<float>();
294+
int zero_point_val = input_zero_point[token_idx].item<int>();
295+
296+
// Calculate batch and sequence indices for this token
297+
int64_t b = token_idx / quantized_input.size(1);
298+
int64_t m = token_idx % quantized_input.size(1);
299+
300+
// Apply dequantization for all features in this token
301+
for (int64_t k = 0; k < quantized_input.size(-1); k++) {
302+
float dequant_val =
303+
(input_accessor[b][m][k] - zero_point_val) * scale_val;
304+
output_accessor[b][m][k] = dequant_val;
305+
}
306+
}
307+
308+
std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
309+
weights_shape[1] *= 2;
310+
311+
at::Tensor weights_dequantized =
312+
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat));
313+
314+
const int64_t N = weights_dequantized.size(0);
315+
const int64_t K = weights_dequantized.size(1);
316+
317+
for (int n = 0; n < N; n++) {
318+
for (int k = 0; k < K; k += 2) {
319+
const int group_idx = k / group_size;
320+
const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
321+
const uint8_t second_val = packed_val & 0x0F;
322+
const uint8_t first_val = (packed_val & 0xF0) >> 4;
323+
324+
const float scale = weight_scales[group_idx][n].item().to<float>();
325+
const int zero = weight_zeros[group_idx][n].item().to<int>();
326+
327+
weights_dequantized[n][k] =
328+
((float(first_val) - 8.0) - float(zero)) * scale;
329+
weights_dequantized[n][k + 1] =
330+
((float(second_val) - 8.0) - float(zero)) * scale;
331+
}
332+
}
333+
334+
at::Tensor linear_result = at::linear(x_float, weights_dequantized);
335+
336+
return linear_result;
337+
}
338+
152339
//
153340
// Test functions
154341
//
@@ -425,15 +612,15 @@ TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) {
425612
/*N = */ 256);
426613
}
427614

428-
TEST(VulkanLinearQCS4WTest, test_reference_impl) {
615+
TEST_F(VulkanLinearQCS4WTest, test_reference_impl) {
429616
test_reference_linear_qcs4w(
430617
/*B = */ 1,
431618
/*M = */ 4,
432619
/*K = */ 128,
433620
/*N = */ 32);
434621
}
435622

436-
TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
623+
TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
437624
test_vulkan_linear_qcs4w(
438625
/*B = */ 1,
439626
/*M = */ 4,
@@ -447,7 +634,7 @@ TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
447634
/*N = */ 256);
448635
}
449636

450-
TEST(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) {
637+
TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) {
451638
test_vulkan_linear_qcs4w(
452639
/*B = */ 1,
453640
/*M = */ 32,

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def define_common_targets(is_fbcode = False):
205205
]
206206
)
207207
define_test_targets(
208-
"linear_weight_int4_test",
208+
"quantized_linear_test",
209209
extra_deps = [
210210
":test_utils",
211211
]

0 commit comments

Comments
 (0)