Skip to content

Commit 84f82e8

Browse files
ggml-cuda: Add generic NVFP4 MMQ kernel (#21074)
* Introduced NVFP4 generic MMQ kernel * Added extra FP8 guard, hope to solve ci HIP failure * Rename tiles and use HIP_FP8_AVAILABLE * Removed remaning FP8 straggler and added const int * Const * Removed DECL_MMQ_CASE artifact * Removed newline * Removed space after else * Changed HIP FP8 NVFP4 conversion gate * Added new line to bottom of mmq.cu 270 * Removed extra spaces * Removed single space in front of else on line 814 * Added NVFP4 to generate cu script so HIP can see it, further tightened logic * Include generated mmq-instance-nvfp4.cu * Added NVFP4 mmq to HIP Check ignore list * Update ggml/src/ggml-cuda/mmq.cuh Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4 Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update ggml/src/ggml-cuda/mmq.cuh Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4 in tile assert Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update ggml/src/ggml-cuda/mmq.cuh Added function name ending for end if Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Added function names to closing endif Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
1 parent e1cb817 commit 84f82e8

File tree

7 files changed

+117
-19
lines changed

7 files changed

+117
-19
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -800,19 +800,32 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
800800
}
801801

802802
static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
803-
#ifdef FP8_AVAILABLE
804-
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
805-
#if defined(GGML_USE_HIP) && defined(CDNA3)
806-
// ROCm dose not support fp8 in software on devices with fp8 hardware,
803+
#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
804+
// ROCm does not support fp8 in software on devices with fp8 hardware,
807805
// but CDNA3 supports only e4m3_fnuz (no inf).
806+
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
808807
const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast<const __hip_fp8_e4m3_fnuz *>(&bits);
808+
return static_cast<float>(xf) / 2;
809809
#else
810+
#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
811+
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
810812
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
811-
#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP)
812813
return static_cast<float>(xf) / 2;
813814
#else
814-
NO_DEVICE_CODE;
815-
#endif // FP8_AVAILABLE
815+
if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f
816+
return 0.0f;
817+
}
818+
const int exp = (x >> 3) & 0xF;
819+
const int man = x & 0x7;
820+
float raw;
821+
if (exp == 0) {
822+
raw = ldexpf((float) man, -9);
823+
} else {
824+
raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
825+
}
826+
return static_cast<float>(raw / 2);
827+
#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
828+
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
816829
}
817830

818831
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4791,9 +4791,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
47914791
case GGML_TYPE_Q5_1:
47924792
case GGML_TYPE_Q8_0:
47934793
case GGML_TYPE_MXFP4:
4794-
#ifdef FP8_AVAILABLE
47954794
case GGML_TYPE_NVFP4:
4796-
#endif // FP8_AVAILABLE
47974795
case GGML_TYPE_Q2_K:
47984796
case GGML_TYPE_Q3_K:
47994797
case GGML_TYPE_Q4_K:

ggml/src/ggml-cuda/mmq.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
2323
case GGML_TYPE_MXFP4:
2424
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
2525
break;
26+
case GGML_TYPE_NVFP4:
27+
mul_mat_q_case<GGML_TYPE_NVFP4>(ctx, args, stream);
28+
break;
2629
case GGML_TYPE_Q2_K:
2730
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
2831
break;
@@ -273,6 +276,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
273276
case GGML_TYPE_Q5_1:
274277
case GGML_TYPE_Q8_0:
275278
case GGML_TYPE_MXFP4:
279+
case GGML_TYPE_NVFP4:
276280
case GGML_TYPE_Q2_K:
277281
case GGML_TYPE_Q3_K:
278282
case GGML_TYPE_Q4_K:
@@ -362,5 +366,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
362366
}
363367

364368
return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
365-
366369
}

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
6868
return MMQ_Q8_1_DS_LAYOUT_D4;
6969
case GGML_TYPE_MXFP4:
7070
return MMQ_Q8_1_DS_LAYOUT_D4;
71+
case GGML_TYPE_NVFP4:
72+
return MMQ_Q8_1_DS_LAYOUT_D4;
7173
case GGML_TYPE_Q2_K:
7274
return MMQ_Q8_1_DS_LAYOUT_D2S6;
7375
case GGML_TYPE_Q3_K:
@@ -189,6 +191,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
189191
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
190192
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
191193
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
194+
case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16;
192195
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
193196
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
194197
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
@@ -206,12 +209,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
206209
}
207210
}
208211

