Skip to content

Commit dcca0d3

Browse files
cpu: introduce chunking for flash attention (#16829)
Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop on top that handles the chunks.
1 parent bacddc0 commit dcca0d3

File tree

1 file changed

+90
-16
lines changed

1 file changed

+90
-16
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 90 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7909,10 +7909,10 @@ void ggml_compute_forward_argsort(
79097909

79107910
// ggml_compute_forward_flash_attn_ext
79117911

7912-
static void ggml_compute_forward_flash_attn_ext_f16(
7912+
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
79137913
const ggml_compute_params * params,
7914-
ggml_tensor * dst) {
7915-
7914+
ggml_tensor * dst,
7915+
int ir0, int ir1) {
79167916
const ggml_tensor * q = dst->src[0];
79177917
const ggml_tensor * k = dst->src[1];
79187918
const ggml_tensor * v = dst->src[2];
@@ -7928,9 +7928,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
79287928
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
79297929
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
79307930

7931-
const int ith = params->ith;
7932-
const int nth = params->nth;
7933-
79347931
const int64_t DK = nek0;
79357932
const int64_t DV = nev0;
79367933
const int64_t N = neq1;
@@ -7964,16 +7961,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
79647961

79657962
// parallelize by q rows using ggml_vec_dot_f32
79667963

7967-
// total rows in q
7968-
const int nr = neq1*neq2*neq3;
7969-
7970-
// rows per thread
7971-
const int dr = (nr + nth - 1)/nth;
7972-
7973-
// row range for this thread
7974-
const int ir0 = dr*ith;
7975-
const int ir1 = MIN(ir0 + dr, nr);
7976-
79777964
float scale = 1.0f;
79787965
float max_bias = 0.0f;
79797966
float logit_softcap = 0.0f;
@@ -8000,6 +7987,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
80007987
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
80017988
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
80027989

7990+
int ith = params->ith;
7991+
80037992
// loop over n_batch and n_head
80047993
for (int ir = ir0; ir < ir1; ++ir) {
80057994
// q indices
@@ -8147,6 +8136,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
81478136
}
81488137
}
81498138

8139+
static void ggml_compute_forward_flash_attn_ext_f16(
8140+
const ggml_compute_params * params,
8141+
ggml_tensor * dst) {
8142+
8143+
const ggml_tensor * q = dst->src[0];
8144+
const ggml_tensor * k = dst->src[1];
8145+
const ggml_tensor * v = dst->src[2];
8146+
8147+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8148+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8149+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8150+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8151+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8152+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8153+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8154+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8155+
8156+
const int64_t DK = nek0;
8157+
const int64_t DV = nev0;
8158+
const int64_t N = neq1;
8159+
8160+
GGML_ASSERT(ne0 == DV);
8161+
GGML_ASSERT(ne2 == N);
8162+
8163+
// input tensor rows must be contiguous
8164+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8165+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8166+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8167+
8168+
GGML_ASSERT(neq0 == DK);
8169+
GGML_ASSERT(nek0 == DK);
8170+
GGML_ASSERT(nev0 == DV);
8171+
8172+
GGML_ASSERT(neq1 == N);
8173+
8174+
// dst cannot be transposed or permuted
8175+
GGML_ASSERT(nb0 == sizeof(float));
8176+
GGML_ASSERT(nb0 <= nb1);
8177+
GGML_ASSERT(nb1 <= nb2);
8178+
GGML_ASSERT(nb2 <= nb3);
8179+
8180+
// parallelize by q rows using ggml_vec_dot_f32
8181+
8182+
// total rows in q
8183+
const int64_t nr = neq1*neq2*neq3;
8184+
8185+
// rows per thread
8186+
const int ith = params->ith;
8187+
const int nth = params->nth;
8188+
8189+
// disable for NUMA
8190+
const bool disable_chunking = ggml_is_numa();
8191+
8192+
// 4x chunks per thread
8193+
int nth_scaled = nth * 4;
8194+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8195+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8196+
8197+
if (nth == 1 || nchunk < nth || disable_chunking) {
8198+
nchunk = nth;
8199+
}
8200+
8201+
if (ith == 0) {
8202+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8203+
ggml_threadpool_chunk_set(params->threadpool, nth);
8204+
}
8205+
8206+
ggml_barrier(params->threadpool);
8207+
8208+
// The number of elements in each chunk
8209+
const int64_t dr = (nr + nchunk - 1) / nchunk;
8210+
8211+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
8212+
int current_chunk = ith;
8213+
8214+
while (current_chunk < nchunk) {
8215+
const int64_t ir0 = dr * current_chunk;
8216+
const int64_t ir1 = MIN(ir0 + dr, nr);
8217+
8218+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
8219+
8220+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8221+
}
8222+
}
8223+
81508224
void ggml_compute_forward_flash_attn_ext(
81518225
const ggml_compute_params * params,
81528226
ggml_tensor * dst) {

0 commit comments

Comments
 (0)