Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,36 @@

#include <cassert>

class VulkanLinearQCS4WTest : public ::testing::Test {
public:
void SetUp() override {
if (!vkcompute::api::context()
->adapter_ptr()
->supports_int16_shader_types()) {
GTEST_SKIP();
}
}

void TearDown() override {
// Clean up any resources if needed
}
};

class VulkanLinearQTA8AQGA4WTest : public ::testing::Test {
public:
void SetUp() override {
if (!vkcompute::api::context()
->adapter_ptr()
->has_full_int8_buffers_support()) {
GTEST_SKIP();
}
}

void TearDown() override {
// Clean up any resources if needed
}
};

//
// Reference Implementations
//
Expand Down Expand Up @@ -149,6 +179,162 @@ at::Tensor linear_qcs4w_reference_impl(
return out.reshape(out_shape);
}

at::Tensor linear_qta8a_qga4w_quantized_matmul(
const at::Tensor& quantized_input, // [B, M, K] int8 quantized input
const at::Tensor& input_scale, // [B*M] per-token input scales
const at::Tensor& input_zero_point, // [B*M] per-token input zero points
const at::Tensor& weights_4x2, // [N, K/2] 4-bit packed weights
const int64_t group_size, // Group size for weight quantization
const at::Tensor& weight_scales, // [K/group_size, N] weight scales
const at::Tensor& weight_zeros) { // [K/group_size, N] weight zeros

const int64_t B = quantized_input.size(0);
const int64_t M = quantized_input.size(1);
const int64_t K = quantized_input.size(2);
const int64_t N = weights_4x2.size(0);

// Create output tensor for floating point results
at::Tensor float_output =
at::zeros({B, M, N}, at::device(at::kCPU).dtype(at::kFloat));

// Accessors for efficient access
auto input_accessor = quantized_input.accessor<int8_t, 3>();
auto output_accessor = float_output.accessor<float, 3>();
auto weights_accessor = weights_4x2.accessor<uint8_t, 2>();
auto weight_scales_accessor = weight_scales.accessor<float, 2>();
auto weight_zeros_accessor = weight_zeros.accessor<int32_t, 2>();
auto input_scale_accessor = input_scale.accessor<float, 1>();
auto input_zero_accessor = input_zero_point.accessor<int32_t, 1>();

// Perform quantized matrix multiplication following quantization.md equation
// (5): result_real_value = lhs_scale * rhs_scale * Sum_over_k(
// (lhs_quantized_value[k] - lhs_zero_point) *
// (rhs_quantized_value[k] - rhs_zero_point)
// )
for (int64_t b = 0; b < B; b++) {
for (int64_t m = 0; m < M; m++) {
const int64_t token_idx = b * M + m;
const float lhs_scale =
input_scale_accessor[token_idx]; // Per-token input scale
const int32_t lhs_zero_point =
input_zero_accessor[token_idx]; // Per-token input zero point

for (int64_t n = 0; n < N; n++) {
float result_real_value = 0.0f;

for (int64_t k = 0; k < K; k++) {
// Get per-group weight quantization parameters
const int64_t group_idx = k / group_size;
const float rhs_scale =
weight_scales_accessor[group_idx][n]; // Per-group weight scale
const int32_t rhs_zero_point =
weight_zeros_accessor[group_idx]
[n]; // Per-group weight zero point

// Unpack the 4-bit weight for this position
const uint8_t packed_val = weights_accessor[n][k / 2];
uint8_t weight_4bit;
if (k % 2 == 0) {
weight_4bit = (packed_val & 0xF0) >> 4; // First weight in pair
} else {
weight_4bit = packed_val & 0x0F; // Second weight in pair
}

// Get quantized values
const int32_t lhs_quantized_value =
static_cast<int32_t>(input_accessor[b][m][k]);
// Convert 4-bit weight to signed: subtract 8 to get range [-8, 7]
const int32_t rhs_quantized_value =
static_cast<int32_t>(weight_4bit) - 8;

// Apply proper quantization paradigm from quantization.md equation
// (3): real_value = scale * (quantized_value - zero_point) Following
// equation (5): result = lhs_scale * rhs_scale *
// (lhs_quantized - lhs_zero) * (rhs_quantized - rhs_zero)
const float lhs_diff =
static_cast<float>(lhs_quantized_value - lhs_zero_point);
const float rhs_diff =
static_cast<float>(rhs_quantized_value - rhs_zero_point);

result_real_value += lhs_scale * rhs_scale * lhs_diff * rhs_diff;
}

output_accessor[b][m][n] = result_real_value;
}
}
}

return float_output;
}

