@@ -106,98 +106,76 @@ struct fattn_mma_f16_config<576, 512> {
106106
107107//  ------------------------------------------------------------------------------------------------------------------
108108
109- //  The compiler is unable to unroll loops with the k0_start == k0_stop condition.
110- //  Therefore, write functions for the loop iterations and unroll the loops manually.
109+ template <int  stride_tile, int  nwarps, int  nbatch_fa, bool  use_cp_async>
110+ static  __device__  __forceinline__  void  flash_attn_ext_f16_load_tile (
111+         const  half2 * const  __restrict__  KV, half2 * const  __restrict__  tile_KV, const  int  D2, const  int  stride_KV) {
111112
112- template <int  stride_tile, int  nwarps, int  nbatch_fa, int  stride_k>
113- static  __device__  __forceinline__  void  flash_attn_ext_f16_load_tile_async_loop_iter_async (
114-         const  half2 * const  __restrict__  KV, const  unsigned  int  tile_KV_32, const  int  chunks_per_row, const  int  stride_KV) {
115-     constexpr  int  preload = 64 ;
116-     constexpr  int  h2_per_chunk = 16 /sizeof (half2);
113+     //  K/V data is loaded with decreasing granularity for D for better memory bandwidth.
114+     //  The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
117115
118-     const  int  k0_start = stride_k == WARP_SIZE ? 0  : chunks_per_row - chunks_per_row % (2 *stride_k);
119-     const  int  k0_stop  =                             chunks_per_row - chunks_per_row % (1 *stride_k);
120-     const  int  stride_i = WARP_SIZE / stride_k;
116+     if  (use_cp_async) {
117+         constexpr  int  preload = 64 ;
118+         constexpr  int  h2_per_chunk = 16 /sizeof (half2);
119+         const  int  chunks_per_row = D2 / h2_per_chunk;
121120
122-     if  (k0_start == k0_stop) {
123-         return ;
124-     }
121+         const  unsigned  int  tile_KV_32 = ggml_cuda_cvta_generic_to_shared (tile_KV);
125122
126- #pragma  unroll
127-     for  (int  i0 = 0 ; i0 < nbatch_fa; i0 += nwarps*stride_i) {
128-         const  int  i = i0 + threadIdx .y *stride_i + (stride_k == WARP_SIZE ? 0  : threadIdx .x  / stride_k);
123+         auto  load = [&] __device__  (const  int  n) {
124+             const  int  stride_k = WARP_SIZE >> n;
125+             const  int  k0_start = stride_k == WARP_SIZE ? 0  : chunks_per_row - chunks_per_row % (2 *stride_k);
126+             const  int  k0_stop  =                             chunks_per_row - chunks_per_row % (1 *stride_k);
127+             const  int  stride_i = WARP_SIZE / stride_k;
129128
130-         if  (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa ) {
131-             break ;
132-         }
129+              if  (k0_start == k0_stop ) {
130+                  return ;
131+              }
133132
134133#pragma  unroll
135-         for  (int  k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
136-             const  int  k = k0 + (stride_k == WARP_SIZE ? threadIdx .x  : threadIdx .x  % stride_k);
137- 
138-             cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof (half2)) + k*16 , KV + i*stride_KV + k*h2_per_chunk);
139-         }
140-     }
141- }
142- 
143- template <int  stride_tile, int  nwarps, int  nbatch_fa, int  stride_k>
144- static  __device__  __forceinline__  void  flash_attn_ext_f16_load_tile_async_loop_iter_sync (
145-         const  half2 * const  __restrict__  KV, half2 * const  __restrict__  tile_KV, const  int  D2, const  int  stride_KV) {
146-     const  int  k0_start = stride_k == WARP_SIZE ? 0  : D2 - D2 % (2 *stride_k);
147-     const  int  k0_stop  =                             D2 - D2 % (1 *stride_k);
148-     const  int  stride_i = WARP_SIZE / stride_k;
134+             for  (int  i0 = 0 ; i0 < nbatch_fa; i0 += nwarps*stride_i) {
135+                 const  int  i = i0 + threadIdx .y *stride_i + (stride_k == WARP_SIZE ? 0  : threadIdx .x  / stride_k);
149136
150-     if  (k0_start == k0_stop ) {
151-         return ;
152-     }
137+                  if  (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa ) {
138+                      break ;
139+                  }
153140
154141#pragma  unroll
155-     for  (int  i0  = 0 ; i0  < nbatch_fa; i0  += nwarps*stride_i ) {
156-         const  int  i  = i0  + threadIdx . y *stride_i +  (stride_k == WARP_SIZE ? 0  : threadIdx .x  /  stride_k);
142+                  for  (int  k0  = k0_start; k0  < k0_stop; k0  += stride_k ) {
143+                      const  int  k  = k0  + (stride_k == WARP_SIZE ? threadIdx . x  : threadIdx .x  %  stride_k);
157144
158-         if  (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
159-             break ;
160-         }
145+                     cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof (half2)) + k*16 , KV + i*stride_KV + k*h2_per_chunk);
146+                 }
147+             }
148+         };
149+         ggml_cuda_unroll<5 >{}(load);
150+     } else  {
151+         static_assert (nbatch_fa % (4 *nwarps) == 0 , " out of bounds"  );
152+         auto  load = [&] __device__  (const  int  n) {
153+             const  int  stride_k = WARP_SIZE >> n;
154+             const  int  k0_start = stride_k == WARP_SIZE ? 0  : D2 - D2 % (2 *stride_k);
155+             const  int  k0_stop  =                             D2 - D2 % (1 *stride_k);
156+             const  int  stride_i = WARP_SIZE / stride_k;
157+ 
158+             if  (k0_start == k0_stop) {
159+                 return ;
160+             }
161161
162162#pragma  unroll
163-         for  (int  k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
164-             const  int  k = k0 + (stride_k == WARP_SIZE ? threadIdx .x  : threadIdx .x  % stride_k);
165- 
166-             tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
167-         }
168-     }
169- }
163+             for  (int  i0 = 0 ; i0 < nbatch_fa; i0 += nwarps*stride_i) {
164+                 const  int  i = i0 + threadIdx .y *stride_i + (stride_k == WARP_SIZE ? 0  : threadIdx .x  / stride_k);
170165
171- template <int  stride_tile, int  nwarps, int  nbatch_fa, bool  use_cp_async>
172- static  __device__  __forceinline__  void  flash_attn_ext_f16_load_tile (
173-         const  half2 * const  __restrict__  KV, half2 * const  __restrict__  tile_KV, const  int  D2, const  int  stride_KV) {
174- 
175-     //  K/V data is loaded with decreasing granularity for D for better memory bandwidth.
176-     //  The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
166+                 if  (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
167+                     break ;
168+                 }
177169
178-     if  (use_cp_async) {
179-         const  unsigned  int  tile_KV_32 = ggml_cuda_cvta_generic_to_shared (tile_KV);
180-         constexpr  int  h2_per_chunk = 16 /sizeof (half2);
181-         const  int  chunks_per_row = D2 / h2_per_chunk;
170+ #pragma  unroll
171+                 for  (int  k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
172+                     const  int  k = k0 + (stride_k == WARP_SIZE ? threadIdx .x  : threadIdx .x  % stride_k);
182173
183-         flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE>
184-             (KV, tile_KV_32, chunks_per_row, stride_KV);
185-         flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/2 >
186-             (KV, tile_KV_32, chunks_per_row, stride_KV);
187-         flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/4 >
188-             (KV, tile_KV_32, chunks_per_row, stride_KV);
189-         flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/8 >
190-             (KV, tile_KV_32, chunks_per_row, stride_KV);
191-         flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/16 >
192-             (KV, tile_KV_32, chunks_per_row, stride_KV);
193-     } else  {
194-         static_assert (nbatch_fa % (4 *nwarps) == 0 , " out of bounds"  );
195-         flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE>
196-             (KV, tile_KV, D2, stride_KV);
197-         flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE/2 >
198-             (KV, tile_KV, D2, stride_KV);
199-         flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE/4 >
200-             (KV, tile_KV, D2, stride_KV);
174+                     tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
175+                 }
176+             }
177+         };
178+         ggml_cuda_unroll<3 >{}(load);
201179    }
202180}
203181
0 commit comments