Skip to content

Commit 1ef7fd0

Browse files
committed
Review: formatting changes + remove syncthreads from tile + remove warp_reduce_max from wmma
1 parent 4946c19 commit 1ef7fd0

File tree

3 files changed

+18
-24
lines changed

3 files changed

+18
-24
lines changed

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ static __global__ void flash_attn_tile_ext_f16(
4949
const int sequence = blockIdx.z / ne02;
5050
const int head = blockIdx.z - sequence*ne02;
5151
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
52-
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
53-
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
54-
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
52+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
53+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
54+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
5555
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
5656
const float * sinksf = (const float *) (sinks);
5757

@@ -247,7 +247,7 @@ static __global__ void flash_attn_tile_ext_f16(
247247
if (sinksf && blockIdx.y == 0) {
248248
const half sink = __float2half(sinksf[head]);
249249

250-
#pragma unroll
250+
#pragma unroll
251251
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
252252
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
253253
kqmax_new_j = warp_reduce_max(kqmax_new_j);
@@ -261,13 +261,11 @@ static __global__ void flash_attn_tile_ext_f16(
261261
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
262262
}
263263

264-
#pragma unroll
264+
#pragma unroll
265265
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
266266
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
267267
}
268268
}
269-
270-
__syncthreads();
271269
}
272270

273271
float2 * dst2 = (float2 *) dst;

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ static __global__ void flash_attn_tile_ext_f32(
6060
const int sequence = blockIdx.z / ne02;
6161
const int head = blockIdx.z - sequence*ne02;
6262
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
63-
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
64-
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
65-
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
63+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
64+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
65+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
6666
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
6767
const float * sinksf = (const float *) (sinks);
6868

@@ -258,7 +258,7 @@ static __global__ void flash_attn_tile_ext_f32(
258258
if (sinksf && blockIdx.y == 0) {
259259
const float sink = sinksf[head];
260260

261-
#pragma unroll
261+
#pragma unroll
262262
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
263263
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
264264
kqmax_new_j = warp_reduce_max(kqmax_new_j);
@@ -272,14 +272,12 @@ static __global__ void flash_attn_tile_ext_f32(
272272
kqsum[j0/nwarps] += val;
273273
}
274274

275-
#pragma unroll
275+
#pragma unroll
276276
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
277277
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
278278
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
279279
}
280280
}
281-
282-
__syncthreads();
283281
}
284282

285283
float2 * dst2 = (float2 *) dst;

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ static __global__ void flash_attn_ext_f16(
8282
const int sequence = blockIdx.z / ne02;
8383
const int head = blockIdx.z - sequence*ne02;
8484
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
85-
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
86-
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
87-
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
88-
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
89-
const half2 * mask2 = (const half2 *) maskh;
85+
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
86+
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
87+
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
88+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
89+
const half2 * mask2 = (const half2 *) maskh;
9090
const float * sinksf = (const float *) sinks;
9191

9292
const int stride_Q = nb01 / sizeof(float);
@@ -387,21 +387,20 @@ static __global__ void flash_attn_ext_f16(
387387
const float sinkf = sinksf[head];
388388
const half sinkh = __float2half(sinkf);
389389

390-
#pragma unroll
390+
#pragma unroll
391391
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
392392
const int j = j0 + threadIdx.y;
393393

394394
if (std::is_same<KQ_acc_t, float>::value) {
395395
float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
396-
kqmax_new = warp_reduce_max<warp_size>(kqmax_new);
397396

398397
const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
399398
KQ_max_f[j0/nwarps] = kqmax_new;
400399

401400
KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
402401

403402
const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
404-
#pragma unroll
403+
#pragma unroll
405404
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
406405
const int i = i0 + threadIdx.x;
407406
if (i0 + warp_size > D/2 && i >= D/2) break;
@@ -410,7 +409,6 @@ static __global__ void flash_attn_ext_f16(
410409
} else {
411410
half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
412411
half kqmax_new = fmaxf(kqmax_old, sinkh);
413-
kqmax_new = warp_reduce_max<warp_size>(kqmax_new);
414412
KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
415413

416414
const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
@@ -420,7 +418,7 @@ static __global__ void flash_attn_ext_f16(
420418
const half val = hexp(sinkh - kqmax_new);
421419
KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
422420

423-
#pragma unroll
421+
#pragma unroll
424422
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
425423
const int i = i0 + threadIdx.x;
426424
if (i0 + warp_size > D/2 && i >= D/2) break;

0 commit comments

Comments
 (0)