@@ -465,6 +465,24 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
465465 }
466466}
467467
468+ template <typename dst_t >
469+ static __global__ void dequantize_block_mxfp4 (const void * __restrict__ vx, dst_t * __restrict__ yy) {
470+
471+ const int64_t i = blockIdx .x ;
472+ const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
473+
474+ const int64_t tid = threadIdx .x ;
475+ const int64_t il = tid/8 ; // 0...3
476+ const int64_t ib = tid%8 ; // 0...7
477+ dst_t * y = yy + i*QK_K + 32 *ib + 4 *il;
478+ const uint8_t * q4 = x[ib].qs + 4 *il;
479+ const float d = ggml_cuda_e8m0_to_fp32 (x[ib].e );
480+ for (int j = 0 ; j < 4 ; ++j) {
481+ y[j+ 0 ] = d * kvalues_mxfp4[q4[j] & 0xf ]*0 .5f ;
482+ y[j+16 ] = d * kvalues_mxfp4[q4[j] >> 4 ]*0 .5f ;
483+ }
484+ }
485+
468486template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
469487static void dequantize_block_cuda (const void * vx, dst_t * y,
470488 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -588,6 +606,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
588606 dequantize_block_iq4_xs<<<nb, 32 , 0 , stream>>> (vx, y);
589607}
590608
609+ template <typename dst_t >
610+ static void dequantize_row_mxfp4_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
611+ const int nb = (k + QK_K - 1 ) / QK_K;
612+ dequantize_block_mxfp4<<<nb, 32 , 0 , stream>>> (vx, y);
613+ }
614+
591615template <typename src_t , typename dst_t >
592616static __global__ void convert_unary (
593617 const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -677,6 +701,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
677701 return dequantize_row_iq4_xs_cuda;
678702 case GGML_TYPE_IQ3_S:
679703 return dequantize_row_iq3_s_cuda;
704+ case GGML_TYPE_MXFP4:
705+ return dequantize_row_mxfp4_cuda;
680706 case GGML_TYPE_F32:
681707 return convert_unary_cont_cuda<float >;
682708 case GGML_TYPE_BF16:
@@ -726,6 +752,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
726752 return dequantize_row_iq4_xs_cuda;
727753 case GGML_TYPE_IQ3_S:
728754 return dequantize_row_iq3_s_cuda;
755+ case GGML_TYPE_MXFP4:
756+ return dequantize_row_mxfp4_cuda;
729757 case GGML_TYPE_F16:
730758 return convert_unary_cont_cuda<half>;
731759 case GGML_TYPE_BF16:
0 commit comments