Skip to content

Commit 39afc51

Browse files
committed
Update on "[ET-VK] Fix implementation of int4 quantized linear"
## Context Fix the existing implementation of int4 weight quantized linear to conform with how the `_weight_int4packed_mm` op works in the ATen library. For some additional context, the current op implementation does not actually match the behaviour of `_weight_int4packed_mm`. The ATen op expects that the weights have already been packed into a specific format, with `inner_k_tiles` as a packing parameter. The packing is accomplished via calling the `_convert_weight_to_int4pack` operator. Thus the current implementation in vulkan is equivalent to calling `_convert_weight_to_int4pack` + `_weight_int4packed_mm`. To address this discrepancy, the operator implementation is registered under the `linear_weight_int4` custom op as of this diff. The problems with the existing implementation were as follows: * The expected sizes of the scales and zeros tensor was incorrect. Previously, the sizes were assumed to be `(2, N, num_groups)` but the correct size is `(num_groups, N, 2)` * Previously, when unpacking a uint8_t into 2 unpacked int4 values, it was assumed that the LSB was the first value and the MSB was the second value. However, this ordering should be flipped * The original implementation expected the output tensor to be channels packed, but in practice we want the output tensor to be width packed This diff addresses the above issues, and introduces a dedicated test binary to test against an equivalent reference implementation expressed with ATen functions. Differential Revision: [D64354773](https://our.internmc.facebook.com/intern/diff/D64354773/) [ghstack-poisoned]
2 parents 5376ef3 + 58ebdff commit 39afc51

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

backends/vulkan/test/op_tests/linear_weight_int4_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,15 +202,15 @@ void test_vulkan_linear_int4(
202202
ASSERT_TRUE(at::allclose(vk_out, out_ref, 1e-4, 1e-4));
203203
}
204204

205-
TEST(VulkanSDPATest, test_reference_impl) {
205+
TEST(VulkanInt4LinearTest, test_reference_impl) {
206206
test_reference_linear_int4(
207207
/*B = */ 1,
208208
/*M = */ 4,
209209
/*K = */ 128,
210210
/*N = */ 32);
211211
}
212212

213-
TEST(VulkanSDPATest, test_vulkan_impl) {
213+
TEST(VulkanInt4LinearTest, test_vulkan_impl) {
214214
if (!vkcompute::api::context()
215215
->adapter_ptr()
216216
->has_full_int8_buffers_support()) {

0 commit comments

Comments
 (0)