Skip to content

Commit 2ef0c5f

Browse files
CUDA: generalized (mma) FA, add Volta support
1 parent 583cb83 commit 2ef0c5f

File tree

10 files changed

+935
-735
lines changed

10 files changed

+935
-735
lines changed

ggml/include/ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2278,7 +2278,7 @@ extern "C" {
22782278
float stop,
22792279
float step);
22802280

2281-
#define GGML_KQ_MASK_PAD 64
2281+
#define GGML_KQ_MASK_PAD 1
22822282

22832283
// q: [n_embd_k, n_batch, n_head, ne3 ]
22842284
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)(
2525
const float m1,
2626
const uint32_t n_head_log2,
2727
const float logit_softcap,
28-
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
28+
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
2929
const int32_t nb01, const int32_t nb02, const int32_t nb03,
3030
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
3131
const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max(
621621
template<int D, int ncols1, int ncols2> // D == head size
622622
__launch_bounds__(D, 1)
623623
static __global__ void flash_attn_stream_k_fixup(
624-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
624+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
625+
const int nbatch_fa) {
625626
constexpr int ncols = ncols1*ncols2;
626627

627628
const int bidx0 = blockIdx.x;
@@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup(
632633

633634
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
634635

635-
const int iter_k = ne11 / FATTN_KQ_STRIDE;
636-
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
636+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
637+
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
637638

638639
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
639640
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
@@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results(
765766
template <int DV, int ncols1, int ncols2>
766767
void launch_fattn(
767768
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
768-
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
769+
const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
769770
) {
770771
constexpr int ncols = ncols1 * ncols2;
771772

@@ -790,8 +791,6 @@ void launch_fattn(
790791
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
791792

792793
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
793-
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
794-
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
795794

796795
ggml_cuda_pool & pool = ctx.pool();
797796
cudaStream_t main_stream = ctx.stream();
@@ -915,7 +914,7 @@ void launch_fattn(
915914

916915
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
917916
} else {
918-
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
917+
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
919918

920919
// parallel_blocks must not be larger than what the tensor size allows:
921920
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -970,6 +969,9 @@ void launch_fattn(
970969
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
971970
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
972971

972+
// TODO other tensor dimensions after removal of WMMA kernel:
973+
const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
974+
973975
GGML_ASSERT(block_dim.x % warp_size == 0);
974976
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
975977
(const char *) Q->data,
@@ -980,7 +982,7 @@ void launch_fattn(
980982
KV_max.ptr,
981983
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
982984
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
983-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
985+
Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
984986
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
985987
nb21, nb22, nb23,
986988
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
@@ -995,7 +997,7 @@ void launch_fattn(
995997

996998
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
997999
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
998-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
1000+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
9991001
}
10001002
} else if (parallel_blocks > 1) {
10011003
const dim3 block_dim_combine(DV, 1, 1);

0 commit comments

Comments
 (0)