Skip to content

Commit fe2b775

Browse files
do loop unrolling via C++ template
1 parent dd05446 commit fe2b775

File tree

3 files changed

+74
-77
lines changed

3 files changed

+74
-77
lines changed

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ if (CUDAToolkit_FOUND)
100100

101101
set(CUDA_CXX_FLAGS "")
102102

103-
set(CUDA_FLAGS -use_fast_math)
103+
set(CUDA_FLAGS -use_fast_math -extended-lambda)
104104

105105
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
106106
# Options are:

ggml/src/ggml-cuda/common.cuh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,25 @@ static __device__ void no_device_code(
300300
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
301301
#endif // __CUDA_ARCH__
302302

303+
// The compiler is always able to unroll loops if they contain continue expressions.
304+
// In such cases loop unrolling can still be achieved via recursion:
305+
template <int n>
306+
struct ggml_cuda_unroll {
307+
template <typename Func, typename... Args>
308+
__device__ void operator()(const Func & f, Args... args) const {
309+
f(n - 1, args...);
310+
ggml_cuda_unroll<n - 1>{}(f, args...);
311+
}
312+
};
313+
314+
template <>
315+
struct ggml_cuda_unroll<1> {
316+
template <typename Func, typename... Args>
317+
__device__ void operator()(const Func & f, Args... args) const {
318+
f(0, args...);
319+
}
320+
};
321+
303322
template<int width = WARP_SIZE>
304323
static __device__ __forceinline__ int warp_reduce_sum(int x) {
305324
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 54 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -106,98 +106,76 @@ struct fattn_mma_f16_config<576, 512> {
106106

107107
// ------------------------------------------------------------------------------------------------------------------
108108

109-
// The compiler is unable to unroll loops with the k0_start == k0_stop condition.
110-
// Therefore, write functions for the loop iterations and unroll the loops manually.
109+
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
110+
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
111+
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
111112

112-
template<int stride_tile, int nwarps, int nbatch_fa, int stride_k>
113-
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_async(
114-
const half2 * const __restrict__ KV, const unsigned int tile_KV_32, const int chunks_per_row, const int stride_KV) {
115-
constexpr int preload = 64;
116-
constexpr int h2_per_chunk = 16/sizeof(half2);
113+
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
114+
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
117115

118-
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
119-
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
120-
const int stride_i = WARP_SIZE / stride_k;
116+
if (use_cp_async) {
117+
constexpr int preload = 64;
118+
constexpr int h2_per_chunk = 16/sizeof(half2);
119+
const int chunks_per_row = D2 / h2_per_chunk;
121120

122-
if (k0_start == k0_stop) {
123-
return;
124-
}
121+
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
125122

126-
#pragma unroll
127-
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
128-
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
123+
auto load = [&] __device__ (const int n) {
124+
const int stride_k = WARP_SIZE >> n;
125+
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
126+
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
127+
const int stride_i = WARP_SIZE / stride_k;
129128

130-
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
131-
break;
132-
}
129+
if (k0_start == k0_stop) {
130+
return;
131+
}
133132

134133
#pragma unroll
135-
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
136-
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
137-
138-
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
139-
}
140-
}
141-
}
142-
143-
template<int stride_tile, int nwarps, int nbatch_fa, int stride_k>
144-
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_sync(
145-
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
146-
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
147-
const int k0_stop = D2 - D2 % (1*stride_k);
148-
const int stride_i = WARP_SIZE / stride_k;
134+
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
135+
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
149136

150-
if (k0_start == k0_stop) {
151-
return;
152-
}
137+
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
138+
break;
139+
}
153140

154141
#pragma unroll
155-
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
156-
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
142+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
143+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
157144

158-
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
159-
break;
160-
}
145+
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
146+
}
147+
}
148+
};
149+
ggml_cuda_unroll<5>{}(load);
150+
} else {
151+
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
152+
auto load = [&] __device__ (const int n) {
153+
const int stride_k = WARP_SIZE >> n;
154+
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
155+
const int k0_stop = D2 - D2 % (1*stride_k);
156+
const int stride_i = WARP_SIZE / stride_k;
157+
158+
if (k0_start == k0_stop) {
159+
return;
160+
}
161161

162162
#pragma unroll
163-
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
164-
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
165-
166-
tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
167-
}
168-
}
169-
}
163+
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
164+
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
170165

171-
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
172-
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
173-
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
174-
175-
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
176-
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
166+
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
167+
break;
168+
}
177169

178-
if (use_cp_async) {
179-
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
180-
constexpr int h2_per_chunk = 16/sizeof(half2);
181-
const int chunks_per_row = D2 / h2_per_chunk;
170+
#pragma unroll
171+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
172+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
182173

183-
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE>
184-
(KV, tile_KV_32, chunks_per_row, stride_KV);
185-
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/2>
186-
(KV, tile_KV_32, chunks_per_row, stride_KV);
187-
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/4>
188-
(KV, tile_KV_32, chunks_per_row, stride_KV);
189-
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/8>
190-
(KV, tile_KV_32, chunks_per_row, stride_KV);
191-
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/16>
192-
(KV, tile_KV_32, chunks_per_row, stride_KV);
193-
} else {
194-
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
195-
flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE>
196-
(KV, tile_KV, D2, stride_KV);
197-
flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE/2>
198-
(KV, tile_KV, D2, stride_KV);
199-
flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE/4>
200-
(KV, tile_KV, D2, stride_KV);
174+
tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
175+
}
176+
}
177+
};
178+
ggml_cuda_unroll<3>{}(load);
201179
}
202180
}
203181

0 commit comments

Comments
 (0)