Skip to content

Commit 534f158

Browse files
committed
CUDA: use mma PTX instructions for FlashAttention (ggml-org#11583)
* CUDA: use mma PTX instructions for FlashAttention * __shfl_sync workaround for movmatrix * add __shfl_sync to HIP Authors : Johannes Gaessler and Slaren
1 parent 80fe9c3 commit 534f158

25 files changed

+2053
-993
lines changed

ggml/include/ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1893,7 +1893,7 @@ extern "C" {
18931893
struct ggml_tensor * a,
18941894
int k);
18951895

1896-
#define GGML_KQ_MASK_PAD 32
1896+
#define GGML_KQ_MASK_PAD 64
18971897

18981898
// q: [n_embd, n_batch, n_head, 1]
18991899
// k: [n_embd, n_kv, n_head_kv, 1]

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ typedef float2 dfloat2;
148148
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
149149

150150
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
151-
#define INT8_MMA_AVAILABLE
151+
#define NEW_MMA_AVAILABLE
152152
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
153153

154154
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
@@ -159,11 +159,13 @@ static constexpr bool fast_fp16_available(const int cc) {
159159
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
160160
}
161161

162+
// Any FP16 tensor cores are available.
162163
static constexpr bool fp16_mma_available(const int cc) {
163164
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
164165
}
165166

166-
static constexpr bool int8_mma_available(const int cc) {
167+
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
168+
static constexpr bool new_mma_available(const int cc) {
167169
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
168170
}
169171

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 154 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
653653
nullptr;
654654
}
655655

656+
template<int D, int ncols, int KQ_stride> // D == head size
657+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
658+
__launch_bounds__(D, 1)
659+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
660+
static __global__ void flash_attn_stream_k_fixup(
661+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
662+
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
663+
664+
const int iter_k = ne11 / KQ_stride;
665+
const int iter_j = (ne01 + (ncols - 1)) / ncols;
666+
667+
const int bidx0 = blockIdx.x;
668+
669+
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
670+
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
671+
672+
const bool did_not_have_any_data = kbc0 == kbc0_stop;
673+
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
674+
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
675+
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
676+
return;
677+
}
678+
679+
const int channel = kbc0 / (iter_k*iter_j);
680+
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
681+
682+
dst += jt*ncols*ne02*D + channel*D;
683+
684+
// Load the partial result that needs a fixup:
685+
float dst_val[ncols] = {0.0f};
686+
float max_val[ncols] = {0.0f};
687+
float rowsum[ncols] = {0.0f};
688+
#pragma unroll
689+
for (int j = 0; j < ncols; ++j) {
690+
if (jt*ncols + j >= ne01) {
691+
break;
692+
}
693+
dst_val[j] = dst[j*ne02*D + threadIdx.x];
694+
695+
const float2 tmp = dst_fixup[bidx0*ncols + j];
696+
max_val[j] = tmp.x;
697+
rowsum[j] = tmp.y;
698+
}
699+
700+
// Iterate over previous blocks and compute the combined results.
701+
// All CUDA blocks that get here must have a previous block that needs a fixup.
702+
int bidx = bidx0 - 1;
703+
int kbc_stop = kbc0;
704+
while(true) {
705+
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
706+
if (kbc == kbc_stop) { // Did not have any data.
707+
bidx--;
708+
kbc_stop = kbc;
709+
continue;
710+
}
711+
712+
#pragma unroll
713+
for (int j = 0; j < ncols; ++j) {
714+
if (jt*ncols + j >= ne01) {
715+
break;
716+
}
717+
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
718+
719+
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
720+
721+
// Scale the current and new value accumulators depending on the max. values.
722+
const float max_val_new = fmaxf(max_val[j], tmp.x);
723+
724+
const float diff_val = max_val[j] - max_val_new;
725+
const float diff_add = tmp.x - max_val_new;
726+
727+
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
728+
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
729+
730+
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
731+
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
732+
733+
max_val[j] = max_val_new;
734+
}
735+
736+
// If this block started in a previous tile we are done and don't need to combine additional partial results.
737+
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
738+
break;
739+
}
740+
bidx--;
741+
kbc_stop = kbc;
742+
}
743+
744+
// Write back final result:
745+
#pragma unroll
746+
for (int j = 0; j < ncols; ++j) {
747+
if (jt*ncols + j >= ne01) {
748+
return;
749+
}
750+
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
751+
}
752+
}
753+
656754
template<int D, int parallel_blocks> // D == head size
657755
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
658756
__launch_bounds__(D, 1)
@@ -722,10 +820,11 @@ static void on_no_fattn_vec_case(const int D) {
722820
}
723821
}
724822

