Skip to content
This repository was archived by the owner on May 11, 2025. It is now read-only.

3-bit dequantize to be tested #37

@vince62s

Description

@vince62s

Hi @casper-hansen,

I just asked Claude to code the 3-bit kernel and it gave me this:

// 3-bit version
__inline__ __device__ void dequantize_s3_to_fp16x2(half2 const &source, uint4 *result) 
{
    uint32_t *h = reinterpret_cast<uint32_t *>(result);
    uint32_t const i3s = reinterpret_cast<uint32_t const &>(source);

    // For 3-bit, we can pack 10 values in 32 bits (with 2 bits unused)
    // Masks for extracting 3 bits
    static constexpr uint32_t MASK_0_2 = 0x00000007;   // First 3 bits
    static constexpr uint32_t MASK_3_5 = 0x00000038;   // Next 3 bits
    static constexpr uint32_t MASK_6_8 = 0x000001C0;   // Next 3 bits
    static constexpr uint32_t MASK_9_11 = 0x00000E00;  // Next 3 bits
    static constexpr uint32_t I3s_TO_F16s_MAGIC_NUM = 0x64006400;

    // Extract values using bit masks and shifts
    const uint32_t top_i3s = i3s >> 12;  // Shift to get upper values

    // Extract each group of 3 bits and position them
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
                 : "=r"(h[0])
                 : "r"(i3s), "n"(MASK_0_2), "n"(I3s_TO_F16s_MAGIC_NUM), "n"(0xf0));

    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
                 : "=r"(h[1])
                 : "r"(i3s), "n"(MASK_3_5), "n"(I3s_TO_F16s_MAGIC_NUM), "n"(0xf0));

    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
                 : "=r"(h[2])
                 : "r"(top_i3s), "n"(MASK_6_8), "n"(I3s_TO_F16s_MAGIC_NUM), "n"(0xf0));

    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
                 : "=r"(h[3])
                 : "r"(top_i3s), "n"(MASK_9_11), "n"(I3s_TO_F16s_MAGIC_NUM), "n"(0xf0));

    // Magic numbers for conversion
    static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
    static constexpr uint32_t ONE_EIGHTH = 0x2e002e00;     // 1/8 for 3-bit range
    static constexpr uint32_t NEG_32 = 0xd200d200;         // -32 for 3-bit range

    // Convert to final fp16 values
    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
    asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_EIGHTH), "r"(NEG_32));
    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
    asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_EIGHTH), "r"(NEG_32));
}

Might be worth testing.

The issue is that in the code, PACK_NUM or PACK_FACTOR is not an arg at the moment but a constant.
In order to test it without to much dev, this need to be set PACK_FACTOR need to be 10 for 3 bit (instead of 8)
and the code needs to be adjusted to point to the new function above.
You are way more familiar than me with the code base so you may spend less time than me to test.

Cheers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions