Skip to content

Commit 2d011e6

Browse files
JohannesGaesslergaugarg-nv
authored andcommitted
CUDA: determine FA parallel blocks at runtime
1 parent 0fd8487 commit 2d011e6

File tree

10 files changed

+166
-257
lines changed

10 files changed

+166
-257
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -606,48 +606,47 @@ static __global__ void flash_attn_stream_k_fixup(
606606
*dst = dst_val / rowsum;
607607
}
608608

609-
template<int D, int parallel_blocks> // D == head size
609+
template<int D> // D == head size
610610
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
611611
__launch_bounds__(D, 1)
612612
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
613613
static __global__ void flash_attn_combine_results(
614614
const float * __restrict__ VKQ_parts,
615615
const float2 * __restrict__ VKQ_meta,
616-
float * __restrict__ dst) {
617-
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
618-
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
619-
dst += D * gridDim.y*blockIdx.x;
616+
float * __restrict__ dst,
617+
const int parallel_blocks) {
618+
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
619+
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
620+
dst += D * gridDim.z*blockIdx.x;
620621

621622
const int tid = threadIdx.x;
622623
__builtin_assume(tid < D);
623624

624-
__shared__ float2 meta[parallel_blocks];
625+
extern __shared__ float2 meta[];
625626
if (tid < 2*parallel_blocks) {
626-
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
627+
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
627628
}
628629

629630
__syncthreads();
630631

631632
float kqmax = meta[0].x;
632-
#pragma unroll
633633
for (int l = 1; l < parallel_blocks; ++l) {
634634
kqmax = max(kqmax, meta[l].x);
635635
}
636636

637637
float VKQ_numerator = 0.0f;
638638
float VKQ_denominator = 0.0f;
639-
#pragma unroll
640639
for (int l = 0; l < parallel_blocks; ++l) {
641640
const float diff = meta[l].x - kqmax;
642641
const float KQ_max_scale = expf(diff);
643642
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
644643
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
645644

646-
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
645+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
647646
VKQ_denominator += KQ_max_scale * meta[l].y;
648647
}
649648

650-
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
649+
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
651650
}
652651

653652
static void on_no_fattn_vec_case(const int D) {
@@ -671,11 +670,10 @@ static void on_no_fattn_vec_case(const int D) {
671670
}
672671
}
673672

