Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions ggml/src/ggml-cuda/fattn-new-mma.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1362,26 +1362,46 @@ void launch_fattn_new_mma(
to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream);
K_data = (char *) K_f16.ptr;

const size_t bs = ggml_blck_size(K->type);
const size_t ts = ggml_type_size(K->type);

nb11 = nb11*bs*sizeof(half)/ts;
nb12 = nb12*bs*sizeof(half)/ts;
nb13 = nb13*bs*sizeof(half)/ts;
nb11 = K->ne[0]*sizeof(half);
nb12 = nb11*K->ne[1];
nb13 = nb12*K->ne[2];

// Original PR in llama.cpp. I don't think that can work when K is not contiguous (e.g., nb11 > nb12), there are
// gaps between the rows, etc., as ggml_get_to_fp16_cuda stores into contiguous memory.
//const size_t bs = ggml_blck_size(K->type);
//const size_t ts = ggml_type_size(K->type);

//nb11 = nb11*bs*sizeof(half)/ts;
//nb12 = nb12*bs*sizeof(half)/ts;
//nb13 = nb13*bs*sizeof(half)/ts;
}

if (need_f16_V && V->type != GGML_TYPE_F16) {
V_f16.alloc(ggml_nelements(V));
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);
V_data = (char *) V_f16.ptr;

const size_t bs = ggml_blck_size(V->type);
const size_t ts = ggml_type_size(V->type);

nb21 = nb21*bs*sizeof(half)/ts;
nb22 = nb22*bs*sizeof(half)/ts;
nb23 = nb23*bs*sizeof(half)/ts;
if constexpr (DV == 512) {
// DeepSeek. In this case the V cache is the same as the K cache, except that
// it has 512 elements per row instead of 576.
nb21 = nb11;
nb22 = nb12;
nb23 = nb13;
V_data = K_data;
} else {
V_f16.alloc(ggml_nelements(V));
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);
V_data = (char *) V_f16.ptr;

nb21 = K->ne[0]*sizeof(half);
nb22 = nb21*V->ne[1];
nb23 = nb22*V->ne[2];

// Original PR in llama.cpp. Same comment as above for the K cache.
//const size_t bs = ggml_blck_size(V->type);
//const size_t ts = ggml_type_size(V->type);

//nb21 = nb21*bs*sizeof(half)/ts;
//nb22 = nb22*bs*sizeof(half)/ts;
//nb23 = nb23*bs*sizeof(half)/ts;
}
}

int parallel_blocks = 1;
Expand Down