Skip to content

Commit 1d72c84

Browse files
CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 (#15131)
* CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16
1 parent 20638e4 commit 1d72c84

File tree

13 files changed

+750
-225
lines changed

13 files changed

+750
-225
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,13 @@ typedef float2 dfloat2;
233233
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
234234

235235
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
236-
#define NEW_MMA_AVAILABLE
236+
#define TURING_MMA_AVAILABLE
237237
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
238238

239+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
240+
#define AMPERE_MMA_AVAILABLE
241+
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
242+
239243
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
240244
#define CP_ASYNC_AVAILABLE
241245
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -303,10 +307,14 @@ static bool amd_mfma_available(const int cc) {
303307
}
304308

305309
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
306-
static bool new_mma_available(const int cc) {
310+
static bool turing_mma_available(const int cc) {
307311
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
308312
}
309313

314+
static bool ampere_mma_available(const int cc) {
315+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
316+
}
317+
310318
static bool cp_async_available(const int cc) {
311319
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
312320
}

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
418418
float * const __restrict__ KQ_max,
419419
float * const __restrict__ KQ_rowsum,
420420
const int kb0) {
421-
#ifdef NEW_MMA_AVAILABLE
421+
#ifdef TURING_MMA_AVAILABLE
422422
typedef fattn_mma_f16_config<DKQ, DV> c;
423423

424424
#ifdef CP_ASYNC_AVAILABLE
@@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
776776
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
777777
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
778778
NO_DEVICE_CODE;
779-
#endif // NEW_MMA_AVAILABLE
779+
#endif // TURING_MMA_AVAILABLE
780780
}
781781

782782
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
@@ -800,7 +800,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
800800
const int jt,
801801
const int kb0_start,
802802
const int kb0_stop) {
803-
#ifdef NEW_MMA_AVAILABLE
803+
#ifdef TURING_MMA_AVAILABLE
804804
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
805805

806806
typedef fattn_mma_f16_config<DKQ, DV> c;
@@ -1196,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
11961196
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
11971197
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
11981198
NO_DEVICE_CODE;
1199-
#endif // NEW_MMA_AVAILABLE
1199+
#endif // TURING_MMA_AVAILABLE
12001200
}
12011201

12021202
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
@@ -1223,7 +1223,7 @@ static __global__ void flash_attn_ext_f16(
12231223
const int32_t nb21, const int32_t nb22, const int64_t nb23,
12241224
const int32_t ne31, const int32_t ne32, const int32_t ne33,
12251225
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1226-
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1226+
#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
12271227

12281228
// Skip unused kernel variants for faster compilation:
12291229
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1354,7 +1354,7 @@ static __global__ void flash_attn_ext_f16(
13541354
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
13551355
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
13561356
NO_DEVICE_CODE;
1357-
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1357+
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
13581358
}
13591359

13601360
template <int DKQ, int DV, int ncols1, int ncols2>

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
327327
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
328328
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
329329
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
330-
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
330+
const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
331331
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
332332
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
333333
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
@@ -340,7 +340,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
340340
}
341341

342342
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
343-
if (fp16_mma_available(cc) && !new_mma_available(cc)) {
343+
if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
344344
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
345345
return;
346346
}

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
#include "ggml-cuda/fattn.cuh"
2323
#include "ggml-cuda/getrows.cuh"
2424
#include "ggml-cuda/im2col.cuh"
25+
#include "ggml-cuda/mmf.cuh"
2526
#include "ggml-cuda/mmq.cuh"
26-
#include "ggml-cuda/mmv.cuh"
27+
#include "ggml-cuda/mmvf.cuh"
2728
#include "ggml-cuda/mmvq.cuh"
2829
#include "ggml-cuda/norm.cuh"
2930
#include "ggml-cuda/opt-step-adamw.cuh"
@@ -2008,7 +2009,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20082009
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
20092010
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
20102011

2011-
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
2012+
bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
2013+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2014+
bool use_mul_mat_f = !ggml_is_quantized(src0->type)
20122015
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
20132016
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
20142017
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
@@ -2028,14 +2031,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20282031
}
20292032

20302033
const int cc = ggml_cuda_info().devices[id].cc;
2034+
const int warp_size = ggml_cuda_info().devices[id].warp_size;
20312035
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2032-
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
2036+
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
2037+
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
20332038
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
20342039
}
20352040
} else {
20362041
const int cc = ggml_cuda_info().devices[ctx.device].cc;
2042+
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
20372043
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2038-
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
2044+
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
2045+
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
20392046
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
20402047
}
20412048

@@ -2048,15 +2055,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20482055
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
20492056

20502057
//TODO update for generic tensor parallelism
2051-
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2058+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
20522059
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
20532060
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
20542061
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
20552062

2056-
if (!split && use_mul_mat_vec) {
2063+
if (!split && use_mul_mat_vec_f) {
20572064
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
20582065
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
2059-
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
2066+
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
2067+
} else if (!split && use_mul_mat_f) {
2068+
ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
20602069
} else if (!split && use_mul_mat_vec_q) {
20612070
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
20622071
} else if (!split && use_mul_mat_q) {
@@ -2065,8 +2074,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20652074
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
20662075
// general KQ + KQV multi-batch without FlashAttention
20672076
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
2068-
} else if (use_mul_mat_vec) {
2069-
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
2077+
} else if (use_mul_mat_vec_f) {
2078+
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
20702079
} else if (use_mul_mat_vec_q) {
20712080
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
20722081
} else if (use_mul_mat_q) {
@@ -2094,7 +2103,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
20942103
if (ggml_is_quantized(src0->type)) {
20952104
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
20962105
} else {
2097-
ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
2106+
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
20982107
}
20992108
return;
21002109
}
@@ -3516,7 +3525,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
35163525
#endif // FLASH_ATTN_AVAILABLE
35173526
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
35183527
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3519-
if (!new_mma_available(cc)) {
3528+
if (!turing_mma_available(cc)) {
35203529
return false;
35213530
}
35223531
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];

0 commit comments

Comments
 (0)