Skip to content

Commit bb0d51a

Browse files
JohannesGaesslerggerganov
authored andcommitted
fix excessive KQ_b loads
1 parent e1ecd3b commit bb0d51a

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

ggml-cuda/fattn.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,12 +387,16 @@ static __global__ void flash_attn_ext_f16(
387387

388388
__syncthreads();
389389

390-
frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n];
390+
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
391391
#pragma unroll
392392
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
393393
#pragma unroll
394-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) {
395-
nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded);
394+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
395+
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
396+
nvcuda::wmma::load_matrix_sync(
397+
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
398+
KQ + j0*kqs_padded + k,
399+
kqs_padded);
396400
}
397401
}
398402

@@ -412,7 +416,7 @@ static __global__ void flash_attn_ext_f16(
412416
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
413417
#pragma unroll
414418
for (int j = 0; j < ncols/frag_n; ++j) {
415-
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
419+
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
416420
}
417421
}
418422
}

0 commit comments

Comments
 (0)