Skip to content

Commit 25a7705

Browse files
committed
metal : unroll update
1 parent 88e73e8 commit 25a7705

File tree

1 file changed

+19
-31
lines changed

1 file changed

+19
-31
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4787,7 +4787,8 @@ void kernel_flash_attn_ext_impl(
47874787

47884788
constexpr short NC = (C/8)/NSG;
47894789

4790-
// TODO: not good to unroll for large contexts - not sure why?
4790+
// note: do not unroll for large heads
4791+
#pragma unroll (DK <= 64 ? NC : 1)
47914792
for (short cc = 0; cc < NC; ++cc) {
47924793
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
47934794

@@ -4798,15 +4799,12 @@ void kernel_flash_attn_ext_impl(
47984799
FOR_UNROLL (short i = 0; i < DK8; ++i) {
47994800
simdgroup_barrier(mem_flags::mem_none);
48004801

4801-
simdgroup_load(mk, pk, NS10, 0, true);
4802-
simdgroup_load(mq, pq, DK);
4802+
simdgroup_load(mk, pk + 8*i, NS10, 0, true);
4803+
simdgroup_load(mq, pq + 8*i, DK);
48034804

48044805
simdgroup_barrier(mem_flags::mem_none);
48054806

48064807
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
4807-
4808-
pk += 8;
4809-
pq += 8;
48104808
}
48114809
} else {
48124810
k8x8_t mk[2];
@@ -4815,26 +4813,22 @@ void kernel_flash_attn_ext_impl(
48154813
FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
48164814
simdgroup_barrier(mem_flags::mem_none);
48174815

4818-
simdgroup_load(mk[0], pk + 0*8, NS10, 0, true);
4819-
simdgroup_load(mk[1], pk + 1*8, NS10, 0, true);
4816+
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
4817+
simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);
48204818

4821-
simdgroup_load(mq[0], pq + 0*8, DK);
4822-
simdgroup_load(mq[1], pq + 1*8, DK);
4819+
simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);
4820+
simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);
48234821

48244822
simdgroup_barrier(mem_flags::mem_none);
48254823

48264824
simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
48274825
simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
4828-
4829-
pk += 16;
4830-
pq += 16;
48314826
}
48324827
}
48334828

48344829
simdgroup_store(mqk, ps, SH, 0, false);
48354830

4836-
pk += 8*(NSG*NS10 - DK8);
4837-
pq += 8*(NSG*0 - DK8);
4831+
pk += 8*(NSG*NS10);
48384832
ps += 8*(NSG);
48394833
}
48404834
} else {
@@ -4961,44 +4955,38 @@ void kernel_flash_attn_ext_impl(
49614955
auto sot = so + 8*sgitg;
49624956

49634957
FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
4964-
simdgroup_load(lo[ii], sot, PV, 0, false);
4965-
4966-
sot += 8*NSG;
4958+
simdgroup_load(lo[ii], sot + 8*ii*NSG, PV, 0, false);
49674959
}
49684960
}
49694961

49704962
{
4971-
auto sst = ss;
4972-
49734963
device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
49744964

49754965
pv += 8*sgitg;
49764966

49774967
FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
49784968
s8x8_t vs;
4979-
simdgroup_load(vs, sst, SH, 0, false);
4969+
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
49804970

4981-
FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
4982-
v8x8_t mv;
4971+
FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
4972+
v8x8_t mv[2];
49834973

4984-
simdgroup_load(mv, pv, NS20, 0, false);
4985-
simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]);
4974+
simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
4975+
simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
49864976

4987-
pv += 8*NSG;
4977+
simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
4978+
simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
49884979
}
49894980

4990-
pv += 8*(NS20 - NO*NSG);
4991-
sst += 8;
4981+
pv += 8*NS20;
49924982
}
49934983
}
49944984

49954985
{
49964986
auto sot = so + 8*sgitg;
49974987

49984988
FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
4999-
simdgroup_store(lo[ii], sot, PV, 0, false);
5000-
5001-
sot += 8*NSG;
4989+
simdgroup_store(lo[ii], sot + 8*ii*NSG, PV, 0, false);
50024990
}
50034991
}
50044992
} else {

0 commit comments

Comments
 (0)