File tree Expand file tree Collapse file tree 2 files changed +8
-10
lines changed Expand file tree Collapse file tree 2 files changed +8
-10
lines changed Original file line number Diff line number Diff line change @@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
174174    K     += blockIdx .y *D * nb11;
175175    V     += blockIdx .y *D * nb21;
176176    maskh += blockIdx .y *D;
177-     for  (int  k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D) {
177+     for  (int  k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D,
178+              //  Increment pointers after each loop:
179+              K += gridDim .y *D*nb11, V += gridDim .y *D*nb21, maskh += gridDim .y *D) {
180+ 
178181        //  Calculate KQ tile and keep track of new maximum KQ values:
179182
180183        if  (mask) {
@@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
291294            }
292295        }
293296
294-         K     += gridDim .y *D * nb11;
295-         V     += gridDim .y *D * nb21;
296-         maskh += gridDim .y *D;
297- 
298297        __syncthreads ();
299298    }
300299
Original file line number Diff line number Diff line change @@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
180180    K     += blockIdx .y *D * nb11;
181181    V     += blockIdx .y *D * nb21;
182182    maskh += blockIdx .y *D;
183-     for  (int  k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D) {
183+     for  (int  k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D,
184+              //  Increment pointers after each loop:
185+              K += gridDim .y *D*nb11, V += gridDim .y *D*nb21, maskh += gridDim .y *D) {
186+ 
184187        //  Calculate KQ tile and keep track of new maximum KQ values:
185188
186189        if  (mask) {
@@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
286289            }
287290        }
288291
289-         K     += gridDim .y *D * nb11;
290-         V     += gridDim .y *D * nb21;
291-         maskh += gridDim .y *D;
292- 
293292        __syncthreads ();
294293    }
295294
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments