Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 179 additions & 7 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,45 @@ typedef tile<16, 16, float> tile_C_KQ_16;
typedef tile<16, 4, half2> tile_C_VKQ;
typedef tile<16, 8, half2> tile_C_VKQ_16;

typedef void (* fattn_kernel_mma_t)(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int2 * __restrict__ bounds,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3);

template<int D, int nwarps, int KQ_per_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
Expand Down Expand Up @@ -871,6 +910,7 @@ static __global__ void flash_attn_mma_ext_f16(
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int2 * __restrict__ bounds,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
Expand Down Expand Up @@ -948,8 +988,13 @@ static __global__ void flash_attn_mma_ext_f16(

const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;

const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;
int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;

if (bounds) {
kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter);
kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter);
}

constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) {
Expand Down Expand Up @@ -987,8 +1032,12 @@ static __global__ void flash_attn_mma_ext_f16(

const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;

const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;
int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
if (bounds) {
kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter);
kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter);
}

constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
Expand Down Expand Up @@ -1144,9 +1193,109 @@ static __global__ void flash_attn_mma_combine_results(
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_all(int x) {
if constexpr (width == WARP_SIZE) { //ggml_cuda_get_physical_warp_size()) {
return __all_sync(0xffffffff, x);
} else {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
}
return x;
}
}

template <int ncols1, bool is_swa>
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
static __global__ void flash_attn_mask_to_KV_min_max(
const half2 * __restrict__ mask, int2 * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) {
const int ne31 = gridDim.x;
const int tid = threadIdx.x;
const int sequence = blockIdx.y;
const int jt = blockIdx.x;

mask += sequence*s33 + jt*ncols1*s31;

__shared__ int buf_iw[WARP_SIZE];
if (tid < WARP_SIZE) {
buf_iw[tid] = 1;
}
__syncthreads();

int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
int all_inf = 1;

#pragma unroll
for (int j = 0; j < ncols1; ++j) {
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
}

all_inf = warp_reduce_all(all_inf);
if (tid % WARP_SIZE == 0) {
buf_iw[tid / WARP_SIZE] = all_inf;
}
__syncthreads();
all_inf = buf_iw[tid % WARP_SIZE];
__syncthreads();
all_inf = warp_reduce_all(all_inf);

if (!all_inf) {
break;
}
}

if constexpr (!is_swa) {
if (threadIdx.x == 0) {
KV_min_max[sequence*ne31 + jt] = {0, KV_max_sj + FATTN_KQ_STRIDE};
}
return;
}

if (threadIdx.x == 0) {
KV_min_max[sequence*ne31 + jt].y = KV_max_sj + FATTN_KQ_STRIDE;
}

if (tid < WARP_SIZE) {
buf_iw[tid] = 1;
}
__syncthreads();

int KV_min_sj = 0;
for (; KV_min_sj < KV_max_sj; KV_min_sj += FATTN_KQ_STRIDE) {
int all_inf = 1;

#pragma unroll
for (int j = 0; j < ncols1; ++j) {
const float2 tmp = __half22float2(mask[j*s31 + KV_min_sj/2 + tid]);
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
}

all_inf = warp_reduce_all(all_inf);
if (tid % WARP_SIZE == 0) {
buf_iw[tid / WARP_SIZE] = all_inf;
}
__syncthreads();
all_inf = buf_iw[tid % WARP_SIZE];
__syncthreads();
all_inf = warp_reduce_all(all_inf);

if (!all_inf) {
break;
}
}

if (threadIdx.x == 0) {
KV_min_max[sequence*ne31 + jt].x = KV_min_sj;
}
}


template <int D, int ncols1, int ncols2, int KQ_stride>
void launch_fattn_mma(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_mma_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
) {
constexpr int ncols = ncols1 * ncols2;
Expand All @@ -1171,6 +1320,9 @@ void launch_fattn_mma(

GGML_ASSERT(Q->ne[3] == 1);

int n_swa;
memcpy(&n_swa, (const int *) KQV->op_params + 4, sizeof(int));

ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
Expand All @@ -1179,6 +1331,7 @@ void launch_fattn_mma(

ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
ggml_cuda_pool_alloc<int2> KV_min_max(pool);
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);

Expand Down Expand Up @@ -1225,11 +1378,29 @@ void launch_fattn_mma(
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];

if (mask && (Q->ne[1] >= 1024 || (n_swa > 0 && K->ne[1] >= FATTN_KQ_STRIDE + n_swa))) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
KV_min_max.alloc(ne_KV_max);
if (n_swa > 0) {
flash_attn_mask_to_KV_min_max<ncols1, true><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
} else {
flash_attn_mask_to_KV_min_max<ncols1, false><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
}
CUDA_CHECK(cudaGetLastError());
}

const dim3 block_dim(warp_size, nwarps, 1);
dim3 blocks_num;
if (stream_k) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int max_blocks = 2*nsm;
const int max_blocks = Q->ne[1] > 1 ? 2*nsm : nsm;
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);

Expand Down Expand Up @@ -1313,6 +1484,7 @@ void launch_fattn_mma(
V_data,
mask ? ((const char *) mask->data) : nullptr,
sinks ? ((const char *)sinks->data) : nullptr,
KV_min_max.ptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
Expand Down Expand Up @@ -1372,7 +1544,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));

fattn_kernel_t fattn_kernel;
fattn_kernel_mma_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_mma_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -9008,6 +9008,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
float params[] = { scale, max_bias, softcap };
ggml_set_op_params(result, params, sizeof(params));

ggml_set_op_params_i32(result, 4, 0);

result->op = GGML_OP_FLASH_ATTN_EXT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = q;
Expand Down
24 changes: 16 additions & 8 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7985,7 +7985,7 @@ static struct ggml_tensor * llm_build_kqv(
float kq_scale,
const llm_build_cb & cb,
int il,
ggml_tensor * sinks = nullptr) {
ggml_tensor * sinks = nullptr, int n_swa = 0) {
const llama_model & model = lctx.model;
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;
Expand Down Expand Up @@ -8033,6 +8033,9 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
ggml_flash_attn_ext_add_sinks(cur, sinks);
if (n_swa > 0) {
((int32_t *)cur->op_params)[4] = n_swa;
}

// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
// For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
Expand Down Expand Up @@ -8190,7 +8193,7 @@ static struct ggml_tensor * llm_build_kv(
float kq_scale,
const llm_build_cb & cb,
int il,
ggml_tensor * sinks = nullptr) {
ggml_tensor * sinks = nullptr, int n_swa = 0) {
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;

Expand All @@ -8205,7 +8208,7 @@ static struct ggml_tensor * llm_build_kv(
struct ggml_tensor * cur;

cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks);
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa);
cb(cur, "kqv_out", il);

return cur;
Expand Down Expand Up @@ -8766,7 +8769,8 @@ struct llm_build_context {

cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr,
this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0);
}

if (il == n_layer - 1) {
Expand Down Expand Up @@ -12198,7 +12202,8 @@ struct llm_build_context {

cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il, nullptr,
KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0);
}

cur = llm_build_norm(ctx0, cur, hparams,
Expand Down Expand Up @@ -12335,7 +12340,8 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);

cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il, nullptr,
KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0);
}

cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il);
Expand Down Expand Up @@ -14400,7 +14406,8 @@ struct llm_build_context {
}

cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il, nullptr,
is_sliding ? hparams.n_swa : 0);
}

if (il == n_layer - 1) {
Expand Down Expand Up @@ -15490,7 +15497,8 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);

cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks);
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks,
is_sliding ? hparams.n_swa : 0);

cb(cur, "attn_out", il);
}
Expand Down