File tree Expand file tree Collapse file tree 2 files changed +8
-10
lines changed Expand file tree Collapse file tree 2 files changed +8
-10
lines changed Original file line number Diff line number Diff line change @@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
174174 K += blockIdx .y *D * nb11;
175175 V += blockIdx .y *D * nb21;
176176 maskh += blockIdx .y *D;
177- for (int k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D) {
177+ for (int k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D,
178+ // Increment pointers after each loop:
179+ K += gridDim .y *D*nb11, V += gridDim .y *D*nb21, maskh += gridDim .y *D) {
180+
178181 // Calculate KQ tile and keep track of new maximum KQ values:
179182
180183 if (mask) {
@@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
291294 }
292295 }
293296
294- K += gridDim .y *D * nb11;
295- V += gridDim .y *D * nb21;
296- maskh += gridDim .y *D;
297-
298297 __syncthreads ();
299298 }
300299
Original file line number Diff line number Diff line change @@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
180180 K += blockIdx .y *D * nb11;
181181 V += blockIdx .y *D * nb21;
182182 maskh += blockIdx .y *D;
183- for (int k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D) {
183+ for (int k_VKQ_0 = blockIdx .y *D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *D,
184+ // Increment pointers after each loop:
185+ K += gridDim .y *D*nb11, V += gridDim .y *D*nb21, maskh += gridDim .y *D) {
186+
184187 // Calculate KQ tile and keep track of new maximum KQ values:
185188
186189 if (mask) {
@@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
286289 }
287290 }
288291
289- K += gridDim .y *D * nb11;
290- V += gridDim .y *D * nb21;
291- maskh += gridDim .y *D;
292-
293292 __syncthreads ();
294293 }
295294
You can’t perform that action at this time.
0 commit comments