@@ -481,6 +481,46 @@ static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_
481481 }
482482}
483483
484+ template <typename dst_t >
485+ static __global__ void dequantize_block_mxfp6_e3m2 (const void * __restrict__ vx, dst_t * __restrict__ yy) {
486+
487+ // QK_K(256) / QK_MXFP6_E3M2(32) = 8 blocks
488+ const int64_t i = blockIdx .x ;
489+ const block_mxfp6_e3m2 * x = (const block_mxfp6_e3m2 *) vx + i*(QK_K/QK_MXFP6_E3M2);
490+
491+ const int64_t tid = threadIdx .x ;
492+ // MXFP6 has 32 6-bit values, packed into 24 bytes.
493+ // For 4 thread model, each thread handle 24 / 4 = 6 bytes = 2 3-byte block
494+ // Each thread generate 8 values.
495+ const int64_t il = tid/8 ; // 0...3 -> 4 threads
496+ const int64_t ib = tid%8 ; // 0...7 -> each threads handle 2 (3 bytes) block
497+ dst_t * y = yy + i*QK_K + 32 *ib;
498+ const uint8_t * qs = x[ib].qs ;
499+ const float d = ggml_cuda_e8m0_to_fp32 (x[ib].e );
500+ for (int g_idx = 0 ; g_idx < 2 ; ++g_idx) {
501+ const int g = 2 * il + g_idx;
502+ // input index -> 3 byte * current index
503+ const uint8_t * q3 = qs + 3 *g;
504+ // output index -> 4 byte * current index
505+ const int y_offset = 4 * g;
506+
507+ const uint8_t b0 = q3[0 ];
508+ const uint8_t b1 = q3[1 ];
509+ const uint8_t b3 = q3[2 ];
510+
511+ const uint8_t v0_idx = b0 & 0x3F ;
512+ const uint8_t v1_idx = (b0 >> 6 ) | ((b1 & 0x0F ) << 2 );
513+ const uint8_t v2_idx = (b1 >> 4 ) | ((b2 & 0x03 ) << 4 );
514+ const uint8_t v3_idx = b2 >> 2 ;
515+
516+ y[y_offset + 0 ] = d * kvalues_mxfp6_e3m2[v0_idx]*0 .0625f ;
517+ y[y_offset + 1 ] = d * kvalues_mxfp6_e3m2[v1_idx]*0 .0625f ;
518+ y[y_offset + 2 ] = d * kvalues_mxfp6_e3m2[v2_idx]*0 .0625f ;
519+ y[y_offset + 3 ] = d * kvalues_mxfp6_e3m2[v3_idx]*0 .0625f ;
520+ }
521+ }
522+
523+
484524template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
485525static void dequantize_block_cuda (const void * vx, dst_t * y,
486526 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -610,6 +650,12 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
610650 dequantize_block_mxfp4<<<nb, 32 , 0 , stream>>> (vx, y);
611651}
612652
653+ template <typename dst_t >
654+ static void dequantize_row_mxfp6_e3m2_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
655+ const int nb = (k + QK_K - 1 ) / QK_K;
656+ dequantize_block_mxfp6_e3m2<<<nb, 32 , 0 , stream>>> (vx, y);
657+ }
658+
613659template <typename src_t , typename dst_t >
614660static __global__ void convert_unary (
615661 const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -701,6 +747,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
701747 return dequantize_row_iq3_s_cuda;
702748 case GGML_TYPE_MXFP4:
703749 return dequantize_row_mxfp4_cuda;
750+ case GGML_TYPE_MXFP6_E3M2:
751+ return dequantize_row_mxfp6_e3m2_cuda;
704752 case GGML_TYPE_F32:
705753 return convert_unary_cont_cuda<float >;
706754 case GGML_TYPE_BF16:
@@ -752,6 +800,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
752800 return dequantize_row_iq3_s_cuda;
753801 case GGML_TYPE_MXFP4:
754802 return dequantize_row_mxfp4_cuda;
803+ case GGML_TYPE_MXFP6_E3M2:
804+ return dequantize_row_mxfp6_e3m2_cuda;
755805 case GGML_TYPE_F16:
756806 return convert_unary_cont_cuda<half>;
757807 case GGML_TYPE_BF16:
0 commit comments