1818
1919#include < cassert>
2020
21+ class VulkanLinearQCS4WTest : public ::testing::Test {
22+ public:
23+ void SetUp () override {
24+ if (!vkcompute::api::context ()
25+ ->adapter_ptr ()
26+ ->supports_int16_shader_types ()) {
27+ GTEST_SKIP ();
28+ }
29+ }
30+
31+ void TearDown () override {
32+ // Clean up any resources if needed
33+ }
34+ };
35+
36+ class VulkanLinearQTA8AQGA4WTest : public ::testing::Test {
37+ public:
38+ void SetUp () override {
39+ if (!vkcompute::api::context ()
40+ ->adapter_ptr ()
41+ ->has_full_int8_buffers_support ()) {
42+ GTEST_SKIP ();
43+ }
44+ }
45+
46+ void TearDown () override {
47+ // Clean up any resources if needed
48+ }
49+ };
50+
2151//
2252// Reference Implementations
2353//
@@ -149,6 +179,162 @@ at::Tensor linear_qcs4w_reference_impl(
149179 return out.reshape (out_shape);
150180}
151181
182+ at::Tensor linear_qta8a_qga4w_quantized_matmul (
183+ const at::Tensor& quantized_input, // [B, M, K] int8 quantized input
184+ const at::Tensor& input_scale, // [B*M] per-token input scales
185+ const at::Tensor& input_zero_point, // [B*M] per-token input zero points
186+ const at::Tensor& weights_4x2, // [N, K/2] 4-bit packed weights
187+ const int64_t group_size, // Group size for weight quantization
188+ const at::Tensor& weight_scales, // [K/group_size, N] weight scales
189+ const at::Tensor& weight_zeros) { // [K/group_size, N] weight zeros
190+
191+ const int64_t B = quantized_input.size (0 );
192+ const int64_t M = quantized_input.size (1 );
193+ const int64_t K = quantized_input.size (2 );
194+ const int64_t N = weights_4x2.size (0 );
195+
196+ // Create output tensor for floating point results
197+ at::Tensor float_output =
198+ at::zeros ({B, M, N}, at::device (at::kCPU ).dtype (at::kFloat ));
199+
200+ // Accessors for efficient access
201+ auto input_accessor = quantized_input.accessor <int8_t , 3 >();
202+ auto output_accessor = float_output.accessor <float , 3 >();
203+ auto weights_accessor = weights_4x2.accessor <uint8_t , 2 >();
204+ auto weight_scales_accessor = weight_scales.accessor <float , 2 >();
205+ auto weight_zeros_accessor = weight_zeros.accessor <int32_t , 2 >();
206+ auto input_scale_accessor = input_scale.accessor <float , 1 >();
207+ auto input_zero_accessor = input_zero_point.accessor <int32_t , 1 >();
208+
209+ // Perform quantized matrix multiplication following quantization.md equation
210+ // (5): result_real_value = lhs_scale * rhs_scale * Sum_over_k(
211+ // (lhs_quantized_value[k] - lhs_zero_point) *
212+ // (rhs_quantized_value[k] - rhs_zero_point)
213+ // )
214+ for (int64_t b = 0 ; b < B; b++) {
215+ for (int64_t m = 0 ; m < M; m++) {
216+ const int64_t token_idx = b * M + m;
217+ const float lhs_scale =
218+ input_scale_accessor[token_idx]; // Per-token input scale
219+ const int32_t lhs_zero_point =
220+ input_zero_accessor[token_idx]; // Per-token input zero point
221+
222+ for (int64_t n = 0 ; n < N; n++) {
223+ float result_real_value = 0 .0f ;
224+
225+ for (int64_t k = 0 ; k < K; k++) {
226+ // Get per-group weight quantization parameters
227+ const int64_t group_idx = k / group_size;
228+ const float rhs_scale =
229+ weight_scales_accessor[group_idx][n]; // Per-group weight scale
230+ const int32_t rhs_zero_point =
231+ weight_zeros_accessor[group_idx]
232+ [n]; // Per-group weight zero point
233+
234+ // Unpack the 4-bit weight for this position
235+ const uint8_t packed_val = weights_accessor[n][k / 2 ];
236+ uint8_t weight_4bit;
237+ if (k % 2 == 0 ) {
238+ weight_4bit = (packed_val & 0xF0 ) >> 4 ; // First weight in pair
239+ } else {
240+ weight_4bit = packed_val & 0x0F ; // Second weight in pair
241+ }
242+
243+ // Get quantized values
244+ const int32_t lhs_quantized_value =
245+ static_cast <int32_t >(input_accessor[b][m][k]);
246+ // Convert 4-bit weight to signed: subtract 8 to get range [-8, 7]
247+ const int32_t rhs_quantized_value =
248+ static_cast <int32_t >(weight_4bit) - 8 ;
249+
250+ // Apply proper quantization paradigm from quantization.md equation
251+ // (3): real_value = scale * (quantized_value - zero_point) Following
252+ // equation (5): result = lhs_scale * rhs_scale *
253+ // (lhs_quantized - lhs_zero) * (rhs_quantized - rhs_zero)
254+ const float lhs_diff =
255+ static_cast <float >(lhs_quantized_value - lhs_zero_point);
256+ const float rhs_diff =
257+ static_cast <float >(rhs_quantized_value - rhs_zero_point);
258+
259+ result_real_value += lhs_scale * rhs_scale * lhs_diff * rhs_diff;
260+ }
261+
262+ output_accessor[b][m][n] = result_real_value;
263+ }
264+ }
265+ }
266+
267+ return float_output;
268+ }
269+
270+ at::Tensor linear_qta8a_qga4w_4bit_dequant_impl (
271+ const at::Tensor& quantized_input,
272+ const at::Tensor& input_scale,
273+ const at::Tensor& input_zero_point,
274+ const at::Tensor& weights_4x2,
275+ const int64_t group_size,
276+ const at::Tensor& weight_scales,
277+ const at::Tensor& weight_zeros) {
278+ // Calculate number of input tokens
279+ int64_t input_num_tokens = 1 ;
280+ for (size_t i = 0 ; i < quantized_input.sizes ().size () - 1 ; i++) {
281+ input_num_tokens *= quantized_input.size (i);
282+ }
283+
284+ // Manually dequantize the char tensor using per-token quantization
285+ at::Tensor x_float = at::zeros_like (quantized_input, at::kFloat );
286+
287+ // Apply per-token dequantization
288+ auto input_accessor = quantized_input.accessor <int8_t , 3 >();
289+ auto output_accessor = x_float.accessor <float , 3 >();
290+
291+ for (int64_t token_idx = 0 ; token_idx < input_num_tokens; token_idx++) {
292+ float scale_val = input_scale[token_idx].item <float >();
293+ int zero_point_val = input_zero_point[token_idx].item <int >();
294+
295+ // Calculate batch and sequence indices for this token
296+ int64_t b = token_idx / quantized_input.size (1 );
297+ int64_t m = token_idx % quantized_input.size (1 );
298+
299+ // Apply dequantization for all features in this token
300+ for (int64_t k = 0 ; k < quantized_input.size (-1 ); k++) {
301+ float dequant_val =
302+ (input_accessor[b][m][k] - zero_point_val) * scale_val;
303+ output_accessor[b][m][k] = dequant_val;
304+ }
305+ }
306+
307+ std::vector<int64_t > weights_shape (weights_4x2.sizes ().vec ());
308+ weights_shape[1 ] *= 2 ;
309+
310+ at::Tensor weights_dequantized =
311+ at::empty (weights_shape, at::device (at::kCPU ).dtype (at::kFloat ));
312+
313+ const int64_t N = weights_dequantized.size (0 );
314+ const int64_t K = weights_dequantized.size (1 );
315+
316+ for (int n = 0 ; n < N; n++) {
317+ for (int k = 0 ; k < K; k += 2 ) {
318+ const int group_idx = k / group_size;
319+ const uint8_t packed_val = weights_4x2[n][k / 2 ].item ().to <uint8_t >();
320+ const uint8_t second_val = packed_val & 0x0F ;
321+ const uint8_t first_val = (packed_val & 0xF0 ) >> 4 ;
322+
323+ const float scale = weight_scales[group_idx][n].item ().to <float >();
324+ const int zero = weight_zeros[group_idx][n].item ().to <int >();
325+
326+ weights_dequantized[n][k] =
327+ ((float (first_val) - 8.0 ) - float (zero)) * scale;
328+ weights_dequantized[n][k + 1 ] =
329+ ((float (second_val) - 8.0 ) - float (zero)) * scale;
330+ }
331+ }
332+
333+ at::Tensor linear_result = at::linear (x_float, weights_dequantized);
334+
335+ return linear_result;
336+ }
337+
152338//
153339// Test functions
154340//
@@ -425,15 +611,15 @@ TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) {
425611 /* N = */ 256 );
426612}
427613
428- TEST (VulkanLinearQCS4WTest, test_reference_impl) {
614+ TEST_F (VulkanLinearQCS4WTest, test_reference_impl) {
429615 test_reference_linear_qcs4w (
430616 /* B = */ 1 ,
431617 /* M = */ 4 ,
432618 /* K = */ 128 ,
433619 /* N = */ 32 );
434620}
435621
436- TEST (VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
622+ TEST_F (VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
437623 test_vulkan_linear_qcs4w (
438624 /* B = */ 1 ,
439625 /* M = */ 4 ,
@@ -447,7 +633,7 @@ TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
447633 /* N = */ 256 );
448634}
449635
450- TEST (VulkanLinearQCS4WTest, test_vulkan_impl_gemm) {
636+ TEST_F (VulkanLinearQCS4WTest, test_vulkan_impl_gemm) {
451637 test_vulkan_linear_qcs4w (
452638 /* B = */ 1 ,
453639 /* M = */ 32 ,
0 commit comments