209-
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
210-
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
211-
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
212-
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
213-
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
214-
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
212+
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
213+
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4
214+
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4
215+
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
216+
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
217+
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
218+
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
215219

216220
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
217221
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
@@ -220,6 +224,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
220224
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
221225
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
222226
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
227+
static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
228+
223229

224230
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
225231
switch (type) {
@@ -230,6 +236,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
230236
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
231237
// tile sizes are the same for Q8_1 and FP4 for blackwell
232238
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
239+
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
233240
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
234241
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
235242
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -826,6 +833,65 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
826833
}
827834
}
828835

836+
837+
template <int mmq_y, bool need_check>
838+
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
839+
int * __restrict__ x_tile,
840+
const int kb0,
841+
const int i_max,
842+
const int stride) {
843+
constexpr int nwarps = mmq_get_nwarps_device();
844+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
845+
846+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
847+
int * x_qs = (int *) x_tile;
848+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
849+
#else
850+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y);
851+
int * x_qs = (int *) x_tile;
852+
float * x_df = (float *) (x_qs + txs.qs);
853+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
854+
855+
constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
856+
constexpr int rows_per_warp = warp_size / threads_per_row;
857+
const int kbx = threadIdx.x % threads_per_row;
858+
const int row_in_warp = threadIdx.x / threads_per_row;
859+
860+
#pragma unroll
861+
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
862+
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
863+
864+
if constexpr (need_check) {
865+
i = min(i, i_max);
866+
}
867+
868+
const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
869+
const uint32_t * __restrict__ src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
870+
const int kqs = 16 * kbx;
871+
const int ksc = 4 * kbx;
872+
873+
#pragma unroll
874+
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
875+
const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4);
876+
const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4);
877+
878+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
879+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x;
880+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x;
881+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y;
882+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y;
883+
x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
884+
#else
885+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x;
886+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x;
887+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y;
888+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y;
889+
x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
890+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
891+
}
892+
}
893+
}
894+
829895
template <int mmq_x, int mmq_y>
830896
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
831897
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -1229,7 +1295,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
12291295
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
12301296
}
12311297

1232-
// Used for Q3_K, IQ2_S, and IQ2_XS
1298+
// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
12331299
template <int mmq_x, int mmq_y>
12341300
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
12351301
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -3261,6 +3327,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
32613327
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
32623328
};
32633329

3330+
template <int mmq_x, int mmq_y, bool need_check>
3331+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
3332+
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
3333+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
3334+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3335+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
3336+
};
3337+
32643338
template <int mmq_x, int mmq_y, bool need_check>
32653339
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
32663340
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
@@ -4069,6 +4143,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
40694143
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
40704144
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
40714145
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
4146+
extern DECL_MMQ_CASE(GGML_TYPE_NVFP4);
40724147
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
40734148
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
40744149
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
3636
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
3737
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
38-
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
38+
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4"
3939
]
4040

4141
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_NVFP4);

scripts/hip/gcn-cdna-vgpr-check.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,11 @@ def main():
139139
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
140140
'_ZL18flash_attn_ext_vecILi128ELi2EL9ggml_type2ELS0_2ELb0EEvPKcS2_S2_S2_S2_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS6_IjLj3EEiiiiiiiiiiiliiliiiiil',
141141
'_ZL9mul_mat_qIL9ggml_type10ELi16ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
142-
'_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
142+
'_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
143+
'_ZL9mul_mat_qIL9ggml_type40ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
144+
'_ZL9mul_mat_qIL9ggml_type40ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
145+
'_ZL9mul_mat_qIL9ggml_type40ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
146+
'_ZL9mul_mat_qIL9ggml_type40ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
143147
}
144148

145149
functions = parse_log_file(log_file)

0 commit comments

Comments
 (0)