Skip to content

Commit b048cb8

Browse files
author
morelos
committed
[ET-VK][Ops] linear_qta8a_qga4w test framework
Pull Request resolved: #12005 # Context This test framework establishes the foundation for validating the `linear_qta8a_qga4w` operator implementation as part of enabling dynamic quantization. The motivation stems from advancing beyond weight-only quantization to full activation and weight quantized linear operations, enabling true integer arithmetic throughout the matrix multiplication process for improved performance on GPU hardware. The current weight-only quantized linear implementations in ET-VK dequantize weights to floating point before computation, missing the performance benefits of integer arithmetic. This operator nomenclature breakdown: - **qta8a**: Quantized per-token affine 8-bit activation inputs - **qga4w**: Quantized per-group affine 4-bit weights # Changes The reference implementation (`linear_qta8a_qga4w_4bit_dequant_impl`) provides a baseline for validating the GPU shader implementation through a deliberately simplified computation path. The quantized int8 input tensor is dequantized using the standard affine transformation `(quantized_input.to(at::kFloat) - input_zero_point) * input_scale`. After dequantization, the implementation performs standard floating point linear operation `at::linear(x_float, weights_dequantized)`. This two-stage approach of dequantize → compute provides a clear reference against which the GPU's integer arithmetic implementation can be validated. ghstack-source-id: 294213573 @exported-using-ghexport Differential Revision: [D77173442](https://our.internmc.facebook.com/intern/diff/D77173442/)
1 parent a7091bf commit b048cb8

File tree

2 files changed

+173
-1
lines changed

2 files changed

+173
-1
lines changed

backends/vulkan/test/op_tests/linear_weight_int4_test.cpp renamed to backends/vulkan/test/op_tests/quantized_linear_test.cpp

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@
1818

1919
#include <cassert>
2020

21+
class VulkanLinearQTA8AQGA4WTest : public ::testing::Test {
22+
public:
23+
void SetUp() override {
24+
if (!vkcompute::api::context()
25+
->adapter_ptr()
26+
->has_full_int8_buffers_support()) {
27+
GTEST_SKIP();
28+
}
29+
}
30+
31+
void TearDown() override {
32+
// Clean up any resources if needed
33+
}
34+
};
35+
2136
//
2237
// Reference Implementations
2338
//
@@ -149,6 +164,163 @@ at::Tensor linear_qcs4w_reference_impl(
149164
return out.reshape(out_shape);
150165
}
151166

