@@ -68,6 +68,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
6868 return MMQ_Q8_1_DS_LAYOUT_D4;
6969 case GGML_TYPE_MXFP4:
7070 return MMQ_Q8_1_DS_LAYOUT_D4;
71+ case GGML_TYPE_NVFP4:
72+ return MMQ_Q8_1_DS_LAYOUT_D4;
7173 case GGML_TYPE_Q2_K:
7274 return MMQ_Q8_1_DS_LAYOUT_D2S6;
7375 case GGML_TYPE_Q3_K:
@@ -189,6 +191,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
189191 case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
190192 case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
191193 case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
194+ case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16;
192195 case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
193196 case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
194197 case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
@@ -206,12 +209,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
206209 }
207210}
208211
209- #define MMQ_MMA_TILE_X_K_Q8_0 (2 *MMQ_TILE_NE_K + 2 *MMQ_TILE_NE_K/QI8_0 + 4 )
210- #define MMQ_MMA_TILE_X_K_FP4 (2 *MMQ_TILE_NE_K + 8 + 4 )
211- #define MMQ_MMA_TILE_X_K_Q8_1 (2 *MMQ_TILE_NE_K + 2 *MMQ_TILE_NE_K/QI8_0 + 4 )
212- #define MMQ_MMA_TILE_X_K_Q2_K (2 *MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4 )
213- #define MMQ_MMA_TILE_X_K_Q3_K (2 *MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4 )
214- #define MMQ_MMA_TILE_X_K_Q6_K (2 *MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7 )
212+ #define MMQ_MMA_TILE_X_K_Q8_0 (2 *MMQ_TILE_NE_K + 2 *MMQ_TILE_NE_K/QI8_0 + 4 )
213+ #define MMQ_MMA_TILE_X_K_FP4 (2 *MMQ_TILE_NE_K + 8 + 4 ) // MXFP4
214+ #define MMQ_MMA_TILE_X_K_NVFP4 (2 *MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4 ) // NVFP4
215+ #define MMQ_MMA_TILE_X_K_Q8_1 (2 *MMQ_TILE_NE_K + 2 *MMQ_TILE_NE_K/QI8_0 + 4 )
216+ #define MMQ_MMA_TILE_X_K_Q2_K (2 *MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4 )
217+ #define MMQ_MMA_TILE_X_K_Q3_K (2 *MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4 )
218+ #define MMQ_MMA_TILE_X_K_Q6_K (2 *MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7 )
215219
216220static_assert (MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4 , " Wrong padding." );
217221static_assert (MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4 , " Wrong padding." );
@@ -220,6 +224,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
220224static_assert (MMQ_MMA_TILE_X_K_Q6_K % 8 == 4 , " Wrong padding." );
221225static_assert (MMQ_MMA_TILE_X_K_FP4 % 8 == 4 , " Wrong padding." );
222226static_assert (MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, " Wrong tile size for MXFP4" );
227+ static_assert (MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4 , " Wrong padding." );
228+
223229
224230static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
225231 switch (type) {
@@ -230,6 +236,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
230236 case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
231237 // tile sizes are the same for Q8_1 and FP4 for blackwell
232238 case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
239+ case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
233240 case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
234241 case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
235242 case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -826,6 +833,65 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
826833 }
827834}
828835
836+
837+ template <int mmq_y, bool need_check>
838+ static __device__ __forceinline__ void load_tiles_nvfp4 (const char * __restrict__ x,
839+ int * __restrict__ x_tile,
840+ const int kb0,
841+ const int i_max,
842+ const int stride) {
843+ constexpr int nwarps = mmq_get_nwarps_device ();
844+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
845+
846+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
847+ int * x_qs = (int *) x_tile;
848+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2 );
849+ #else
850+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_NVFP4, mmq_y);
851+ int * x_qs = (int *) x_tile;
852+ float * x_df = (float *) (x_qs + txs.qs );
853+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
854+
855+ constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
856+ constexpr int rows_per_warp = warp_size / threads_per_row;
857+ const int kbx = threadIdx .x % threads_per_row;
858+ const int row_in_warp = threadIdx .x / threads_per_row;
859+
860+ #pragma unroll
861+ for (int i0 = 0 ; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
862+ int i = i0 + threadIdx .y * rows_per_warp + row_in_warp;
863+
864+ if constexpr (need_check) {
865+ i = min (i, i_max);
866+ }
867+
868+ const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
869+ const uint32_t * __restrict__ src_qs = reinterpret_cast <const uint32_t *>(bxi->qs );
870+ const int kqs = 16 * kbx;
871+ const int ksc = 4 * kbx;
872+
873+ #pragma unroll
874+ for (int sub = 0 ; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
875+ const int2 q0 = get_int_from_table_16 (src_qs[2 * sub + 0 ], kvalues_mxfp4);
876+ const int2 q1 = get_int_from_table_16 (src_qs[2 * sub + 1 ], kvalues_mxfp4);
877+
878+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
879+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0 ] = q0.x ;
880+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1 ] = q1.x ;
881+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2 ] = q0.y ;
882+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3 ] = q1.y ;
883+ x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32 (bxi->d [sub]);
884+ #else
885+ x_qs[i * (2 * MMQ_TILE_NE_K + 1 ) + kqs + 4 * sub + 0 ] = q0.x ;
886+ x_qs[i * (2 * MMQ_TILE_NE_K + 1 ) + kqs + 4 * sub + 1 ] = q1.x ;
887+ x_qs[i * (2 * MMQ_TILE_NE_K + 1 ) + kqs + 4 * sub + 2 ] = q0.y ;
888+ x_qs[i * (2 * MMQ_TILE_NE_K + 1 ) + kqs + 4 * sub + 3 ] = q1.y ;
889+ x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32 (bxi->d [sub]);
890+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
891+ }
892+ }
893+ }
894+
829895template <int mmq_x, int mmq_y>
830896static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a (
831897 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -1229,7 +1295,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
12291295#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
12301296}
12311297
1232- // Used for Q3_K, IQ2_S, and IQ2_XS
1298+ // Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
12331299template <int mmq_x, int mmq_y>
12341300static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a (
12351301 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -3261,6 +3327,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
32613327 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
32623328};
32633329
3330+ template <int mmq_x, int mmq_y, bool need_check>
3331+ struct mmq_type_traits <mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
3332+ static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
3333+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
3334+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3335+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
3336+ };
3337+
32643338template <int mmq_x, int mmq_y, bool need_check>
32653339struct mmq_type_traits <mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
32663340 static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
@@ -4069,6 +4143,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
40694143extern DECL_MMQ_CASE (GGML_TYPE_Q5_1);
40704144extern DECL_MMQ_CASE (GGML_TYPE_Q8_0);
40714145extern DECL_MMQ_CASE (GGML_TYPE_MXFP4);
4146+ extern DECL_MMQ_CASE (GGML_TYPE_NVFP4);
40724147extern DECL_MMQ_CASE (GGML_TYPE_Q2_K);
40734148extern DECL_MMQ_CASE (GGML_TYPE_Q3_K);
40744149extern DECL_MMQ_CASE (GGML_TYPE_Q4_K);
0 commit comments