|
4 | 4 | #include <mma.h> |
5 | 5 |
|
6 | 6 | #define FATTN_KQ_STRIDE 256 |
| 7 | +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. |
7 | 8 |
|
8 | 9 | template<int D, int parallel_blocks> // D == head size |
9 | 10 | __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) |
@@ -59,13 +60,13 @@ static __global__ void flash_attn_vec_ext_f16( |
59 | 60 | KQ[tid] = -INFINITY; |
60 | 61 | half2 * KQ2 = (half2 *) KQ; |
61 | 62 |
|
62 | | - half kqmax = -INFINITY; |
| 63 | + half kqmax = -HALF_MAX_HALF; |
63 | 64 | half kqsum = 0.0f; |
64 | 65 |
|
65 | 66 | __shared__ half kqmax_shared[WARP_SIZE]; |
66 | 67 | __shared__ half kqsum_shared[WARP_SIZE]; |
67 | 68 | if (threadIdx.y == 0) { |
68 | | - kqmax_shared[threadIdx.x] = -INFINITY; |
| 69 | + kqmax_shared[threadIdx.x] = -HALF_MAX_HALF; |
69 | 70 | kqsum_shared[threadIdx.x] = 0.0f; |
70 | 71 | } |
71 | 72 | __syncthreads(); |
@@ -139,7 +140,7 @@ static __global__ void flash_attn_vec_ext_f16( |
139 | 140 | if (tid < D) { |
140 | 141 | #pragma unroll |
141 | 142 | for (int k0 = 0; k0 < D; k0 += 2) { |
142 | | - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { |
| 143 | + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { |
143 | 144 | break; |
144 | 145 | } |
145 | 146 |
|
@@ -253,9 +254,9 @@ static __global__ void flash_attn_ext_f16( |
253 | 254 | __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; |
254 | 255 | half2 * KQ2 = (half2 *) KQ; |
255 | 256 |
|
256 | | - half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}}; |
257 | | - half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}}; |
258 | | - half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}}; |
| 257 | + half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}}; |
| 258 | + half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}}; |
| 259 | + half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}}; |
259 | 260 |
|
260 | 261 | __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. |
261 | 262 | half2 * VKQ2 = (half2 *) VKQ; |
@@ -578,6 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst |
578 | 579 | GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && |
579 | 580 | "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); |
580 | 581 |
|
| 582 | + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); |
| 583 | + |
581 | 584 | ggml_cuda_set_device(ctx.device); |
582 | 585 |
|
583 | 586 | const cudaStream_t main_stream = ctx.stream(); |
|
0 commit comments