@@ -4955,7 +4955,9 @@ void kernel_flash_attn_ext_impl(
49554955 auto sot = so + 8 *sgitg;
49564956
49574957 FOR_UNROLL (short ii = 0 ; ii < NO; ++ii) {
4958- simdgroup_load (lo[ii], sot + 8 *ii*NSG, PV, 0 , false );
4958+ simdgroup_load (lo[ii], sot, PV, 0 , false );
4959+
4960+ sot += 8 *NSG;
49594961 }
49604962 }
49614963
@@ -4964,29 +4966,56 @@ void kernel_flash_attn_ext_impl(
49644966
49654967 pv += 8 *sgitg;
49664968
4967- FOR_UNROLL (short cc = 0 ; cc < C/8 ; ++cc) {
4968- s8x8_t vs;
4969- simdgroup_load (vs, ss + 8 *cc, SH, 0 , false );
4969+ if (DV <= 64 ) {
4970+ FOR_UNROLL (short cc = 0 ; cc < C/8 ; ++cc) {
4971+ s8x8_t vs;
4972+ simdgroup_load (vs, ss + 8 *cc, SH, 0 , false );
4973+
4974+ FOR_UNROLL (short ii = 0 ; ii < NO/2 ; ++ii) {
4975+ v8x8_t mv[2 ];
49704976
4971- FOR_UNROLL ( short ii = 0 ; ii < NO/ 2 ; ++ii) {
4972- v8x8_t mv[2 ] ;
4977+ simdgroup_load (mv[ 0 ], pv + 0 *NSG + 16 *ii*NSG, NS20, 0 , false );
4978+ simdgroup_load ( mv[1 ], pv + 8 *NSG + 16 *ii*NSG, NS20, 0 , false ) ;
49734979
4974- simdgroup_load (mv[0 ], pv + 0 *NSG + 16 *ii*NSG, NS20, 0 , false );
4975- simdgroup_load (mv[1 ], pv + 8 *NSG + 16 *ii*NSG, NS20, 0 , false );
4980+ simdgroup_multiply_accumulate (lo[2 *ii + 0 ], vs, mv[0 ], lo[2 *ii + 0 ]);
4981+ simdgroup_multiply_accumulate (lo[2 *ii + 1 ], vs, mv[1 ], lo[2 *ii + 1 ]);
4982+ }
49764983
4977- simdgroup_multiply_accumulate (lo[2 *ii + 0 ], vs, mv[0 ], lo[2 *ii + 0 ]);
4978- simdgroup_multiply_accumulate (lo[2 *ii + 1 ], vs, mv[1 ], lo[2 *ii + 1 ]);
4984+ pv += 8 *NS20;
49794985 }
4986+ } else {
4987+ FOR_UNROLL (short cc = 0 ; cc < (C/8 )/2 ; ++cc) {
4988+ s8x8_t vs[2 ];
4989+
4990+ simdgroup_load (vs[0 ], ss + 16 *cc + 0 , SH, 0 , false );
4991+ simdgroup_load (vs[1 ], ss + 16 *cc + 8 , SH, 0 , false );
49804992
4981- pv += 8 *NS20;
4993+ FOR_UNROLL (short ii = 0 ; ii < NO/2 ; ++ii) {
4994+ v8x8_t mv[4 ];
4995+
4996+ simdgroup_load (mv[0 ], pv + 0 *NSG + 16 *ii*NSG + 0 *8 *NS20, NS20, 0 , false );
4997+ simdgroup_load (mv[1 ], pv + 8 *NSG + 16 *ii*NSG + 0 *8 *NS20, NS20, 0 , false );
4998+ simdgroup_load (mv[2 ], pv + 0 *NSG + 16 *ii*NSG + 1 *8 *NS20, NS20, 0 , false );
4999+ simdgroup_load (mv[3 ], pv + 8 *NSG + 16 *ii*NSG + 1 *8 *NS20, NS20, 0 , false );
5000+
5001+ simdgroup_multiply_accumulate (lo[2 *ii + 0 ], vs[0 ], mv[0 ], lo[2 *ii + 0 ]);
5002+ simdgroup_multiply_accumulate (lo[2 *ii + 1 ], vs[0 ], mv[1 ], lo[2 *ii + 1 ]);
5003+ simdgroup_multiply_accumulate (lo[2 *ii + 0 ], vs[1 ], mv[2 ], lo[2 *ii + 0 ]);
5004+ simdgroup_multiply_accumulate (lo[2 *ii + 1 ], vs[1 ], mv[3 ], lo[2 *ii + 1 ]);
5005+ }
5006+
5007+ pv += 2 *8 *NS20;
5008+ }
49825009 }
49835010 }
49845011
49855012 {
49865013 auto sot = so + 8 *sgitg;
49875014
49885015 FOR_UNROLL (short ii = 0 ; ii < NO; ++ii) {
4989- simdgroup_store (lo[ii], sot + 8 *ii*NSG, PV, 0 , false );
5016+ simdgroup_store (lo[ii], sot, PV, 0 , false );
5017+
5018+ sot += 8 *NSG;
49905019 }
49915020 }
49925021 } else {
0 commit comments