674-
// parallel_blocks == 0 is stream-k decomposition
675-
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
673+
template <int D, int ncols1, int ncols2, int KQ_stride>
676674
void launch_fattn(
677675
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
678-
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
676+
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V, const bool stream_k,
679677
const int warp_size = WARP_SIZE
680678
) {
681679
constexpr int ncols = ncols1 * ncols2;
@@ -699,6 +697,9 @@ void launch_fattn(
699697

700698
GGML_ASSERT(Q->ne[3] == 1);
701699

700+
GGML_ASSERT(stream_k || ncols2 == 1);
701+
const int parallel_blocks = Q->ne[1] <= ncols1 ? 4 : 1;
702+
702703
ggml_cuda_pool & pool = ctx.pool();
703704
cudaStream_t main_stream = ctx.stream();
704705
const int id = ggml_cuda_get_device();
@@ -753,7 +754,7 @@ void launch_fattn(
753754

754755
const dim3 block_dim(warp_size, nwarps, 1);
755756
dim3 blocks_num;
756-
if (parallel_blocks == 0) {
757+
if (stream_k) {
757758
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
758759
const int max_blocks = 2*nsm;
759760
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
@@ -769,9 +770,9 @@ void launch_fattn(
769770

770771
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
771772
} else {
772-
blocks_num.x = parallel_blocks*ntiles_x;
773-
blocks_num.y = Q->ne[2];
774-
blocks_num.z = Q->ne[3];
773+
blocks_num.x = ntiles_x;
774+
blocks_num.y = parallel_blocks;
775+
blocks_num.z = Q->ne[2]*Q->ne[3];
775776

776777
if (parallel_blocks > 1) {
777778
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -803,7 +804,7 @@ void launch_fattn(
803804
K_data,
804805
V_data,
805806
mask ? ((const char *) mask->data) : nullptr,
806-
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
807+
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
807808
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
808809
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
809810
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -815,7 +816,7 @@ void launch_fattn(
815816
);
816817
CUDA_CHECK(cudaGetLastError());
817818

818-
if constexpr (parallel_blocks == 0) {
819+
if (stream_k) {
819820
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
820821
const dim3 block_dim_combine(D, 1, 1);
821822
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
@@ -824,13 +825,14 @@ void launch_fattn(
824825
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
825826
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
826827
}
827-
} else if constexpr (parallel_blocks > 1) {
828+
} else if (parallel_blocks > 1) {
828829
const dim3 block_dim_combine(D, 1, 1);
829-
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
830+
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
831+
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
830832

831-
flash_attn_combine_results<D, parallel_blocks>
832-
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
833-
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
833+
flash_attn_combine_results<D>
834+
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
835+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
834836
}
835837
CUDA_CHECK(cudaGetLastError());
836838
}

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
970970
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
971971
}
972972

973-
launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
973+
launch_fattn<D, ncols1, ncols2, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true, true);
974974
}
975975

976976

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#define FATTN_KQ_STRIDE_TILE_F16 64
66

7-
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
7+
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
88
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
99
__launch_bounds__(nwarps*WARP_SIZE, 1)
1010
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(
5858

5959
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
6060

61-
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
62-
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
61+
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
6362

6463
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
65-
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
66-
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
67-
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
64+
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
65+
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
66+
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
6867
const half * maskh = (const half *) mask + ne11*ic0;
6968

7069
const int stride_KV2 = nb11 / sizeof(half2);
7170

72-
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
71+
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
7372
const half slopeh = __float2half(slopef);
7473

7574
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(
105104

106105
__syncthreads();
107106

108-
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
109-
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
107+
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
110108
// Calculate KQ tile and keep track of new maximum KQ values:
111109

112110
half kqmax_new[ncols/nwarps];
@@ -271,40 +269,40 @@ static __global__ void flash_attn_tile_ext_f16(
271269
const int i0 = i00 + 2*threadIdx.x;
272270

273271
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
274-
if (parallel_blocks == 1) {
272+
if (gridDim.y == 1) {
275273
dst_val /= __half2half2(kqsum_j);
276274
}
277-
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
278-
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
279-
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
275+
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
276+
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
277+
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
280278
}
281279

282-
if (parallel_blocks != 1 && threadIdx.x == 0) {
283-
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
280+
if (gridDim.y != 1 && threadIdx.x == 0) {
281+
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
284282
}
285283
}
286284
#else
287285
NO_DEVICE_CODE;
288286
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
289287
}
290288

291-
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
289+
template <int cols_per_block, bool use_logit_softcap>
292290
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
293291
const ggml_tensor * Q = dst->src[0];
294292
switch (Q->ne[0]) {
295293
case 64: {
296294
constexpr int D = 64;
297295
constexpr int nwarps = 8;
298296
constexpr size_t nbytes_shared = 0;
299-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
300-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
297+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
298+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true, false);
301299
} break;
302300
case 128: {
303301
constexpr int D = 128;
304302
constexpr int nwarps = 8;
305303
constexpr size_t nbytes_shared = 0;
306-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
307-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
304+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
305+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true, false);
308306
} break;
309307
default: {
310308
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -324,37 +322,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
324322

325323
if (Q->ne[1] <= 16) {
326324
constexpr int cols_per_block = 16;
327-
constexpr int parallel_blocks = 4;
328325
if (logit_softcap == 0.0f) {
329326
constexpr bool use_logit_softcap = false;
330-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
327+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
331328
} else {
332329
constexpr bool use_logit_softcap = true;
333-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
334-
}
335-
return;
336-
}
337-
338-
if (Q->ne[1] <= 32) {
339-
constexpr int cols_per_block = 32;
340-
constexpr int parallel_blocks = 4;
341-
if (logit_softcap == 0.0f) {
342-
constexpr bool use_logit_softcap = false;
343-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
344-
} else {
345-
constexpr bool use_logit_softcap = true;
346-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
330+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
347331
}
348332
return;
349333
}
350334

351335
constexpr int cols_per_block = 32;
352-
constexpr int parallel_blocks = 1;
353336
if (logit_softcap == 0.0f) {
354337
constexpr bool use_logit_softcap = false;
355-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
338+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
356339
} else {
357340
constexpr bool use_logit_softcap = true;
358-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
341+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
359342
}
360343
}

0 commit comments

Comments
 (0)