Skip to content

Commit 080a880

Browse files
JohannesGaesslerqnixsynapse
authored andcommitted
CUDA: mul_mat_v support for batch sizes > 1 (ggml-org#14262)
* CUDA: mul_mat_v support for batch sizes > 1 * use 64 bit math for initial offset calculation
1 parent 3f064bb commit 080a880

File tree

3 files changed

+15
-31
lines changed

3 files changed

+15
-31
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ static bool fp16_mma_hardware_available(const int cc) {
262262
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
263263
}
264264

265+
static bool bf16_mma_hardware_available(const int cc) {
266+
return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE;
267+
}
268+
265269
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
266270
static bool new_mma_available(const int cc) {
267271
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,16 +1943,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19431943
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
19441944

19451945
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
1946-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1947-
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
1946+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
19481947
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
19491948
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
19501949
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
19511950
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
19521951
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
19531952

1954-
bool any_gpus_with_slow_fp16 = false;
1955-
bool any_gpus_without_fp16_mma = false;
1953+
bool any_gpus_with_slow_fp16 = false;
19561954

19571955
if (split) {
19581956
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@@ -1963,16 +1961,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19631961
continue;
19641962
}
19651963

1966-
const int cc = ggml_cuda_info().devices[id].cc;
1967-
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1968-
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1969-
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1964+
const int cc = ggml_cuda_info().devices[id].cc;
1965+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1966+
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
1967+
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
19701968
}
19711969
} else {
1972-
const int cc = ggml_cuda_info().devices[ctx.device].cc;
1973-
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1974-
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1975-
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1970+
const int cc = ggml_cuda_info().devices[ctx.device].cc;
1971+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1972+
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
1973+
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
19761974
}
19771975

19781976
// debug helpers
@@ -1983,7 +1981,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19831981
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
19841982
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
19851983

1986-
if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
1984+
if (!split && use_mul_mat_vec) {
19871985
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
19881986
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
19891987
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);

ggml/src/ggml-cuda/mmv.cu

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -456,11 +456,6 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
456456
return ne11 <= 4;
457457
}
458458
return ne11 <= 3;
459-
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
460-
if (fp32_mma_hardware_available(cc)) {
461-
return ne11 <= 3;
462-
}
463-
return ne11 <= 8;
464459
}
465460
return ne11 <= 8;
466461
case GGML_TYPE_F16:
@@ -473,14 +468,6 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
473468
return src0_small && ne11 <= 3;
474469
}
475470
return ne11 <= 8;
476-
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
477-
if (fp16_mma_hardware_available(cc)) {
478-
if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
479-
return ne11 <= 5;
480-
}
481-
return ne11 <= 2;
482-
}
483-
return ne11 <= 8;
484471
}
485472
return ne11 <= 8;
486473
case GGML_TYPE_BF16:
@@ -493,11 +480,6 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
493480
return src0_small && ne11 <= 3;
494481
}
495482
return ne11 <= 8;
496-
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
497-
if (bf16_mma_hardware_available(cc)) {
498-
return ne11 <= 3;
499-
}
500-
return ne11 <= 8;
501483
}
502484
return ne11 <= 8;
503485
default:

0 commit comments

Comments
 (0)