@@ -104,75 +104,100 @@ struct fattn_mma_f16_config<576, 512> {
104104 static constexpr int nbatch_combine = 128 ;
105105};
106106
107- template <int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
108- static __device__ __forceinline__ void flash_attn_ext_f16_load_tile (
109- const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
107+ // ------------------------------------------------------------------------------------------------------------------
110108
111- // K/V data is loaded with decreasing granularity for D for better memory bandwidth .
112- // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes .
109+ // The compiler is unable to unroll loops with the k0_start == k0_stop condition .
110+ // Therefore, write functions for the loop iterations and unroll the loops manually .
113111
114- if (use_cp_async) {
115- const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared (tile_KV);
112+ template <int stride_tile, int nwarps, int nbatch_fa, int stride_k>
113+ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_async (
114+ const half2 * const __restrict__ KV, const unsigned int tile_KV_32, const int chunks_per_row, const int stride_KV) {
115+ constexpr int preload = 64 ;
116+ constexpr int h2_per_chunk = 16 /sizeof (half2);
116117
117- constexpr int preload = 64 ;
118- constexpr int h2_per_chunk = 16 /sizeof (half2);
118+ const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2 *stride_k);
119+ const int k0_stop = chunks_per_row - chunks_per_row % (1 *stride_k);
120+ const int stride_i = WARP_SIZE / stride_k;
119121
120- const int chunks_per_row = D2 / h2_per_chunk;
122+ if (k0_start == k0_stop) {
123+ return ;
124+ }
121125
122126#pragma unroll
123- for (int stride_k : {WARP_SIZE, WARP_SIZE/2 , WARP_SIZE/4 , WARP_SIZE/8 , WARP_SIZE/16 }) {
124- const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2 *stride_k);
125- const int k0_stop = chunks_per_row - chunks_per_row % (1 *stride_k);
126- const int stride_i = WARP_SIZE / stride_k;
127+ for (int i0 = 0 ; i0 < nbatch_fa; i0 += nwarps*stride_i) {
128+ const int i = i0 + threadIdx .y *stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx .x / stride_k);
127129
128- if (k0_start == k0_stop ) {
129- continue ;
130- }
130+ if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa ) {
131+ break ;
132+ }
131133
132134#pragma unroll
133- for (int i0 = 0 ; i0 < nbatch_fa; i0 += nwarps*stride_i ) {
134- const int i = i0 + threadIdx . y *stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx .x / stride_k);
135+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k ) {
136+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx . x : threadIdx .x % stride_k);
135137
136- if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
137- break ;
138- }
138+ cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof (half2)) + k*16 , KV + i*stride_KV + k*h2_per_chunk);
139+ }
140+ }
141+ }
142+
143+ template <int stride_tile, int nwarps, int nbatch_fa, int stride_k>
144+ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_sync (
145+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
146+ const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2 *stride_k);
147+ const int k0_stop = D2 - D2 % (1 *stride_k);
148+ const int stride_i = WARP_SIZE / stride_k;
149+
150+ if (k0_start == k0_stop) {
151+ return ;
152+ }
139153
140154#pragma unroll
141- for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k ) {
142- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx . x : threadIdx .x % stride_k);
155+ for (int i0 = 0 ; i0 < nbatch_fa; i0 += nwarps*stride_i ) {
156+ const int i = i0 + threadIdx . y *stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx .x / stride_k);
143157
144- cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof (half2)) + k*16 , KV + i*stride_KV + k*h2_per_chunk);
145- }
146- }
158+ if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
159+ break ;
147160 }
148- } else {
149- static_assert (nbatch_fa % (4 *nwarps) == 0 , " out of bounds" );
161+
150162#pragma unroll
151- for (int stride_k : {WARP_SIZE, WARP_SIZE/2 , WARP_SIZE/4 }) {
152- const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2 *stride_k);
153- const int k0_stop = D2 - D2 % (1 *stride_k);
154- const int stride_i = WARP_SIZE / stride_k;
163+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
164+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx .x : threadIdx .x % stride_k);
155165
156- if (k0_start == k0_stop || k0_stop <= 0 ) {
157- continue ;
158- }
166+ tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
167+ }
168+ }
169+ }
159170
160- # pragma unroll
161- for ( int i0 = 0 ; i0 < nbatch_fa; i0 += nwarps*stride_i) {
162- const int i = i0 + threadIdx . y *stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx . x / stride_k);
171+ template < int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
172+ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile (
173+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
163174
164- if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
165- break ;
166- }
175+ // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
176+ // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
167177
168- #pragma unroll
169- for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
170- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx .x : threadIdx .x % stride_k);
178+ if (use_cp_async) {
179+ const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared (tile_KV);
180+ constexpr int h2_per_chunk = 16 /sizeof (half2);
181+ const int chunks_per_row = D2 / h2_per_chunk;
171182
172- tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
173- }
174- }
175- }
183+ flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE>
184+ (KV, tile_KV_32, chunks_per_row, stride_KV);
185+ flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/2 >
186+ (KV, tile_KV_32, chunks_per_row, stride_KV);
187+ flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/4 >
188+ (KV, tile_KV_32, chunks_per_row, stride_KV);
189+ flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/8 >
190+ (KV, tile_KV_32, chunks_per_row, stride_KV);
191+ flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/16 >
192+ (KV, tile_KV_32, chunks_per_row, stride_KV);
193+ } else {
194+ static_assert (nbatch_fa % (4 *nwarps) == 0 , " out of bounds" );
195+ flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE>
196+ (KV, tile_KV, D2, stride_KV);
197+ flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE/2 >
198+ (KV, tile_KV, D2, stride_KV);
199+ flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE/4 >
200+ (KV, tile_KV, D2, stride_KV);
176201 }
177202}
178203
@@ -848,10 +873,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
848873 }
849874#pragma unroll
850875 for (int offset = np*cols_per_warp/2 ; offset >= cols_per_warp; offset >>= 1 ) {
851- if (offset >= WARP_SIZE) {
852- continue ;
876+ if (offset < WARP_SIZE) {
877+ KQ_cmn = fmaxf (KQ_cmn, __shfl_xor_sync ( 0xFFFFFFFF , KQ_cmn, offset, WARP_SIZE)) ;
853878 }
854- KQ_cmn = fmaxf (KQ_cmn, __shfl_xor_sync (0xFFFFFFFF , KQ_cmn, offset, WARP_SIZE));
855879 }
856880
857881 float KQ_cms[nmeta]; // KQ combine max scale per warp.
@@ -867,10 +891,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
867891 }
868892#pragma unroll
869893 for (int offset = np*cols_per_warp/2 ; offset >= cols_per_warp; offset >>= 1 ) {
870- if (offset >= WARP_SIZE) {
871- continue ;
894+ if (offset < WARP_SIZE) {
895+ KQ_crs += __shfl_xor_sync ( 0xFFFFFFFF , KQ_crs, offset, WARP_SIZE) ;
872896 }
873- KQ_crs += __shfl_xor_sync (0xFFFFFFFF , KQ_crs, offset, WARP_SIZE);
874897 }
875898
876899 // Write back combined meta data:
0 commit comments