File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -3450,7 +3450,7 @@ kernel void kernel_flash_attn_ext_vec(
34503450            {
34513451                //  each simdgroup processes 1 query and 4 keys
34523452                for  (short  cc = 0 ; cc < C/4 ; ++cc) {
3453-                     qk_t  mqk  = 0.0 ;
3453+                     qk_t  mqka[ 4 ]  = {  0.0 ,  0.0 ,  0.0 ,  0.0  } ;
34543454
34553455                    device const  kd4x4_t  * pk = (device const  kd4x4_t  *) ((device const  char  *) k + ((ic + 4 *cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
34563456
@@ -3461,13 +3461,14 @@ kernel void kernel_flash_attn_ext_vec(
34613461                        k4x4_t  mk;
34623462                        deq_k (pk + i/nl_k, i%nl_k, mk);
34633463
3464-                         mqk +=
3465-                             dot (mq[ii/NL][0 ], mk[0 ]) +
3466-                             dot (mq[ii/NL][1 ], mk[1 ]) +
3467-                             dot (mq[ii/NL][2 ], mk[2 ]) +
3468-                             dot (mq[ii/NL][3 ], mk[3 ]);
3464+                         mqka[0 ] += dot (mq[ii/NL][0 ], mk[0 ]);
3465+                         mqka[1 ] += dot (mq[ii/NL][1 ], mk[1 ]);
3466+                         mqka[2 ] += dot (mq[ii/NL][2 ], mk[2 ]);
3467+                         mqka[3 ] += dot (mq[ii/NL][3 ], mk[3 ]);
34693468                    }
34703469
3470+                     qk_t  mqk = mqka[0 ] + mqka[1 ] + mqka[2 ] + mqka[3 ];
3471+ 
34713472                    //  simdgroup reduce
34723473                    //  [ 0 ..  7] -> [ 0]
34733474                    //  [ 8 .. 15] -> [ 8]
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments