@@ -8,11 +8,14 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
8
8
if (GGML_CUDA_CC_IS_AMD (cc)) {
9
9
switch (D) {
10
10
case 64 :
11
- return ncols <= 16 ? 32 : 64 ;
11
+ return 64 ;
12
12
case 128 :
13
- return ncols <= 16 ? 64 : warp_size;
14
13
case 256 :
15
- return 64 ;
14
+ if (GGML_CUDA_CC_IS_GCN (cc) || GGML_CUDA_CC_IS_CDNA (cc)) {
15
+ return ncols <= 16 ? 64 : 32 ;
16
+ } else {
17
+ return 64 ;
18
+ }
16
19
default :
17
20
GGML_ABORT (" fatal error" );
18
21
return -1 ;
@@ -41,17 +44,26 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
41
44
GGML_ABORT (" fatal error" );
42
45
return -1 ;
43
46
}
47
+ GGML_UNUSED (warp_size);
44
48
}
45
49
46
50
static constexpr __device__ int fattn_tile_get_kq_stride_device (int D, int ncols, int warp_size) {
47
51
#ifdef GGML_USE_HIP
48
52
switch (D) {
49
53
case 64 :
50
- return ncols <= 16 ? 32 : 64 ;
54
+ return 64 ;
51
55
case 128 :
52
- return ncols <= 16 ? 64 : warp_size;
56
+ #if defined(GCN) || defined(CDNA)
57
+ return ncols <= 16 ? 64 : 32 ;
58
+ #else
59
+ return 64 ;
60
+ #endif // defined(GCN) || defined(CDNA)
53
61
case 256 :
62
+ #if defined(GCN) || defined(CDNA)
63
+ return ncols <= 16 ? 64 : 32 ;
64
+ #else
54
65
return 64 ;
66
+ #endif // defined(GCN) || defined(CDNA)
55
67
default :
56
68
return -1 ;
57
69
}
@@ -88,9 +100,17 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
88
100
case 64 :
89
101
return 64 ;
90
102
case 128 :
91
- return ncols <= 16 ? 2 *warp_size : 128 ;
103
+ #if defined(GCN) || defined(CDNA)
104
+ return ncols <= 16 ? 64 : 128 ;
105
+ #else
106
+ return 64 ;
107
+ #endif // defined(GCN) || defined(CDNA)
92
108
case 256 :
93
- return ncols <= 16 ? 128 : 2 *warp_size;
109
+ #if defined(GCN) || defined(CDNA)
110
+ return ncols <= 16 ? 64 : 128 ;
111
+ #else
112
+ return ncols <= 16 ? 64 : 256 ;
113
+ #endif // defined(GCN) || defined(CDNA)
94
114
default :
95
115
return -1 ;
96
116
}
@@ -196,14 +216,21 @@ static __global__ void flash_attn_tile(
196
216
197
217
const float slope = get_alibi_slope (max_bias, head, n_head_log2, m0, m1);
198
218
219
+ #if defined(GGML_USE_HIP)
220
+ constexpr int cpy_nb = 16 ;
221
+ #else
222
+ constexpr int cpy_nb = 8 ;
223
+ #endif // defined(GGML_USE_HIP) && defined(GCN)
224
+ constexpr int cpy_ne = cpy_nb / 4 ;
225
+
199
226
__shared__ float KQ[ncols][kq_stride];
200
227
#ifdef FAST_FP16_AVAILABLE
201
228
__shared__ half2 Q_tmp[ncols][D/2 ];
202
- __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + 1 )]; // Padded to avoid memory bank conflicts.
229
+ __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne )]; // Padded to avoid memory bank conflicts.
203
230
half2 VKQ[ncols/nwarps][D/(2 *warp_size)] = {{{0 .0f , 0 .0f }}};
204
231
#else
205
232
__shared__ float Q_tmp[ncols][D];
206
- __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + 1 )]; // Padded to avoid memory bank conflicts.
233
+ __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne )]; // Padded to avoid memory bank conflicts.
207
234
float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
208
235
float2 VKQ[ncols/nwarps][D/(2 *warp_size)] = {{{0 .0f , 0 .0f }}};
209
236
#endif // FAST_FP16_AVAILABLE
@@ -256,11 +283,11 @@ static __global__ void flash_attn_tile(
256
283
for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; k_KQ_1 += warp_size) {
257
284
const half2 tmp_h2 = K_h2[int64_t (k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx .x ];
258
285
#ifdef FAST_FP16_AVAILABLE
259
- KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1 ) + k_KQ_1 + threadIdx .x ] = tmp_h2;
286
+ KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne ) + k_KQ_1 + threadIdx .x ] = tmp_h2;
260
287
#else
261
288
const float2 tmp_f2 = __half22float2 (tmp_h2);
262
- KV_tmp_f[i_KQ*(kq_nbatch + 1 ) + 2 *k_KQ_1 + threadIdx .x ] = tmp_f2.x ;
263
- KV_tmp_f[i_KQ*(kq_nbatch + 1 ) + 2 *k_KQ_1 + warp_size + threadIdx .x ] = tmp_f2.y ;
289
+ KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne ) + 2 *k_KQ_1 + threadIdx .x ] = tmp_f2.x ;
290
+ KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne ) + 2 *k_KQ_1 + warp_size + threadIdx .x ] = tmp_f2.y ;
264
291
#endif // FAST_FP16_AVAILABLE
265
292
}
266
293
}
@@ -269,42 +296,45 @@ static __global__ void flash_attn_tile(
269
296
270
297
#ifdef FAST_FP16_AVAILABLE
271
298
#pragma unroll
272
- for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; ++ k_KQ_1) {
273
- half2 K_k[kq_stride/warp_size];
274
- half2 Q_k[ncols/nwarps];
299
+ for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; k_KQ_1 += cpy_ne ) {
300
+ half2 K_k[kq_stride/warp_size][cpy_ne] ;
301
+ half2 Q_k[ncols/nwarps][cpy_ne] ;
275
302
#else
276
303
#pragma unroll
277
- for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch; ++ k_KQ_1) {
278
- float K_k[kq_stride/warp_size];
279
- float Q_k[ncols/nwarps];
304
+ for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne ) {
305
+ float K_k[kq_stride/warp_size][cpy_ne] ;
306
+ float Q_k[ncols/nwarps][cpy_ne] ;
280
307
#endif // FAST_FP16_AVAILABLE
281
308
282
309
#pragma unroll
283
310
for (int i_KQ_0 = 0 ; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
284
311
const int i_KQ = i_KQ_0 + threadIdx .x ;
285
312
286
313
#ifdef FAST_FP16_AVAILABLE
287
- K_k[i_KQ_0/warp_size] = KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1 ) + k_KQ_1];
314
+ ggml_cuda_memcpy_1<cpy_nb>(& K_k[i_KQ_0/warp_size], & KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne ) + k_KQ_1]) ;
288
315
#else
289
- K_k[i_KQ_0/warp_size] = KV_tmp_f [i_KQ*(kq_nbatch + 1 ) + k_KQ_1];
316
+ ggml_cuda_memcpy_1<cpy_nb>(& K_k[i_KQ_0/warp_size], & KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne ) + k_KQ_1]) ;
290
317
#endif // FAST_FP16_AVAILABLE
291
318
}
292
319
#pragma unroll
293
320
for (int j_KQ_0 = 0 ; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
294
321
const int j_KQ = j_KQ_0 + threadIdx .y ;
295
322
296
323
#ifdef FAST_FP16_AVAILABLE
297
- Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1];
324
+ ggml_cuda_memcpy_1<cpy_nb>(& Q_k[j_KQ_0/nwarps], & Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]) ;
298
325
#else
299
- Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0 + k_KQ_1];
326
+ ggml_cuda_memcpy_1<cpy_nb>(& Q_k[j_KQ_0/nwarps], & Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]) ;
300
327
#endif // FAST_FP16_AVAILABLE
301
328
}
302
329
303
330
#pragma unroll
304
331
for (int i_KQ_0 = 0 ; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
305
332
#pragma unroll
306
333
for (int j_KQ_0 = 0 ; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
307
- ggml_cuda_mad (sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
334
+ #pragma unroll
335
+ for (int k = 0 ; k < cpy_ne; ++k) {
336
+ ggml_cuda_mad (sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
337
+ }
308
338
}
309
339
}
310
340
}
@@ -345,14 +375,54 @@ static __global__ void flash_attn_tile(
345
375
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
346
376
347
377
float kqsum_add = 0 .0f ;
378
+ if (kq_stride % (4 *warp_size) == 0 && cpy_ne % 4 == 0 ) {
348
379
#pragma unroll
349
- for (int i0 = 0 ; i0 < kq_stride; i0 += warp_size) {
350
- const int i = i0 + threadIdx .x ;
380
+ for (int i0 = 0 ; i0 < kq_stride; i0 += 4 * warp_size) {
381
+ const int i = i0 + 4 * threadIdx .x ;
351
382
352
- const float diff = KQ[j][i] - kqmax[j0/nwarps];
353
- const float val = expf (diff);
354
- kqsum_add += val;
355
- KQ[j][i] = val;
383
+ float4 val = *(const float4 *) &KQ[j][i];
384
+ val.x = expf (val.x - kqmax[j0/nwarps]);
385
+ val.y = expf (val.y - kqmax[j0/nwarps]);
386
+ val.z = expf (val.z - kqmax[j0/nwarps]);
387
+ val.w = expf (val.w - kqmax[j0/nwarps]);
388
+ kqsum_add += val.x + val.y + val.z + val.w ;
389
+
390
+ #ifdef FAST_FP16_AVAILABLE
391
+ const half2 tmp[2 ] = {make_half2 (val.x , val.y ), make_half2 (val.z , val.w )};
392
+ ggml_cuda_memcpy_1<sizeof (tmp)>(&KQ[j][i/2 ], &tmp);
393
+ #else
394
+ ggml_cuda_memcpy_1<sizeof (val)>(&KQ[j][i], &val);
395
+ #endif // FAST_FP16_AVAILABLE
396
+ }
397
+ } else if (kq_stride % (2 *warp_size) == 0 && cpy_ne % 2 == 0 ) {
398
+ #pragma unroll
399
+ for (int i0 = 0 ; i0 < kq_stride; i0 += 2 *warp_size) {
400
+ const int i = i0 + 2 *threadIdx .x ;
401
+
402
+ float2 val = *(const float2 *) &KQ[j][i];
403
+ val.x = expf (val.x - kqmax[j0/nwarps]);
404
+ val.y = expf (val.y - kqmax[j0/nwarps]);
405
+ kqsum_add += val.x + val.y ;
406
+ #ifdef FAST_FP16_AVAILABLE
407
+ const half2 tmp = make_half2 (val.x , val.y );
408
+ ggml_cuda_memcpy_1<sizeof (tmp)>(&KQ[j][i/2 ], &tmp);
409
+ #else
410
+ ggml_cuda_memcpy_1<sizeof (val)>(&KQ[j][i], &val);
411
+ #endif // FAST_FP16_AVAILABLE
412
+ }
413
+ } else {
414
+ for (int i0 = 0 ; i0 < kq_stride; i0 += warp_size) {
415
+ const int i = i0 + threadIdx .x ;
416
+
417
+ const float diff = KQ[j][i] - kqmax[j0/nwarps];
418
+ const float val = expf (diff);
419
+ kqsum_add += val;
420
+ #ifdef FAST_FP16_AVAILABLE
421
+ ((half *) KQ[j])[i] = val;
422
+ #else
423
+ KQ[j][i] = val;
424
+ #endif // FAST_FP16_AVAILABLE
425
+ }
356
426
}
357
427
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
358
428
@@ -419,8 +489,7 @@ static __global__ void flash_attn_tile(
419
489
const int j = j0 + threadIdx .y ;
420
490
421
491
#ifdef FAST_FP16_AVAILABLE
422
- const float tmp = KQ[j][k0 + k1];
423
- KQ_k[j0/nwarps] = make_half2 (tmp, tmp);
492
+ KQ_k[j0/nwarps] = __half2half2 (((const half *)KQ[j])[k0 + k1]);
424
493
#else
425
494
KQ_k[j0/nwarps] = KQ[j][k0 + k1];
426
495
#endif // FAST_FP16_AVAILABLE
0 commit comments