at::Tensor linear_qta8a_qga4w_4bit_dequant_impl(
const at::Tensor& quantized_input,
const at::Tensor& input_scale,
const at::Tensor& input_zero_point,
const at::Tensor& weights_4x2,
const int64_t group_size,
const at::Tensor& weight_scales,
const at::Tensor& weight_zeros) {
// Calculate number of input tokens
int64_t input_num_tokens = 1;
for (size_t i = 0; i < quantized_input.sizes().size() - 1; i++) {
input_num_tokens *= quantized_input.size(i);
}

// Manually dequantize the char tensor using per-token quantization
at::Tensor x_float = at::zeros_like(quantized_input, at::kFloat);

// Apply per-token dequantization
auto input_accessor = quantized_input.accessor<int8_t, 3>();
auto output_accessor = x_float.accessor<float, 3>();

for (int64_t token_idx = 0; token_idx < input_num_tokens; token_idx++) {
float scale_val = input_scale[token_idx].item<float>();
int zero_point_val = input_zero_point[token_idx].item<int>();

// Calculate batch and sequence indices for this token
int64_t b = token_idx / quantized_input.size(1);
int64_t m = token_idx % quantized_input.size(1);

// Apply dequantization for all features in this token
for (int64_t k = 0; k < quantized_input.size(-1); k++) {
float dequant_val =
(input_accessor[b][m][k] - zero_point_val) * scale_val;
output_accessor[b][m][k] = dequant_val;
}
}

std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
weights_shape[1] *= 2;

at::Tensor weights_dequantized =
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat));

const int64_t N = weights_dequantized.size(0);
const int64_t K = weights_dequantized.size(1);

for (int n = 0; n < N; n++) {
for (int k = 0; k < K; k += 2) {
const int group_idx = k / group_size;
const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
const uint8_t second_val = packed_val & 0x0F;
const uint8_t first_val = (packed_val & 0xF0) >> 4;

const float scale = weight_scales[group_idx][n].item().to<float>();
const int zero = weight_zeros[group_idx][n].item().to<int>();

weights_dequantized[n][k] =
((float(first_val) - 8.0) - float(zero)) * scale;
weights_dequantized[n][k + 1] =
((float(second_val) - 8.0) - float(zero)) * scale;
}
}

at::Tensor linear_result = at::linear(x_float, weights_dequantized);

return linear_result;
}

//
// Test functions
//
Expand Down Expand Up @@ -425,15 +611,15 @@ TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) {
/*N = */ 256);
}

TEST(VulkanLinearQCS4WTest, test_reference_impl) {
TEST_F(VulkanLinearQCS4WTest, test_reference_impl) {
test_reference_linear_qcs4w(
/*B = */ 1,
/*M = */ 4,
/*K = */ 128,
/*N = */ 32);
}

TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
test_vulkan_linear_qcs4w(
/*B = */ 1,
/*M = */ 4,
Expand All @@ -447,7 +633,7 @@ TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) {
/*N = */ 256);
}

TEST(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) {
TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) {
test_vulkan_linear_qcs4w(
/*B = */ 1,
/*M = */ 32,
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/test/op_tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def define_common_targets(is_fbcode = False):
]
)
define_test_targets(
"linear_weight_int4_test",
"quantized_linear_test",
extra_deps = [
":test_utils",
]
Expand Down
Loading