Skip to content

Commit 07d781e

Browse files
committed
cuda : add mxfp4 dequantization support for cuBLAS
ggml-ci
1 parent 1ea3769 commit 07d781e

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,24 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
465465
}
466466
}
467467

468+
template<typename dst_t>
469+
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
470+
471+
const int64_t i = blockIdx.x;
472+
const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
473+
474+
const int64_t tid = threadIdx.x;
475+
const int64_t il = tid/8; // 0...3
476+
const int64_t ib = tid%8; // 0...7
477+
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
478+
const uint8_t * q4 = x[ib].qs + 4*il;
479+
const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
480+
for (int j = 0; j < 4; ++j) {
481+
y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
482+
y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
483+
}
484+
}
485+
468486
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
469487
static void dequantize_block_cuda(const void * vx, dst_t * y,
470488
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -588,6 +606,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
588606
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
589607
}
590608

609+
template<typename dst_t>
610+
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
611+
const int nb = (k + QK_K - 1) / QK_K;
612+
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
613+
}
614+
591615
template <typename src_t, typename dst_t>
592616
static __global__ void convert_unary(
593617
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -677,6 +701,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
677701
return dequantize_row_iq4_xs_cuda;
678702
case GGML_TYPE_IQ3_S:
679703
return dequantize_row_iq3_s_cuda;
704+
case GGML_TYPE_MXFP4:
705+
return dequantize_row_mxfp4_cuda;
680706
case GGML_TYPE_F32:
681707
return convert_unary_cont_cuda<float>;
682708
case GGML_TYPE_BF16:
@@ -726,6 +752,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
726752
return dequantize_row_iq4_xs_cuda;
727753
case GGML_TYPE_IQ3_S:
728754
return dequantize_row_iq3_s_cuda;
755+
case GGML_TYPE_MXFP4:
756+
return dequantize_row_mxfp4_cuda;
729757
case GGML_TYPE_F16:
730758
return convert_unary_cont_cuda<half>;
731759
case GGML_TYPE_BF16:

0 commit comments

Comments
 (0)