167+
// Quantized matrix multiplication following quantization.md paradigms
168+
at::Tensor linear_qta8a_qga4w_quantized_matmul(
169+
const at::Tensor& quantized_input, // [B, M, K] int8 quantized input
170+
const at::Tensor& input_scale, // [B*M] per-token input scales
171+
const at::Tensor& input_zero_point, // [B*M] per-token input zero points
172+
const at::Tensor& weights_4x2, // [N, K/2] 4-bit packed weights
173+
const int64_t group_size, // Group size for weight quantization
174+
const at::Tensor& weight_scales, // [K/group_size, N] weight scales
175+
const at::Tensor& weight_zeros) { // [K/group_size, N] weight zeros
176+
177+
const int64_t B = quantized_input.size(0);
178+
const int64_t M = quantized_input.size(1);
179+
const int64_t K = quantized_input.size(2);
180+
const int64_t N = weights_4x2.size(0);
181+
182+
// Create output tensor for floating point results
183+
at::Tensor float_output =
184+
at::zeros({B, M, N}, at::device(at::kCPU).dtype(at::kFloat));
185+
186+
// Accessors for efficient access
187+
auto input_accessor = quantized_input.accessor<int8_t, 3>();
188+
auto output_accessor = float_output.accessor<float, 3>();
189+
auto weights_accessor = weights_4x2.accessor<uint8_t, 2>();
190+
auto weight_scales_accessor = weight_scales.accessor<float, 2>();
191+
auto weight_zeros_accessor = weight_zeros.accessor<int32_t, 2>();
192+
auto input_scale_accessor = input_scale.accessor<float, 1>();
193+
auto input_zero_accessor = input_zero_point.accessor<int32_t, 1>();
194+
195+
// Perform quantized matrix multiplication following quantization.md equation
196+
// (5): result_real_value = lhs_scale * rhs_scale * Sum_over_k(
197+
// (lhs_quantized_value[k] - lhs_zero_point) *
198+
// (rhs_quantized_value[k] - rhs_zero_point)
199+
// )
200+
for (int64_t b = 0; b < B; b++) {
201+
for (int64_t m = 0; m < M; m++) {
202+
const int64_t token_idx = b * M + m;
203+
const float lhs_scale =
204+
input_scale_accessor[token_idx]; // Per-token input scale
205+
const int32_t lhs_zero_point =
206+
input_zero_accessor[token_idx]; // Per-token input zero point
207+
208+
for (int64_t n = 0; n < N; n++) {
209+
float result_real_value = 0.0f;
210+
211+
for (int64_t k = 0; k < K; k++) {
212+
// Get per-group weight quantization parameters
213+
const int64_t group_idx = k / group_size;
214+
const float rhs_scale =
215+
weight_scales_accessor[group_idx][n]; // Per-group weight scale
216+
const int32_t rhs_zero_point =
217+
weight_zeros_accessor[group_idx]
218+
[n]; // Per-group weight zero point
219+
220+
// Unpack the 4-bit weight for this position
221+
const uint8_t packed_val = weights_accessor[n][k / 2];
222+
uint8_t weight_4bit;
223+
if (k % 2 == 0) {
224+
weight_4bit = (packed_val & 0xF0) >> 4; // First weight in pair
225+
} else {
226+
weight_4bit = packed_val & 0x0F; // Second weight in pair
227+
}
228+
229+
// Get quantized values
230+
const int32_t lhs_quantized_value =
231+
static_cast<int32_t>(input_accessor[b][m][k]);
232+
// Convert 4-bit weight to signed: subtract 8 to get range [-8, 7]
233+
const int32_t rhs_quantized_value =
234+
static_cast<int32_t>(weight_4bit) - 8;
235+
236+
// Apply proper quantization paradigm from quantization.md equation
237+
// (3): real_value = scale * (quantized_value - zero_point) Following
238+
// equation (5): result = lhs_scale * rhs_scale *
239+
// (lhs_quantized - lhs_zero) * (rhs_quantized - rhs_zero)
240+
const float lhs_diff =
241+
static_cast<float>(lhs_quantized_value - lhs_zero_point);
242+
const float rhs_diff =
243+
static_cast<float>(rhs_quantized_value - rhs_zero_point);
244+
245+
result_real_value += lhs_scale * rhs_scale * lhs_diff * rhs_diff;
246+
}
247+
248+
output_accessor[b][m][n] = result_real_value;
249+
}
250+
}
251+
}
252+
253+
return float_output;
254+
}
255+
256+
at::Tensor linear_qta8a_qga4w_4bit_dequant_impl(
257+
const at::Tensor& quantized_input,
258+
const at::Tensor& input_scale,
259+
const at::Tensor& input_zero_point,
260+
const at::Tensor& weights_4x2,
261+
const int64_t group_size,
262+
const at::Tensor& weight_scales,
263+
const at::Tensor& weight_zeros) {
264+
// Calculate number of input tokens
265+
int64_t input_num_tokens = 1;
266+
for (size_t i = 0; i < quantized_input.sizes().size() - 1; i++) {
267+
input_num_tokens *= quantized_input.size(i);
268+
}
269+
270+
// Manually dequantize the char tensor using per-token quantization
271+
at::Tensor x_float = at::zeros_like(quantized_input, at::kFloat);
272+
273+
// Apply per-token dequantization
274+
auto input_accessor = quantized_input.accessor<int8_t, 3>();
275+
auto output_accessor = x_float.accessor<float, 3>();
276+
277+
for (int64_t token_idx = 0; token_idx < input_num_tokens; token_idx++) {
278+
float scale_val = input_scale[token_idx].item<float>();
279+
int zero_point_val = input_zero_point[token_idx].item<int>();
280+
281+
// Calculate batch and sequence indices for this token
282+
int64_t b = token_idx / quantized_input.size(1);
283+
int64_t m = token_idx % quantized_input.size(1);
284+
285+
// Apply dequantization for all features in this token
286+
for (int64_t k = 0; k < quantized_input.size(-1); k++) {
287+
float dequant_val =
288+
(input_accessor[b][m][k] - zero_point_val) * scale_val;
289+
output_accessor[b][m][k] = dequant_val;
290+
}
291+
}
292+
293+
std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
294+
weights_shape[1] *= 2;
295+
296+
at::Tensor weights_dequantized =
297+
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat));
298+
299+
const int64_t N = weights_dequantized.size(0);
300+
const int64_t K = weights_dequantized.size(1);
301+
302+
for (int n = 0; n < N; n++) {
303+
for (int k = 0; k < K; k += 2) {
304+
const int group_idx = k / group_size;
305+
const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
306+
const uint8_t second_val = packed_val & 0x0F;
307+
const uint8_t first_val = (packed_val & 0xF0) >> 4;
308+
309+
const float scale = weight_scales[group_idx][n].item().to<float>();
310+
const int zero = weight_zeros[group_idx][n].item().to<int>();
311+
312+
weights_dequantized[n][k] =
313+
((float(first_val) - 8.0) - float(zero)) * scale;
314+
weights_dequantized[n][k + 1] =
315+
((float(second_val) - 8.0) - float(zero)) * scale;
316+
}
317+
}
318+
319+
at::Tensor linear_result = at::linear(x_float, weights_dequantized);
320+
321+
return linear_result;
322+
}
323+
152324
//
153325
// Test functions
154326
//

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def define_common_targets(is_fbcode = False):
205205
]
206206
)
207207
define_test_targets(
208-
"linear_weight_int4_test",
208+
"quantized_linear_test",
209209
extra_deps = [
210210
":test_utils",
211211
]

0 commit comments

Comments
 (0)