Skip to content

Commit 8777fc4

Browse files
ikawrakowIwan Kawrakow
andauthored
Fix CUDA FlashMLA-3 with quantized KV cache (#400)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 496451a commit 8777fc4

File tree

1 file changed

+37
-17
lines changed

1 file changed

+37
-17
lines changed

ggml/src/ggml-cuda/fattn-new-mma.cu

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,26 +1362,46 @@ void launch_fattn_new_mma(
13621362
to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream);
13631363
K_data = (char *) K_f16.ptr;
13641364

1365-
const size_t bs = ggml_blck_size(K->type);
1366-
const size_t ts = ggml_type_size(K->type);
1367-
1368-
nb11 = nb11*bs*sizeof(half)/ts;
1369-
nb12 = nb12*bs*sizeof(half)/ts;
1370-
nb13 = nb13*bs*sizeof(half)/ts;
1365+
nb11 = K->ne[0]*sizeof(half);
1366+
nb12 = nb11*K->ne[1];
1367+
nb13 = nb12*K->ne[2];
1368+
1369+
// Original PR in llama.cpp. I don't think that can work when K is not contiguous (e.g., nb11 > nb12), there are
1370+
// gaps between the rows, etc., as ggml_get_to_fp16_cuda stores into contiguous memory.
1371+
//const size_t bs = ggml_blck_size(K->type);
1372+
//const size_t ts = ggml_type_size(K->type);
1373+
1374+
//nb11 = nb11*bs*sizeof(half)/ts;
1375+
//nb12 = nb12*bs*sizeof(half)/ts;
1376+
//nb13 = nb13*bs*sizeof(half)/ts;
13711377
}
13721378

13731379
if (need_f16_V && V->type != GGML_TYPE_F16) {
1374-
V_f16.alloc(ggml_nelements(V));
1375-
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
1376-
to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);
1377-
V_data = (char *) V_f16.ptr;
1378-
1379-
const size_t bs = ggml_blck_size(V->type);
1380-
const size_t ts = ggml_type_size(V->type);
1381-
1382-
nb21 = nb21*bs*sizeof(half)/ts;
1383-
nb22 = nb22*bs*sizeof(half)/ts;
1384-
nb23 = nb23*bs*sizeof(half)/ts;
1380+
if constexpr (DV == 512) {
1381+
// DeepSeek. In this case the V cache is the same as the K cache, except that
1382+
// it has 512 elements per row instead of 576.
1383+
nb21 = nb11;
1384+
nb22 = nb12;
1385+
nb23 = nb13;
1386+
V_data = K_data;
1387+
} else {
1388+
V_f16.alloc(ggml_nelements(V));
1389+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
1390+
to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);
1391+
V_data = (char *) V_f16.ptr;
1392+
1393+
nb21 = K->ne[0]*sizeof(half);
1394+
nb22 = nb21*V->ne[1];
1395+
nb23 = nb22*V->ne[2];
1396+
1397+
// Original PR in llama.cpp. Same comment as above for the K cache.
1398+
//const size_t bs = ggml_blck_size(V->type);
1399+
//const size_t ts = ggml_type_size(V->type);
1400+
1401+
//nb21 = nb21*bs*sizeof(half)/ts;
1402+
//nb22 = nb22*bs*sizeof(half)/ts;
1403+
//nb23 = nb23*bs*sizeof(half)/ts;
1404+
}
13851405
}
13861406

13871407
int parallel_blocks = 1;

0 commit comments

Comments
 (0)