@@ -4787,7 +4787,8 @@ void kernel_flash_attn_ext_impl(
47874787
47884788 constexpr short NC = (C/8 )/NSG;
47894789
4790- // TODO: not good to unroll for large contexts - not sure why?
4790+ // note: do not unroll for large heads
4791+ #pragma unroll (DK <= 64 ? NC : 1)
47914792 for (short cc = 0 ; cc < NC; ++cc) {
47924793 qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t , 8 >((qk_t ) 0 .0f );
47934794
@@ -4798,15 +4799,12 @@ void kernel_flash_attn_ext_impl(
47984799 FOR_UNROLL (short i = 0 ; i < DK8; ++i) {
47994800 simdgroup_barrier (mem_flags::mem_none);
48004801
4801- simdgroup_load (mk, pk, NS10, 0 , true );
4802- simdgroup_load (mq, pq, DK);
4802+ simdgroup_load (mk, pk + 8 *i , NS10, 0 , true );
4803+ simdgroup_load (mq, pq + 8 *i , DK);
48034804
48044805 simdgroup_barrier (mem_flags::mem_none);
48054806
48064807 simdgroup_multiply_accumulate (mqk, mq, mk, mqk);
4807-
4808- pk += 8 ;
4809- pq += 8 ;
48104808 }
48114809 } else {
48124810 k8x8_t mk[2 ];
@@ -4815,26 +4813,22 @@ void kernel_flash_attn_ext_impl(
48154813 FOR_UNROLL (short i = 0 ; i < DK8/2 ; ++i) {
48164814 simdgroup_barrier (mem_flags::mem_none);
48174815
4818- simdgroup_load (mk [0 ], pk + 0 *8 , NS10, 0 , true );
4819- simdgroup_load (mk [1 ], pk + 1 *8 , NS10, 0 , true );
4816+ simdgroup_load (mq [0 ], pq + 0 *8 + 16 *i, DK );
4817+ simdgroup_load (mq [1 ], pq + 1 *8 + 16 *i, DK );
48204818
4821- simdgroup_load (mq [0 ], pq + 0 *8 , DK );
4822- simdgroup_load (mq [1 ], pq + 1 *8 , DK );
4819+ simdgroup_load (mk [0 ], pk + 0 *8 + 16 *i, NS10, 0 , true );
4820+ simdgroup_load (mk [1 ], pk + 1 *8 + 16 *i, NS10, 0 , true );
48234821
48244822 simdgroup_barrier (mem_flags::mem_none);
48254823
48264824 simdgroup_multiply_accumulate (mqk, mq[0 ], mk[0 ], mqk);
48274825 simdgroup_multiply_accumulate (mqk, mq[1 ], mk[1 ], mqk);
4828-
4829- pk += 16 ;
4830- pq += 16 ;
48314826 }
48324827 }
48334828
48344829 simdgroup_store (mqk, ps, SH, 0 , false );
48354830
4836- pk += 8 *(NSG*NS10 - DK8);
4837- pq += 8 *(NSG*0 - DK8);
4831+ pk += 8 *(NSG*NS10);
48384832 ps += 8 *(NSG);
48394833 }
48404834 } else {
@@ -4961,44 +4955,38 @@ void kernel_flash_attn_ext_impl(
49614955 auto sot = so + 8 *sgitg;
49624956
49634957 FOR_UNROLL (short ii = 0 ; ii < NO; ++ii) {
4964- simdgroup_load (lo[ii], sot, PV, 0 , false );
4965-
4966- sot += 8 *NSG;
4958+ simdgroup_load (lo[ii], sot + 8 *ii*NSG, PV, 0 , false );
49674959 }
49684960 }
49694961
49704962 {
4971- auto sst = ss;
4972-
49734963 device const v_t * pv = (device const v_t *) (v + ic*args.nb21 );
49744964
49754965 pv += 8 *sgitg;
49764966
49774967 FOR_UNROLL (short cc = 0 ; cc < C/8 ; ++cc) {
49784968 s8x8_t vs;
4979- simdgroup_load (vs, sst , SH, 0 , false );
4969+ simdgroup_load (vs, ss + 8 *cc , SH, 0 , false );
49804970
4981- FOR_UNROLL (short ii = 0 ; ii < NO; ++ii) {
4982- v8x8_t mv;
4971+ FOR_UNROLL (short ii = 0 ; ii < NO/ 2 ; ++ii) {
4972+ v8x8_t mv[ 2 ] ;
49834973
4984- simdgroup_load (mv, pv, NS20, 0 , false );
4985- simdgroup_multiply_accumulate (lo[ii ], vs, mv, lo[ii] );
4974+ simdgroup_load (mv[ 0 ] , pv + 0 *NSG + 16 *ii*NSG , NS20, 0 , false );
4975+ simdgroup_load (mv[ 1 ], pv + 8 *NSG + 16 *ii*NSG, NS20, 0 , false );
49864976
4987- pv += 8 *NSG;
4977+ simdgroup_multiply_accumulate (lo[2 *ii + 0 ], vs, mv[0 ], lo[2 *ii + 0 ]);
4978+ simdgroup_multiply_accumulate (lo[2 *ii + 1 ], vs, mv[1 ], lo[2 *ii + 1 ]);
49884979 }
49894980
4990- pv += 8 *(NS20 - NO*NSG);
4991- sst += 8 ;
4981+ pv += 8 *NS20;
49924982 }
49934983 }
49944984
49954985 {
49964986 auto sot = so + 8 *sgitg;
49974987
49984988 FOR_UNROLL (short ii = 0 ; ii < NO; ++ii) {
4999- simdgroup_store (lo[ii], sot, PV, 0 , false );
5000-
5001- sot += 8 *NSG;
4989+ simdgroup_store (lo[ii], sot + 8 *ii*NSG, PV, 0 , false );
50024990 }
50034991 }
50044992 } else {
0 commit comments