diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/quantized_linear_test.cpp similarity index 64% rename from backends/vulkan/test/op_tests/linear_weight_int4_test.cpp rename to backends/vulkan/test/op_tests/quantized_linear_test.cpp index e48042c4620..108770bb02e 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/quantized_linear_test.cpp @@ -18,6 +18,36 @@ #include +class VulkanLinearQCS4WTest : public ::testing::Test { + public: + void SetUp() override { + if (!vkcompute::api::context() + ->adapter_ptr() + ->supports_int16_shader_types()) { + GTEST_SKIP(); + } + } + + void TearDown() override { + // Clean up any resources if needed + } +}; + +class VulkanLinearQTA8AQGA4WTest : public ::testing::Test { + public: + void SetUp() override { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + } + + void TearDown() override { + // Clean up any resources if needed + } +}; + // // Reference Implementations // @@ -149,6 +179,162 @@ at::Tensor linear_qcs4w_reference_impl( return out.reshape(out_shape); } +at::Tensor linear_qta8a_qga4w_quantized_matmul( + const at::Tensor& quantized_input, // [B, M, K] int8 quantized input + const at::Tensor& input_scale, // [B*M] per-token input scales + const at::Tensor& input_zero_point, // [B*M] per-token input zero points + const at::Tensor& weights_4x2, // [N, K/2] 4-bit packed weights + const int64_t group_size, // Group size for weight quantization + const at::Tensor& weight_scales, // [K/group_size, N] weight scales + const at::Tensor& weight_zeros) { // [K/group_size, N] weight zeros + + const int64_t B = quantized_input.size(0); + const int64_t M = quantized_input.size(1); + const int64_t K = quantized_input.size(2); + const int64_t N = weights_4x2.size(0); + + // Create output tensor for floating point results + at::Tensor float_output = + at::zeros({B, M, N}, at::device(at::kCPU).dtype(at::kFloat)); + + // Accessors for efficient access + auto input_accessor = quantized_input.accessor(); + auto output_accessor = float_output.accessor(); + auto weights_accessor = weights_4x2.accessor(); + auto weight_scales_accessor = weight_scales.accessor(); + auto weight_zeros_accessor = weight_zeros.accessor(); + auto input_scale_accessor = input_scale.accessor(); + auto input_zero_accessor = input_zero_point.accessor(); + + // Perform quantized matrix multiplication following quantization.md equation + // (5): result_real_value = lhs_scale * rhs_scale * Sum_over_k( + // (lhs_quantized_value[k] - lhs_zero_point) * + // (rhs_quantized_value[k] - rhs_zero_point) + // ) + for (int64_t b = 0; b < B; b++) { + for (int64_t m = 0; m < M; m++) { + const int64_t token_idx = b * M + m; + const float lhs_scale = + input_scale_accessor[token_idx]; // Per-token input scale + const int32_t lhs_zero_point = + input_zero_accessor[token_idx]; // Per-token input zero point + + for (int64_t n = 0; n < N; n++) { + float result_real_value = 0.0f; + + for (int64_t k = 0; k < K; k++) { + // Get per-group weight quantization parameters + const int64_t group_idx = k / group_size; + const float rhs_scale = + weight_scales_accessor[group_idx][n]; // Per-group weight scale + const int32_t rhs_zero_point = + weight_zeros_accessor[group_idx] + [n]; // Per-group weight zero point + + // Unpack the 4-bit weight for this position + const uint8_t packed_val = weights_accessor[n][k / 2]; + uint8_t weight_4bit; + if (k % 2 == 0) { + weight_4bit = (packed_val & 0xF0) >> 4; // First weight in pair + } else { + weight_4bit = packed_val & 0x0F; // Second weight in pair + } + + // Get quantized values + const int32_t lhs_quantized_value = + static_cast(input_accessor[b][m][k]); + // Convert 4-bit weight to signed: subtract 8 to get range [-8, 7] + const int32_t rhs_quantized_value = + static_cast(weight_4bit) - 8; + + // Apply proper quantization paradigm from quantization.md equation + // (3): real_value = scale * (quantized_value - zero_point) Following + // equation (5): result = lhs_scale * rhs_scale * + // (lhs_quantized - lhs_zero) * (rhs_quantized - rhs_zero) + const float lhs_diff = + static_cast(lhs_quantized_value - lhs_zero_point); + const float rhs_diff = + static_cast(rhs_quantized_value - rhs_zero_point); + + result_real_value += lhs_scale * rhs_scale * lhs_diff * rhs_diff; + } + + output_accessor[b][m][n] = result_real_value; + } + } + } + + return float_output; +} + +at::Tensor linear_qta8a_qga4w_4bit_dequant_impl( + const at::Tensor& quantized_input, + const at::Tensor& input_scale, + const at::Tensor& input_zero_point, + const at::Tensor& weights_4x2, + const int64_t group_size, + const at::Tensor& weight_scales, + const at::Tensor& weight_zeros) { + // Calculate number of input tokens + int64_t input_num_tokens = 1; + for (size_t i = 0; i < quantized_input.sizes().size() - 1; i++) { + input_num_tokens *= quantized_input.size(i); + } + + // Manually dequantize the char tensor using per-token quantization + at::Tensor x_float = at::zeros_like(quantized_input, at::kFloat); + + // Apply per-token dequantization + auto input_accessor = quantized_input.accessor(); + auto output_accessor = x_float.accessor(); + + for (int64_t token_idx = 0; token_idx < input_num_tokens; token_idx++) { + float scale_val = input_scale[token_idx].item(); + int zero_point_val = input_zero_point[token_idx].item(); + + // Calculate batch and sequence indices for this token + int64_t b = token_idx / quantized_input.size(1); + int64_t m = token_idx % quantized_input.size(1); + + // Apply dequantization for all features in this token + for (int64_t k = 0; k < quantized_input.size(-1); k++) { + float dequant_val = + (input_accessor[b][m][k] - zero_point_val) * scale_val; + output_accessor[b][m][k] = dequant_val; + } + } + + std::vector weights_shape(weights_4x2.sizes().vec()); + weights_shape[1] *= 2; + + at::Tensor weights_dequantized = + at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat)); + + const int64_t N = weights_dequantized.size(0); + const int64_t K = weights_dequantized.size(1); + + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k += 2) { + const int group_idx = k / group_size; + const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); + const uint8_t second_val = packed_val & 0x0F; + const uint8_t first_val = (packed_val & 0xF0) >> 4; + + const float scale = weight_scales[group_idx][n].item().to(); + const int zero = weight_zeros[group_idx][n].item().to(); + + weights_dequantized[n][k] = + ((float(first_val) - 8.0) - float(zero)) * scale; + weights_dequantized[n][k + 1] = + ((float(second_val) - 8.0) - float(zero)) * scale; + } + } + + at::Tensor linear_result = at::linear(x_float, weights_dequantized); + + return linear_result; +} + // // Test functions // @@ -425,7 +611,7 @@ TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) { /*N = */ 256); } -TEST(VulkanLinearQCS4WTest, test_reference_impl) { +TEST_F(VulkanLinearQCS4WTest, test_reference_impl) { test_reference_linear_qcs4w( /*B = */ 1, /*M = */ 4, @@ -433,7 +619,7 @@ TEST(VulkanLinearQCS4WTest, test_reference_impl) { /*N = */ 32); } -TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) { +TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) { test_vulkan_linear_qcs4w( /*B = */ 1, /*M = */ 4, @@ -447,7 +633,7 @@ TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) { /*N = */ 256); } -TEST(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) { +TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) { test_vulkan_linear_qcs4w( /*B = */ 1, /*M = */ 32, diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 0d014c7ef29..9eac90ac33d 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -205,7 +205,7 @@ def define_common_targets(is_fbcode = False): ] ) define_test_targets( - "linear_weight_int4_test", + "quantized_linear_test", extra_deps = [ ":test_utils", ]