1818
1919#include < cassert>
2020
21+ class VulkanLinearQCS4WTest : public ::testing::Test {
22+ public:
23+ void SetUp () override {
24+ if (!vkcompute::api::context ()
25+ ->adapter_ptr ()
26+ ->supports_int16_shader_types ()) {
27+ GTEST_SKIP ();
28+ }
29+ }
30+
31+ void TearDown () override {
32+ // Clean up any resources if needed
33+ }
34+ };
35+
36+ class VulkanLinearQTA8AQGA4WTest : public ::testing::Test {
37+ public:
38+ void SetUp () override {
39+ if (!vkcompute::api::context ()
40+ ->adapter_ptr ()
41+ ->has_full_int8_buffers_support ()) {
42+ GTEST_SKIP ();
43+ }
44+ }
45+
46+ void TearDown () override {
47+ // Clean up any resources if needed
48+ }
49+ };
50+
2151//
2252// Reference Implementations
2353//
@@ -149,6 +179,163 @@ at::Tensor linear_qcs4w_reference_impl(
149179 return out.reshape (out_shape);
150180}
151181
182+ // Quantized matrix multiplication following quantization.md paradigms
183+ at::Tensor linear_qta8a_qga4w_quantized_matmul (
184+ const at::Tensor& quantized_input, // [B, M, K] int8 quantized input
185+ const at::Tensor& input_scale, // [B*M] per-token input scales
186+ const at::Tensor& input_zero_point, // [B*M] per-token input zero points
187+ const at::Tensor& weights_4x2, // [N, K/2] 4-bit packed weights
188+ const int64_t group_size, // Group size for weight quantization
189+ const at::Tensor& weight_scales, // [K/group_size, N] weight scales
190+ const at::Tensor& weight_zeros) { // [K/group_size, N] weight zeros
191+
192+ const int64_t B = quantized_input.size (0 );
193+ const int64_t M = quantized_input.size (1 );
194+ const int64_t K = quantized_input.size (2 );
195+ const int64_t N = weights_4x2.size (0 );
196+
197+ // Create output tensor for floating point results
198+ at::Tensor float_output =
199+ at::zeros ({B, M, N}, at::device (at::kCPU ).dtype (at::kFloat ));
200+
201+ // Accessors for efficient access
202+ auto input_accessor = quantized_input.accessor <int8_t , 3 >();
203+ auto output_accessor = float_output.accessor <float , 3 >();
204+ auto weights_accessor = weights_4x2.accessor <uint8_t , 2 >();
205+ auto weight_scales_accessor = weight_scales.accessor <float , 2 >();
206+ auto weight_zeros_accessor = weight_zeros.accessor <int32_t , 2 >();
207+ auto input_scale_accessor = input_scale.accessor <float , 1 >();
208+ auto input_zero_accessor = input_zero_point.accessor <int32_t , 1 >();
209+
210+ // Perform quantized matrix multiplication following quantization.md equation
211+ // (5): result_real_value = lhs_scale * rhs_scale * Sum_over_k(
212+ // (lhs_quantized_value[k] - lhs_zero_point) *
213+ // (rhs_quantized_value[k] - rhs_zero_point)
214+ // )
215+ for (int64_t b = 0 ; b < B; b++) {
216+ for (int64_t m = 0 ; m < M; m++) {
217+ const int64_t token_idx = b * M + m;
218+ const float lhs_scale =
219+ input_scale_accessor[token_idx]; // Per-token input scale
220+ const int32_t lhs_zero_point =
221+ input_zero_accessor[token_idx]; // Per-token input zero point
222+
223+ for (int64_t n = 0 ; n < N; n++) {
224+ float result_real_value = 0 .0f ;
225+
226+ for (int64_t k = 0 ; k < K; k++) {
227+ // Get per-group weight quantization parameters
228+ const int64_t group_idx = k / group_size;
229+ const float rhs_scale =
230+ weight_scales_accessor[group_idx][n]; // Per-group weight scale
231+ const int32_t rhs_zero_point =
232+ weight_zeros_accessor[group_idx]
233+ [n]; // Per-group weight zero point
234+
235+ // Unpack the 4-bit weight for this position
236+ const uint8_t packed_val = weights_accessor[n][k / 2 ];
237+ uint8_t weight_4bit;
238+ if (k % 2 == 0 ) {
239+ weight_4bit = (packed_val & 0xF0 ) >> 4 ; // First weight in pair
240+ } else {
241+ weight_4bit = packed_val & 0x0F ; // Second weight in pair
242+ }
243+
244+ // Get quantized values
245+ const int32_t lhs_quantized_value =
246+ static_cast <int32_t >(input_accessor[b][m][k]);
247+ // Convert 4-bit weight to signed: subtract 8 to get range [-8, 7]
248+ const int32_t rhs_quantized_value =
249+ static_cast <int32_t >(weight_4bit) - 8 ;
250+
251+ // Apply proper quantization paradigm from quantization.md equation
252+ // (3): real_value = scale * (quantized_value - zero_point) Following
253+ // equation (5): result = lhs_scale * rhs_scale *
254+ // (lhs_quantized - lhs_zero) * (rhs_quantized - rhs_zero)
255+ const float lhs_diff =
256+ static_cast <float >(lhs_quantized_value - lhs_zero_point);
257+ const float rhs_diff =
258+ static_cast <float >(rhs_quantized_value - rhs_zero_point);
259+
260+ result_real_value += lhs_scale * rhs_scale * lhs_diff * rhs_diff;
261+ }
262+
263+ output_accessor[b][m][n] = result_real_value;
264+ }
265+ }
266+ }
267+
268+ return float_output;
269+ }
270+
271+ at::Tensor linear_qta8a_qga4w_4bit_dequant_impl (
272+ const at::Tensor& quantized_input,
273+ const at::Tensor& input_scale,
274+ const at::Tensor& input_zero_point,
275+ const at::Tensor& weights_4x2,
276+ const int64_t group_size,
277+ const at::Tensor& weight_scales,
278+ const at::Tensor& weight_zeros) {
279+ // Calculate number of input tokens
280+ int64_t input_num_tokens = 1 ;
281+ for (size_t i = 0 ; i < quantized_input.sizes ().size () - 1 ; i++) {
282+ input_num_tokens *= quantized_input.size (i);
283+ }
284+
285+ // Manually dequantize the char tensor using per-token quantization
286+ at::Tensor x_float = at::zeros_like (quantized_input, at::kFloat );
287+
288+ // Apply per-token dequantization
289+ auto input_accessor = quantized_input.accessor <int8_t , 3 >();
290+ auto output_accessor = x_float.accessor <float , 3 >();
291+
292+ for (int64_t token_idx = 0 ; token_idx < input_num_tokens; token_idx++) {
293+ float scale_val = input_scale[token_idx].item <float >();
294+ int zero_point_val = input_zero_point[token_idx].item <int >();
295+
296+ // Calculate batch and sequence indices for this token
297+ int64_t b = token_idx / quantized_input.size (1 );
298+ int64_t m = token_idx % quantized_input.size (1 );
299+
300+ // Apply dequantization for all features in this token
301+ for (int64_t k = 0 ; k < quantized_input.size (-1 ); k++) {
302+ float dequant_val =
303+ (input_accessor[b][m][k] - zero_point_val) * scale_val;
304+ output_accessor[b][m][k] = dequant_val;
305+ }
306+ }
307+
308+ std::vector<int64_t > weights_shape (weights_4x2.sizes ().vec ());
309+ weights_shape[1 ] *= 2 ;
310+
311+ at::Tensor weights_dequantized =
312+ at::empty (weights_shape, at::device (at::kCPU ).dtype (at::kFloat ));
313+
314+ const int64_t N = weights_dequantized.size (0 );
315+ const int64_t K = weights_dequantized.size (1 );
316+
317+ for (int n = 0 ; n < N; n++) {
318+ for (int k = 0 ; k < K; k += 2 ) {
319+ const int group_idx = k / group_size;
320+ const uint8_t packed_val = weights_4x2[n][k / 2 ].item ().to <uint8_t >();
321+ const uint8_t second_val = packed_val & 0x0F ;
322+ const uint8_t first_val = (packed_val & 0xF0 ) >> 4 ;
323+
324+ const float scale = weight_scales[group_idx][n].item ().to <float >();
325+ const int zero = weight_zeros[group_idx][n].item ().to <int >();
326+
327+ weights_dequantized[n][k] =
328+ ((float (first_val) - 8.0 ) - float (zero)) * scale;
329+ weights_dequantized[n][k + 1 ] =
330+ ((float (second_val) - 8.0 ) - float (zero)) * scale;
331+ }
332+ }
333+
334+ at::Tensor linear_result = at::linear (x_float, weights_dequantized);
335+
336+ return linear_result;
337+ }
338+
152339//
153340// Test functions
154341//
@@ -425,15 +612,15 @@ TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) {
425612 /* N = */ 256 );
426613}
427614
428- TEST (VulkanLinearQCS4WTest, test_reference_impl) {
615+ TEST_F (VulkanLinearQCS4WTest, test_reference_impl) {
429616 test_reference_linear_qcs4w (
430617 /* B = */ 1 ,
431618 /* M = */ 4 ,
432619 /* K = */ 128 ,
433620 /* N = */ 32 );
434621}
435622
436- TEST (VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
623+ TEST_F (VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
437624 test_vulkan_linear_qcs4w (
438625 /* B = */ 1 ,
439626 /* M = */ 4 ,
@@ -447,7 +634,7 @@ TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
447634 /* N = */ 256 );
448635}
449636
450- TEST (VulkanLinearQCS4WTest, test_vulkan_impl_gemm) {
637+ TEST_F (VulkanLinearQCS4WTest, test_vulkan_impl_gemm) {
451638 test_vulkan_linear_qcs4w (
452639 /* B = */ 1 ,
453640 /* M = */ 32 ,
0 commit comments