Skip to content

Commit 3bb64d9

Browse files
ikawrakowIwan Kawrakow
andauthored
Better TG performance for GQA models (CPU) (#332)
* Slightly better CPU TG performance for GQA * Better CPU FA implementation for TG when GQA * Minor --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent f7c5a94 commit 3bb64d9

File tree

3 files changed

+134
-12
lines changed

3 files changed

+134
-12
lines changed

ggml/src/ggml.c

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21781,19 +21781,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
2178121781
const struct ggml_tensor * q = node->src[0];
2178221782
const struct ggml_tensor * k = node->src[1];
2178321783
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
21784-
int nstep_k = k->ne[1]/32;
21785-
int gcd_k = simple_gcd(nstep_k, n_tasks);
21786-
if (gcd_k > 1) {
21787-
int nth_k = n_tasks/gcd_k;
21788-
int rk2 = q->ne[2]/k->ne[2];
21789-
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
21790-
size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
21791-
if (ggml_is_quantized(k->type)) {
21792-
enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
21793-
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
21794-
size += q->ne[2]*row_size;
21795-
}
21784+
if (k->ne[2] > 1) {
21785+
int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
21786+
int nstep_k = k->ne[2]*k->ne[1]/nk;
21787+
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
21788+
size_t size = nstep_k*result_size;
2179621789
cur = MAX(cur, size);
21790+
} else {
21791+
int nstep_k = k->ne[1]/32;
21792+
int gcd_k = simple_gcd(nstep_k, n_tasks);
21793+
if (gcd_k > 1) {
21794+
int nth_k = n_tasks/gcd_k;
21795+
int rk2 = q->ne[2]/k->ne[2];
21796+
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
21797+
size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
21798+
if (ggml_is_quantized(k->type)) {
21799+
enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
21800+
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
21801+
size += q->ne[2]*row_size;
21802+
}
21803+
cur = MAX(cur, size);
21804+
}
2179721805
}
2179821806
}
2179921807
#endif

ggml/src/iqk/iqk_flash_attn.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,67 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
153153
}
154154
}
155155

156+
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
157+
int nk = 32 * (nek2*nek1/(32*nth));
158+
int nkk = (nek1 + nk - 1)/nk;
159+
int nstep_k = nek2*nkk;
160+
auto result_size = (Dv + 16)*rk2*sizeof(float);
161+
//if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k);
162+
for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) {
163+
int ik02 = istep_k/nkk;
164+
int ik01 = nk*(istep_k - ik02*nkk);
165+
int this_nk = ik01 + nk <= nek1 ? nk : nek1 - ik01;
166+
if (this_nk <= 0) break;
167+
auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
168+
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
169+
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
170+
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
171+
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
172+
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
173+
Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv,
174+
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
175+
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
176+
}
177+
178+
barrier(barrier_data);
179+
180+
// We have nkk results for each head
181+
for (int iq2 = ith; iq2 < neq2; iq2 += nth) {
182+
// ik02*rk2 + il = iq2 (il = 0...rk2-1) => ik02 = iq2/rk2, il = iq2%rk2;
183+
int ik02 = iq2/rk2;
184+
int il = iq2 - ik02*rk2;
185+
auto Racc = qkv + iq2*nb1/sizeof(float);
186+
std::memset(Racc, 0, Dv*sizeof(float));
187+
float M = -INFINITY, S = 0;
188+
for (int ikk = 0; ikk < nkk; ++ikk) {
189+
int istep_k = ik02*nkk + ikk;
190+
auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
191+
const float * R = this_result + il*Dv;
192+
const float * Mj = this_result + Dv*rk2;
193+
const float * Sj = Mj + rk2;
194+
if (Mj[il] == -INFINITY) continue;
195+
if (Mj[il] > M) {
196+
if (M == -INFINITY) {
197+
std::memcpy(Racc, R, Dv*sizeof(float));
198+
S = Sj[il];
199+
} else {
200+
float c = exp(M - Mj[il]);
201+
S = c*S + Sj[il];
202+
for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
203+
}
204+
M = Mj[il];
205+
} else {
206+
float c = exp(Mj[il] - M);
207+
S += c*Sj[il];
208+
for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
209+
}
210+
}
211+
float norm = S > 0 ? 1/S : 1;
212+
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
213+
}
214+
return true;
215+
}
216+
156217
// I keep changing my mind what is the best strategy to split the threads when processing
157218
// multiple heads. This is my current thinking, the commented out code below was the previous.
158219
int ntg = nth/simple_gcd(neq2*neq3, nth);

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,51 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
451451
auto r3 = ne13 / ne03;
452452

453453
if (ne13 == 1 && Ny == 1 && r2 > 1) {
454+
if (Nx >= 256 && Nx%32 == 0) {
455+
int nx32 = Nx/32;
456+
int nchunk = nx32*ne02;
457+
if (r2 <= 8) {
458+
MulMat mm;
459+
if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false;
460+
int nx64 = Nx/64;
461+
int nchunk64 = nx64*ne02;
462+
for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
463+
int i02 = ichunk/nx64;
464+
int ix = 64*(ichunk - i02*nx64);
465+
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
466+
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
467+
}
468+
int ix0 = 64*nx64;
469+
if (ix0 < Nx) {
470+
nx32 -= 2*nx64;
471+
nchunk = nx32*ne02;
472+
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
473+
int i02 = ichunk/nx32;
474+
int ix = ix0 + 32*(ichunk - i02*nx32);
475+
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
476+
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
477+
}
478+
}
479+
//for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
480+
// int i02 = ichunk/nx32;
481+
// int ix = 32*(ichunk - i02*nx32);
482+
// DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
483+
// mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
484+
//}
485+
return true;
486+
}
487+
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
488+
int i02 = ichunk/nx32;
489+
int ix = ichunk - i02*nx32;
490+
if (!iqk_mul_mat(32, r2, ne00,
491+
typeA, (const char *)A + 32*ix*strideA + i02*nb02, strideA,
492+
typeB, (const char *)B + i02*r2*nb12, nb12,
493+
C + 32*ix + r2*i02*nb2, nb2, 0, 1)) return false;
494+
495+
}
496+
return true;
497+
}
498+
//if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02);
454499
int gcd = simple_gcd(ne02, nth);
455500
int counter = 0;
456501
for (int64_t i12 = 0; i12 < ne02; i12++) {
@@ -17153,6 +17198,14 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
1715317198
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
1715417199
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
1715517200
}
17201+
else if (nq1 >= 4) {
17202+
FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
17203+
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17204+
}
17205+
else if (nq1 >= 2) {
17206+
FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
17207+
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17208+
}
1715617209
else {
1715717210
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
1715817211
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);

0 commit comments

Comments
 (0)