Skip to content

Commit fb28732

Browse files
author
horasal
committed
Change layout to keep consistency
1 parent bb9d978 commit fb28732

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,46 @@ static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_
481481
}
482482
}
483483

484+
template<typename dst_t>
485+
static __global__ void dequantize_block_mxfp6_e3m2(const void * __restrict__ vx, dst_t * __restrict__ yy) {
486+
487+
// QK_K(256) / QK_MXFP6_E3M2(32) = 8 blocks
488+
const int64_t i = blockIdx.x;
489+
const block_mxfp6_e3m2 * x = (const block_mxfp6_e3m2 *) vx + i*(QK_K/QK_MXFP6_E3M2);
490+
491+
const int64_t tid = threadIdx.x;
492+
// MXFP6 has 32 6-bit values, packed into 24 bytes.
493+
// For 4 thread model, each thread handle 24 / 4 = 6 bytes = 2 3-byte block
494+
// Each thread generate 8 values.
495+
const int64_t il = tid/8; // 0...3 -> 4 threads
496+
const int64_t ib = tid%8; // 0...7 -> each threads handle 2 (3 bytes) block
497+
dst_t * y = yy + i*QK_K + 32*ib;
498+
const uint8_t * qs = x[ib].qs;
499+
const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
500+
for (int g_idx = 0; g_idx < 2; ++g_idx) {
501+
const int g = 2 * il + g_idx;
502+
// input index -> 3 byte * current index
503+
const uint8_t* q3 = qs + 3*g;
504+
// output index -> 4 byte * current index
505+
const int y_offset = 4 * g;
506+
507+
const uint8_t b0 = q3[0];
508+
const uint8_t b1 = q3[1];
509+
const uint8_t b3 = q3[2];
510+
511+
const uint8_t v0_idx = b0 & 0x3F;
512+
const uint8_t v1_idx = (b0 >> 6) | ((b1 & 0x0F) << 2);
513+
const uint8_t v2_idx = (b1 >> 4) | ((b2 & 0x03) << 4);
514+
const uint8_t v3_idx = b2 >> 2;
515+
516+
y[y_offset + 0] = d * kvalues_mxfp6_e3m2[v0_idx]*0.0625f;
517+
y[y_offset + 1] = d * kvalues_mxfp6_e3m2[v1_idx]*0.0625f;
518+
y[y_offset + 2] = d * kvalues_mxfp6_e3m2[v2_idx]*0.0625f;
519+
y[y_offset + 3] = d * kvalues_mxfp6_e3m2[v3_idx]*0.0625f;
520+
}
521+
}
522+
523+
484524
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
485525
static void dequantize_block_cuda(const void * vx, dst_t * y,
486526
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -610,6 +650,12 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
610650
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
611651
}
612652

653+
template<typename dst_t>
654+
static void dequantize_row_mxfp6_e3m2_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
655+
const int nb = (k + QK_K - 1) / QK_K;
656+
dequantize_block_mxfp6_e3m2<<<nb, 32, 0, stream>>>(vx, y);
657+
}
658+
613659
template <typename src_t, typename dst_t>
614660
static __global__ void convert_unary(
615661
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -701,6 +747,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
701747
return dequantize_row_iq3_s_cuda;
702748
case GGML_TYPE_MXFP4:
703749
return dequantize_row_mxfp4_cuda;
750+
case GGML_TYPE_MXFP6_E3M2:
751+
return dequantize_row_mxfp6_e3m2_cuda;
704752
case GGML_TYPE_F32:
705753
return convert_unary_cont_cuda<float>;
706754
case GGML_TYPE_BF16:
@@ -752,6 +800,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
752800
return dequantize_row_iq3_s_cuda;
753801
case GGML_TYPE_MXFP4:
754802
return dequantize_row_mxfp4_cuda;
803+
case GGML_TYPE_MXFP6_E3M2:
804+
return dequantize_row_mxfp6_e3m2_cuda;
755805
case GGML_TYPE_F16:
756806
return convert_unary_cont_cuda<half>;
757807
case GGML_TYPE_BF16:

0 commit comments

Comments
 (0)