Skip to content

Commit dd05446

Browse files
fix loop unrolling for KV data load
1 parent 187054a commit dd05446

File tree

1 file changed

+78
-55
lines changed

1 file changed

+78
-55
lines changed

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 78 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -104,75 +104,100 @@ struct fattn_mma_f16_config<576, 512> {
104104
static constexpr int nbatch_combine = 128;
105105
};
106106

107-
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
108-
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
109-
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
107+
// ------------------------------------------------------------------------------------------------------------------
110108

111-
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
112-
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
109+
// The compiler is unable to unroll loops with the k0_start == k0_stop condition.
110+
// Therefore, write functions for the loop iterations and unroll the loops manually.
113111

114-
if (use_cp_async) {
115-
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
112+
template<int stride_tile, int nwarps, int nbatch_fa, int stride_k>
113+
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_async(
114+
const half2 * const __restrict__ KV, const unsigned int tile_KV_32, const int chunks_per_row, const int stride_KV) {
115+
constexpr int preload = 64;
116+
constexpr int h2_per_chunk = 16/sizeof(half2);
116117

117-
constexpr int preload = 64;
118-
constexpr int h2_per_chunk = 16/sizeof(half2);
118+
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
119+
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
120+
const int stride_i = WARP_SIZE / stride_k;
119121

120-
const int chunks_per_row = D2 / h2_per_chunk;
122+
if (k0_start == k0_stop) {
123+
return;
124+
}
121125

122126
#pragma unroll
123-
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4, WARP_SIZE/8, WARP_SIZE/16}) {
124-
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
125-
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
126-
const int stride_i = WARP_SIZE / stride_k;
127+
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
128+
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
127129

128-
if (k0_start == k0_stop) {
129-
continue;
130-
}
130+
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
131+
break;
132+
}
131133

132134
#pragma unroll
133-
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
134-
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
135+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
136+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
135137

136-
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
137-
break;
138-
}
138+
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
139+
}
140+
}
141+
}
142+
143+
template<int stride_tile, int nwarps, int nbatch_fa, int stride_k>
144+
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_sync(
145+
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
146+
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
147+
const int k0_stop = D2 - D2 % (1*stride_k);
148+
const int stride_i = WARP_SIZE / stride_k;
149+
150+
if (k0_start == k0_stop) {
151+
return;
152+
}
139153

140154
#pragma unroll
141-
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
142-
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
155+
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
156+
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
143157

144-
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
145-
}
146-
}
158+
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
159+
break;
147160
}
148-
} else {
149-
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
161+
150162
#pragma unroll
151-
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
152-
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
153-
const int k0_stop = D2 - D2 % (1*stride_k);
154-
const int stride_i = WARP_SIZE / stride_k;
163+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
164+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
155165

156-
if (k0_start == k0_stop || k0_stop <= 0) {
157-
continue;
158-
}
166+
tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
167+
}
168+
}
169+
}
159170

160-
#pragma unroll
161-
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
162-
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
171+
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
172+
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
173+
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
163174

164-
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
165-
break;
166-
}
175+
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
176+
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
167177

168-
#pragma unroll
169-
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
170-
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
178+
if (use_cp_async) {
179+
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
180+
constexpr int h2_per_chunk = 16/sizeof(half2);
181+
const int chunks_per_row = D2 / h2_per_chunk;
171182

172-
tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
173-
}
174-
}
175-
}
183+
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE>
184+
(KV, tile_KV_32, chunks_per_row, stride_KV);
185+
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/2>
186+
(KV, tile_KV_32, chunks_per_row, stride_KV);
187+
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/4>
188+
(KV, tile_KV_32, chunks_per_row, stride_KV);
189+
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/8>
190+
(KV, tile_KV_32, chunks_per_row, stride_KV);
191+
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/16>
192+
(KV, tile_KV_32, chunks_per_row, stride_KV);
193+
} else {
194+
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
195+
flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE>
196+
(KV, tile_KV, D2, stride_KV);
197+
flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE/2>
198+
(KV, tile_KV, D2, stride_KV);
199+
flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE/4>
200+
(KV, tile_KV, D2, stride_KV);
176201
}
177202
}
178203

@@ -848,10 +873,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
848873
}
849874
#pragma unroll
850875
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
851-
if (offset >= WARP_SIZE) {
852-
continue;
876+
if (offset < WARP_SIZE) {
877+
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
853878
}
854-
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
855879
}
856880

857881
float KQ_cms[nmeta]; // KQ combine max scale per warp.
@@ -867,10 +891,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
867891
}
868892
#pragma unroll
869893
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
870-
if (offset >= WARP_SIZE) {
871-
continue;
894+
if (offset < WARP_SIZE) {
895+
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
872896
}
873-
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
874897
}
875898

876899
// Write back combined meta data:

0 commit comments

Comments
 (0)