Skip to content

Commit 0c7f7f3

Browse files
committed
Review: use warp_size, fix should_use_mmf condition
1 parent 99acee9 commit 0c7f7f3

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,15 +2030,15 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20302030
const int cc = ggml_cuda_info().devices[id].cc;
20312031
const int warp_size = ggml_cuda_info().devices[id].warp_size;
20322032
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2033-
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne);
2033+
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
20342034
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
20352035
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
20362036
}
20372037
} else {
20382038
const int cc = ggml_cuda_info().devices[ctx.device].cc;
20392039
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
20402040
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2041-
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne);
2041+
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
20422042
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
20432043
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
20442044
}
@@ -2110,7 +2110,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
21102110
return;
21112111
}
21122112

2113-
if ( !ggml_is_quantized(src0->type ) && ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne, ids)) {
2113+
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], ids)) {
21142114
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
21152115
return;
21162116
}

ggml/src/ggml-cuda/mmf.cu

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ static __global__ void mul_mat_f(
3232

3333
if (ids) {
3434
int match = 0;
35-
for(int j0 = 0; j0 < cols_per_block; j0 += warpSize) {
35+
for(int j0 = 0; j0 < cols_per_block; j0 += warp_size) {
3636
const int j = j0 + threadIdx.x;
3737
if(j < cols_per_block) {
3838
match = ids[j*stride_row_id + channel_dst*stride_col_id] == expert_idx;
@@ -451,18 +451,23 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
451451
}
452452
}
453453

454-
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int64_t * src1_ne, const ggml_tensor * ids) {
454+
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, const ggml_tensor * ids) {
455+
456+
if (ggml_is_quantized(type)) {
457+
return false;
458+
}
459+
455460
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
456461
return false;
457462
}
458463
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
459464
return false;
460465
}
461-
if (!ids && src1_ne[1] > 16) {
466+
if (!ids && src1_ncols > 16) {
462467
return false;
463468
}
464469

465-
if (ids && src1_ne[2] > 16) {
470+
if (ids && src1_ncols > 16) {
466471
return false;
467472
}
468473

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
44

5-
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int64_t * src1_ne, const ggml_tensor * ids = nullptr);
5+
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, const ggml_tensor * ids = nullptr);

0 commit comments

Comments
 (0)