725-
template <int D, int parallel_blocks>
823+
// parallel_blocks == 0 is stream-k decomposition
824+
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
726825
void launch_fattn(
727826
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
728-
const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
827+
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
729828
) {
730829
const ggml_tensor * Q = dst->src[0];
731830
const ggml_tensor * K = dst->src[1];
@@ -744,20 +843,23 @@ void launch_fattn(
744843

745844
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
746845

846+
GGML_ASSERT(Q->ne[3] == 1);
847+
747848
ggml_cuda_pool & pool = ctx.pool();
748849
cudaStream_t main_stream = ctx.stream();
850+
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
749851

750852
ggml_cuda_pool_alloc<half> K_f16(pool);
751853
ggml_cuda_pool_alloc<half> V_f16(pool);
752854
ggml_cuda_pool_alloc<float> dst_tmp(pool);
753855
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
754856

755-
char * K_data = (char *) K->data;
857+
const char * K_data = (const char *) K->data;
756858
size_t nb11 = K->nb[1];
757859
size_t nb12 = K->nb[2];
758860
size_t nb13 = K->nb[3];
759861

760-
char * V_data = (char *) V->data;
862+
const char * V_data = (const char *) V->data;
761863
size_t nb21 = V->nb[1];
762864
size_t nb22 = V->nb[2];
763865
size_t nb23 = V->nb[3];
@@ -790,39 +892,60 @@ void launch_fattn(
790892
nb23 = nb23*bs*sizeof(half)/ts;
791893
}
792894

793-
if (parallel_blocks > 1) {
794-
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
795-
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
796-
}
895+
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
896+
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
797897

798898
const dim3 block_dim(WARP_SIZE, nwarps, 1);
799-
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
800-
const int shmem = 0;
899+
dim3 blocks_num;
900+
if (parallel_blocks == 0) {
901+
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
902+
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
903+
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
904+
const bool short_context = K->ne[1] < 4096;
905+
906+
const int nblocks_stream_k = 2*nsm;
907+
908+
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
909+
blocks_num.y = 1;
910+
blocks_num.z = 1;
911+
912+
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
913+
} else {
914+
blocks_num.x = parallel_blocks*ntiles_x;
915+
blocks_num.y = Q->ne[2];
916+
blocks_num.z = Q->ne[3];
917+
918+
if (parallel_blocks > 1) {
919+
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
920+
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
921+
}
922+
}
923+
801924

802925
float scale = 1.0f;
803926
float max_bias = 0.0f;
804927
float logit_softcap = 0.0f;
805928

806-
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
807-
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
808-
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
929+
memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
930+
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
931+
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
809932

810933
if (logit_softcap != 0.0f) {
811934
scale /= logit_softcap;
812935
}
813936

814937
const uint32_t n_head = Q->ne[2];
815-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
938+
const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
816939

817940
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
818941
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
819942

820-
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
943+
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
821944
(const char *) Q->data,
822945
K_data,
823946
V_data,
824947
mask ? ((const char *) mask->data) : nullptr,
825-
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
948+
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
826949
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
827950
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
828951
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -834,16 +957,22 @@ void launch_fattn(
834957
);
835958
CUDA_CHECK(cudaGetLastError());
836959

837-
if ((parallel_blocks) == 1) {
838-
return;
839-
}
960+
if constexpr (parallel_blocks == 0) {
961+
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
962+
const dim3 block_dim_combine(D, 1, 1);
963+
const dim3 blocks_num_combine = blocks_num;
840964

841-
const dim3 block_dim_combine(D, 1, 1);
842-
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
843-
const int shmem_combine = 0;
965+
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
966+
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
967+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
968+
}
969+
} else if constexpr (parallel_blocks > 1) {
970+
const dim3 block_dim_combine(D, 1, 1);
971+
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
844972

845-
flash_attn_combine_results<D, parallel_blocks>
846-
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
847-
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
973+
flash_attn_combine_results<D, parallel_blocks>
974+
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
975+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
976+
}
848977
CUDA_CHECK(cudaGetLastError());
849978
}

0 commit comments

Comments
 (0)