@@ -1362,26 +1362,46 @@ void launch_fattn_new_mma(
13621362 to_fp16 (K_data, K_f16.ptr , 1 , ggml_nelements (K), main_stream);
13631363 K_data = (char *) K_f16.ptr ;
13641364
1365- const size_t bs = ggml_blck_size (K->type );
1366- const size_t ts = ggml_type_size (K->type );
1367-
1368- nb11 = nb11*bs*sizeof (half)/ts;
1369- nb12 = nb12*bs*sizeof (half)/ts;
1370- nb13 = nb13*bs*sizeof (half)/ts;
1365+ nb11 = K->ne [0 ]*sizeof (half);
1366+ nb12 = nb11*K->ne [1 ];
1367+ nb13 = nb12*K->ne [2 ];
1368+
1369+ // Original PR in llama.cpp. I don't think that can work when K is not contiguous (e.g., nb11 > nb12), there are
1370+ // gaps between the rows, etc., as ggml_get_to_fp16_cuda stores into contiguous memory.
1371+ // const size_t bs = ggml_blck_size(K->type);
1372+ // const size_t ts = ggml_type_size(K->type);
1373+
1374+ // nb11 = nb11*bs*sizeof(half)/ts;
1375+ // nb12 = nb12*bs*sizeof(half)/ts;
1376+ // nb13 = nb13*bs*sizeof(half)/ts;
13711377 }
13721378
13731379 if (need_f16_V && V->type != GGML_TYPE_F16) {
1374- V_f16.alloc (ggml_nelements (V));
1375- to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (V->type );
1376- to_fp16 (V_data, V_f16.ptr , 1 , ggml_nelements (V), main_stream);
1377- V_data = (char *) V_f16.ptr ;
1378-
1379- const size_t bs = ggml_blck_size (V->type );
1380- const size_t ts = ggml_type_size (V->type );
1381-
1382- nb21 = nb21*bs*sizeof (half)/ts;
1383- nb22 = nb22*bs*sizeof (half)/ts;
1384- nb23 = nb23*bs*sizeof (half)/ts;
1380+ if constexpr (DV == 512 ) {
1381+ // DeepSeek. In this case the V cache is the same as the K cache, except that
1382+ // it has 512 elements per row instead of 576.
1383+ nb21 = nb11;
1384+ nb22 = nb12;
1385+ nb23 = nb13;
1386+ V_data = K_data;
1387+ } else {
1388+ V_f16.alloc (ggml_nelements (V));
1389+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (V->type );
1390+ to_fp16 (V_data, V_f16.ptr , 1 , ggml_nelements (V), main_stream);
1391+ V_data = (char *) V_f16.ptr ;
1392+
1393+ nb21 = K->ne [0 ]*sizeof (half);
1394+ nb22 = nb21*V->ne [1 ];
1395+ nb23 = nb22*V->ne [2 ];
1396+
1397+ // Original PR in llama.cpp. Same comment as above for the K cache.
1398+ // const size_t bs = ggml_blck_size(V->type);
1399+ // const size_t ts = ggml_type_size(V->type);
1400+
1401+ // nb21 = nb21*bs*sizeof(half)/ts;
1402+ // nb22 = nb22*bs*sizeof(half)/ts;
1403+ // nb23 = nb23*bs*sizeof(half)/ts;
1404+ }
13851405 }
13861406
13871407 int parallel_blocks = 1 ;
0 commit comments