This repository was archived by the owner on May 11, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 23
3-bit dequantize to be tested #37
Copy link
Copy link
Open
Description
Hi @casper-hansen,
I just asked Claude to code the 3-bit kernel and it gave me this:
// 3-bit version
__inline__ __device__ void dequantize_s3_to_fp16x2(half2 const &source, uint4 *result)
{
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i3s = reinterpret_cast<uint32_t const &>(source);
// For 3-bit, we can pack 10 values in 32 bits (with 2 bits unused)
// Masks for extracting 3 bits
static constexpr uint32_t MASK_0_2 = 0x00000007; // First 3 bits
static constexpr uint32_t MASK_3_5 = 0x00000038; // Next 3 bits
static constexpr uint32_t MASK_6_8 = 0x000001C0; // Next 3 bits
static constexpr uint32_t MASK_9_11 = 0x00000E00; // Next 3 bits
static constexpr uint32_t I3s_TO_F16s_MAGIC_NUM = 0x64006400;
// Extract values using bit masks and shifts
const uint32_t top_i3s = i3s >> 12; // Shift to get upper values
// Extract each group of 3 bits and position them
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i3s), "n"(MASK_0_2), "n"(I3s_TO_F16s_MAGIC_NUM), "n"(0xf0));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i3s), "n"(MASK_3_5), "n"(I3s_TO_F16s_MAGIC_NUM), "n"(0xf0));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i3s), "n"(MASK_6_8), "n"(I3s_TO_F16s_MAGIC_NUM), "n"(0xf0));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i3s), "n"(MASK_9_11), "n"(I3s_TO_F16s_MAGIC_NUM), "n"(0xf0));
// Magic numbers for conversion
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint32_t ONE_EIGHTH = 0x2e002e00; // 1/8 for 3-bit range
static constexpr uint32_t NEG_32 = 0xd200d200; // -32 for 3-bit range
// Convert to final fp16 values
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_EIGHTH), "r"(NEG_32));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_EIGHTH), "r"(NEG_32));
}
Might be worth testing.
The issue is that in the code, PACK_NUM or PACK_FACTOR is not an arg at the moment but a constant.
In order to test it without to much dev, this need to be set PACK_FACTOR need to be 10 for 3 bit (instead of 8)
and the code needs to be adjusted to point to the new function above.
You are way more familiar than me with the code base so you may spend less time than me to test.
Cheers.
Metadata
Metadata
Assignees
Labels
No labels