File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -3450,7 +3450,7 @@ kernel void kernel_flash_attn_ext_vec(
34503450 {
34513451 // each simdgroup processes 1 query and 4 keys
34523452 for (short cc = 0 ; cc < C/4 ; ++cc) {
3453- qk_t mqk = 0.0 ;
3453+ qk_t mqka[ 4 ] = { 0.0 , 0.0 , 0.0 , 0.0 } ;
34543454
34553455 device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4 *cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
34563456
@@ -3461,13 +3461,14 @@ kernel void kernel_flash_attn_ext_vec(
34613461 k4x4_t mk;
34623462 deq_k (pk + i/nl_k, i%nl_k, mk);
34633463
3464- mqk +=
3465- dot (mq[ii/NL][0 ], mk[0 ]) +
3466- dot (mq[ii/NL][1 ], mk[1 ]) +
3467- dot (mq[ii/NL][2 ], mk[2 ]) +
3468- dot (mq[ii/NL][3 ], mk[3 ]);
3464+ mqka[0 ] += dot (mq[ii/NL][0 ], mk[0 ]);
3465+ mqka[1 ] += dot (mq[ii/NL][1 ], mk[1 ]);
3466+ mqka[2 ] += dot (mq[ii/NL][2 ], mk[2 ]);
3467+ mqka[3 ] += dot (mq[ii/NL][3 ], mk[3 ]);
34693468 }
34703469
3470+ qk_t mqk = mqka[0 ] + mqka[1 ] + mqka[2 ] + mqka[3 ];
3471+
34713472 // simdgroup reduce
34723473 // [ 0 .. 7] -> [ 0]
34733474 // [ 8 .. 15] -> [ 8]
You can’t perform that action at this time.
0 commit comments