Skip to content

Commit 3574c15

Browse files
committed
cont : more unroll
1 parent 25a7705 commit 3574c15

File tree

2 files changed

+45
-16
lines changed

2 files changed

+45
-16
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4955,7 +4955,9 @@ void kernel_flash_attn_ext_impl(
49554955
auto sot = so + 8*sgitg;
49564956

49574957
FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
4958-
simdgroup_load(lo[ii], sot + 8*ii*NSG, PV, 0, false);
4958+
simdgroup_load(lo[ii], sot, PV, 0, false);
4959+
4960+
sot += 8*NSG;
49594961
}
49604962
}
49614963

@@ -4964,29 +4966,56 @@ void kernel_flash_attn_ext_impl(
49644966

49654967
pv += 8*sgitg;
49664968

4967-
FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
4968-
s8x8_t vs;
4969-
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
4969+
if (DV <= 64) {
4970+
FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
4971+
s8x8_t vs;
4972+
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
4973+
4974+
FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
4975+
v8x8_t mv[2];
49704976

4971-
FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
4972-
v8x8_t mv[2];
4977+
simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
4978+
simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
49734979

4974-
simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
4975-
simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
4980+
simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
4981+
simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
4982+
}
49764983

4977-
simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
4978-
simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
4984+
pv += 8*NS20;
49794985
}
4986+
} else {
4987+
FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
4988+
s8x8_t vs[2];
4989+
4990+
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
4991+
simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);
49804992

4981-
pv += 8*NS20;
4993+
FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
4994+
v8x8_t mv[4];
4995+
4996+
simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
4997+
simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
4998+
simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
4999+
simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
5000+
5001+
simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);
5002+
simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);
5003+
simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);
5004+
simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);
5005+
}
5006+
5007+
pv += 2*8*NS20;
5008+
}
49825009
}
49835010
}
49845011

49855012
{
49865013
auto sot = so + 8*sgitg;
49875014

49885015
FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
4989-
simdgroup_store(lo[ii], sot + 8*ii*NSG, PV, 0, false);
5016+
simdgroup_store(lo[ii], sot, PV, 0, false);
5017+
5018+
sot += 8*NSG;
49905019
}
49915020
}
49925021
} else {

tests/test-backend-ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6884,10 +6884,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
68846884
}
68856885
}
68866886

6887-
for (int kv : { 4096, 8192, 16384, }) {
6888-
for (int hs : { 64, 128, }) {
6889-
for (int nr : { 1, 4, }) {
6890-
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
6887+
for (int kv : { 512, 4096, 8192, 16384, 32768, 65536, }) {
6888+
for (int hs : { 40, 64, 128, 256 }) {
6889+
for (int nr : { 1, }) {
6890+
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 2048, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
68916891
}
68926892
}
68936893
}

0 commit comments

Comments
 (0)