Skip to content

Commit fc435fa

Browse files
authored
[ET-VK][Ops] linear_qta8a_qga4w_qta8o test framework (#12375)
# Context This test framework establishes the foundation for validating the `linear_qta8a_qga4w_qta8o` operator implementation as part of enabling dynamic quantization. The motivation stems from advancing beyond weight-only quantization to full activation and weight quantized linear operations, enabling true integer arithmetic throughout the matrix multiplication process for improved performance on GPU hardware. The current weight-only quantized linear implementations in ET-VK dequantize weights to floating point before computation, missing the performance benefits of integer arithmetic. This operator nomenclature breakdown: - **qta8a**: Quantized per-token affine 8-bit activation inputs - **qga4w**: Quantized per-group affine 4-bit weights - **qta8o**: Quantized per-token affine 8-bit outputs # Changes The reference implementation (`linear_qta8a_qga4w_qta8o_4bit_dequant_impl`) provides a baseline for validating the GPU shader implementation through a deliberately simplified computation path. The quantized int8 input tensor is dequantized using the standard affine transformation `(quantized_input.to(at::kFloat) - input_zero_point) * input_scale`. After dequantization, the implementation performs standard floating point linear operation `at::linear(x_float, weights_dequantized)`, then manually quantizes the result using `at::round(linear_result / output_scale) + output_zero_point` with clamping to the int8 range [-128,127]. This two-stage approach of dequantize → compute → quantize provides a clear reference against which the GPU's integer arithmetic implementation can be validated. Differential Revision: [D77173442](https://our.internmc.facebook.com/intern/diff/D77173442/) [ghstack-poisoned]
1 parent 5b483ab commit fc435fa

File tree

2 files changed

+190
-4
lines changed

2 files changed

+190
-4
lines changed

backends/vulkan/test/op_tests/linear_weight_int4_test.cpp renamed to backends/vulkan/test/op_tests/quantized_linear_test.cpp

Lines changed: 189 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,36 @@
1818

1919
#include <cassert>
2020

21+
class VulkanLinearQCS4WTest : public ::testing::Test {
22+
public:
23+
void SetUp() override {
24+
if (!vkcompute::api::context()
25+
->adapter_ptr()
26+
->supports_int16_shader_types()) {
27+
GTEST_SKIP();
28+
}
29+
}
30+
31+
void TearDown() override {
32+
// Clean up any resources if needed
33+
}
34+
};
35+
36+
class VulkanLinearQTA8AQGA4WTest : public ::testing::Test {
37+
public:
38+
void SetUp() override {
39+
if (!vkcompute::api::context()
40+
->adapter_ptr()
41+
->has_full_int8_buffers_support()) {
42+
GTEST_SKIP();
43+
}
44+
}
45+
46+
void TearDown() override {
47+
// Clean up any resources if needed
48+
}
49+
};
50+
2151
//
2252
// Reference Implementations
2353
//
@@ -149,6 +179,162 @@ at::Tensor linear_qcs4w_reference_impl(
149179
return out.reshape(out_shape);
150180
}
151181

182+
at::Tensor linear_qta8a_qga4w_quantized_matmul(
183+
const at::Tensor& quantized_input, // [B, M, K] int8 quantized input
184+
const at::Tensor& input_scale, // [B*M] per-token input scales
185+
const at::Tensor& input_zero_point, // [B*M] per-token input zero points
186+
const at::Tensor& weights_4x2, // [N, K/2] 4-bit packed weights
187+
const int64_t group_size, // Group size for weight quantization
188+
const at::Tensor& weight_scales, // [K/group_size, N] weight scales
189+
const at::Tensor& weight_zeros) { // [K/group_size, N] weight zeros
190+
191+
const int64_t B = quantized_input.size(0);
192+
const int64_t M = quantized_input.size(1);
193+
const int64_t K = quantized_input.size(2);
194+
const int64_t N = weights_4x2.size(0);
195+
196+
// Create output tensor for floating point results
197+
at::Tensor float_output =
198+
at::zeros({B, M, N}, at::device(at::kCPU).dtype(at::kFloat));
199+
200+
// Accessors for efficient access
201+
auto input_accessor = quantized_input.accessor<int8_t, 3>();
202+
auto output_accessor = float_output.accessor<float, 3>();
203+
auto weights_accessor = weights_4x2.accessor<uint8_t, 2>();
204+
auto weight_scales_accessor = weight_scales.accessor<float, 2>();
205+
auto weight_zeros_accessor = weight_zeros.accessor<int32_t, 2>();
206+
auto input_scale_accessor = input_scale.accessor<float, 1>();
207+
auto input_zero_accessor = input_zero_point.accessor<int32_t, 1>();
208+
209+
// Perform quantized matrix multiplication following quantization.md equation
210+
// (5): result_real_value = lhs_scale * rhs_scale * Sum_over_k(
211+
// (lhs_quantized_value[k] - lhs_zero_point) *
212+
// (rhs_quantized_value[k] - rhs_zero_point)
213+
// )
214+
for (int64_t b = 0; b < B; b++) {
215+
for (int64_t m = 0; m < M; m++) {
216+
const int64_t token_idx = b * M + m;
217+
const float lhs_scale =
218+
input_scale_accessor[token_idx]; // Per-token input scale
219+
const int32_t lhs_zero_point =
220+
input_zero_accessor[token_idx]; // Per-token input zero point
221+
222+
for (int64_t n = 0; n < N; n++) {
223+
float result_real_value = 0.0f;
224+
225+
for (int64_t k = 0; k < K; k++) {
226+
// Get per-group weight quantization parameters
227+
const int64_t group_idx = k / group_size;
228+
const float rhs_scale =
229+
weight_scales_accessor[group_idx][n]; // Per-group weight scale
230+
const int32_t rhs_zero_point =
231+
weight_zeros_accessor[group_idx]
232+
[n]; // Per-group weight zero point
233+
234+
// Unpack the 4-bit weight for this position
235+
const uint8_t packed_val = weights_accessor[n][k / 2];
236+
uint8_t weight_4bit;
237+
if (k % 2 == 0) {
238+
weight_4bit = (packed_val & 0xF0) >> 4; // First weight in pair
239+
} else {
240+
weight_4bit = packed_val & 0x0F; // Second weight in pair
241+
}
242+
243+
// Get quantized values
244+
const int32_t lhs_quantized_value =
245+
static_cast<int32_t>(input_accessor[b][m][k]);
246+
// Convert 4-bit weight to signed: subtract 8 to get range [-8, 7]
247+
const int32_t rhs_quantized_value =
248+
static_cast<int32_t>(weight_4bit) - 8;
249+
250+
// Apply proper quantization paradigm from quantization.md equation
251+
// (3): real_value = scale * (quantized_value - zero_point) Following
252+
// equation (5): result = lhs_scale * rhs_scale *
253+
// (lhs_quantized - lhs_zero) * (rhs_quantized - rhs_zero)
254+
const float lhs_diff =
255+
static_cast<float>(lhs_quantized_value - lhs_zero_point);
256+
const float rhs_diff =
257+
static_cast<float>(rhs_quantized_value - rhs_zero_point);
258+
259+
result_real_value += lhs_scale * rhs_scale * lhs_diff * rhs_diff;
260+
}
261+
262+
output_accessor[b][m][n] = result_real_value;
263+
}
264+
}
265+
}
266+
267+
return float_output;
268+
}
269+
270+
at::Tensor linear_qta8a_qga4w_4bit_dequant_impl(
271+
const at::Tensor& quantized_input,
272+
const at::Tensor& input_scale,
273+
const at::Tensor& input_zero_point,
274+
const at::Tensor& weights_4x2,
275+
const int64_t group_size,
276+
const at::Tensor& weight_scales,
277+
const at::Tensor& weight_zeros) {
278+
// Calculate number of input tokens
279+
int64_t input_num_tokens = 1;
280+
for (size_t i = 0; i < quantized_input.sizes().size() - 1; i++) {
281+
input_num_tokens *= quantized_input.size(i);
282+
}
283+
284+
// Manually dequantize the char tensor using per-token quantization
285+
at::Tensor x_float = at::zeros_like(quantized_input, at::kFloat);
286+
287+
// Apply per-token dequantization
288+
auto input_accessor = quantized_input.accessor<int8_t, 3>();
289+
auto output_accessor = x_float.accessor<float, 3>();
290+
291+
for (int64_t token_idx = 0; token_idx < input_num_tokens; token_idx++) {
292+
float scale_val = input_scale[token_idx].item<float>();
293+
int zero_point_val = input_zero_point[token_idx].item<int>();
294+
295+
// Calculate batch and sequence indices for this token
296+
int64_t b = token_idx / quantized_input.size(1);
297+
int64_t m = token_idx % quantized_input.size(1);
298+
299+
// Apply dequantization for all features in this token
300+
for (int64_t k = 0; k < quantized_input.size(-1); k++) {
301+
float dequant_val =
302+
(input_accessor[b][m][k] - zero_point_val) * scale_val;
303+
output_accessor[b][m][k] = dequant_val;
304+
}
305+
}
306+
307+
std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
308+
weights_shape[1] *= 2;
309+
310+
at::Tensor weights_dequantized =
311+
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat));
312+
313+
const int64_t N = weights_dequantized.size(0);
314+
const int64_t K = weights_dequantized.size(1);
315+
316+
for (int n = 0; n < N; n++) {
317+
for (int k = 0; k < K; k += 2) {
318+
const int group_idx = k / group_size;
319+
const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
320+
const uint8_t second_val = packed_val & 0x0F;
321+
const uint8_t first_val = (packed_val & 0xF0) >> 4;
322+
323+
const float scale = weight_scales[group_idx][n].item().to<float>();
324+
const int zero = weight_zeros[group_idx][n].item().to<int>();
325+
326+
weights_dequantized[n][k] =
327+
((float(first_val) - 8.0) - float(zero)) * scale;
328+
weights_dequantized[n][k + 1] =
329+
((float(second_val) - 8.0) - float(zero)) * scale;
330+
}
331+
}
332+
333+
at::Tensor linear_result = at::linear(x_float, weights_dequantized);
334+
335+
return linear_result;
336+
}
337+
152338
//
153339
// Test functions
154340
//
@@ -425,15 +611,15 @@ TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) {
425611
/*N = */ 256);
426612
}
427613

428-
TEST(VulkanLinearQCS4WTest, test_reference_impl) {
614+
TEST_F(VulkanLinearQCS4WTest, test_reference_impl) {
429615
test_reference_linear_qcs4w(
430616
/*B = */ 1,
431617
/*M = */ 4,
432618
/*K = */ 128,
433619
/*N = */ 32);
434620
}
435621

436-
TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
622+
TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
437623
test_vulkan_linear_qcs4w(
438624
/*B = */ 1,
439625
/*M = */ 4,
@@ -447,7 +633,7 @@ TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
447633
/*N = */ 256);
448634
}
449635

450-
TEST(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) {
636+
TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) {
451637
test_vulkan_linear_qcs4w(
452638
/*B = */ 1,
453639
/*M = */ 32,

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def define_common_targets(is_fbcode = False):
205205
]
206206
)
207207
define_test_targets(
208-
"linear_weight_int4_test",
208+
"quantized_linear_test",
209209
extra_deps = [
210210
":test_utils",
211211
]

0 commit comments

Comments
 (0)