Skip to content

Commit ee19a4a

Browse files
fix KV cache padding, NaN from INFINITY (#6438)
1 parent c63dfdf commit ee19a4a

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

ggml-cuda/fattn.cu

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <mma.h>
55

66
#define FATTN_KQ_STRIDE 256
7+
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
78

89
template<int D, int parallel_blocks> // D == head size
910
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
@@ -59,13 +60,13 @@ static __global__ void flash_attn_vec_ext_f16(
5960
KQ[tid] = -INFINITY;
6061
half2 * KQ2 = (half2 *) KQ;
6162

62-
half kqmax = -INFINITY;
63+
half kqmax = -HALF_MAX_HALF;
6364
half kqsum = 0.0f;
6465

6566
__shared__ half kqmax_shared[WARP_SIZE];
6667
__shared__ half kqsum_shared[WARP_SIZE];
6768
if (threadIdx.y == 0) {
68-
kqmax_shared[threadIdx.x] = -INFINITY;
69+
kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
6970
kqsum_shared[threadIdx.x] = 0.0f;
7071
}
7172
__syncthreads();
@@ -139,7 +140,7 @@ static __global__ void flash_attn_vec_ext_f16(
139140
if (tid < D) {
140141
#pragma unroll
141142
for (int k0 = 0; k0 < D; k0 += 2) {
142-
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
143+
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
143144
break;
144145
}
145146

@@ -253,9 +254,9 @@ static __global__ void flash_attn_ext_f16(
253254
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
254255
half2 * KQ2 = (half2 *) KQ;
255256

256-
half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}};
257-
half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}};
258-
half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}};
257+
half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}};
258+
half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}};
259+
half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}};
259260

260261
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
261262
half2 * VKQ2 = (half2 *) VKQ;
@@ -578,6 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
578579
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
579580
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
580581

582+
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
583+
581584
ggml_cuda_set_device(ctx.device);
582585

583586
const cudaStream_t main_stream = ctx.stream();

llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9973,7 +9973,7 @@ static int llama_decode_internal(
99739973
// a heuristic, to avoid attending the full cache if it is not yet utilized
99749974
// after enough generations, the benefit from this heuristic disappears
99759975
// if we start defragmenting the cache, the benefit from this will be more important
9976-
kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
9976+
kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
99779977
//kv_self.n = llama_kv_cache_cell_max(kv_self);
99789978
}
99799979
}
@@ -13909,7 +13909,7 @@ struct llama_context * llama_new_context_with_model(
1390913909
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
1391013910

1391113911
// this is necessary due to kv_self.n being padded later during inference
13912-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
13912+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
1391313913

1391413914
// with causal attention, the batch size is limited by the context size
1391513915
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;

0 commit comments

Comments
 (0)