@@ -89,7 +89,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
8989 float2 & KQ_max,
9090 float2 & KQ_rowsum,
9191 const int kb0) {
92-
92+ # ifdef NEW_MMA_AVAILABLE
9393 constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
9494 constexpr int D2_padded = D/2 + 4 ; // Size of D in half2, padded to avoid shared memory bank conflicts.
9595
@@ -241,6 +241,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
241241#ifndef CP_ASYNC_AVAILABLE
242242 __syncthreads (); // Only needed if tile_K == tile_V.
243243#endif // CP_ASYNC_AVAILABLE
244+
245+ #else
246+ NO_DEVICE_CODE;
247+ #endif // NEW_MMA_AVAILABLE
244248}
245249
246250template <int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
@@ -262,6 +266,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
262266 const int jt,
263267 const int kb0_start,
264268 const int kb0_stop) {
269+ #ifdef NEW_MMA_AVAILABLE
265270 // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
266271
267272 static_assert (nwarps*tile_B::I % ncols == 0 , " bad nwarps" );
@@ -518,6 +523,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
518523 if (np > 1 ) {
519524 __syncthreads ();
520525 }
526+ #else
527+ NO_DEVICE_CODE;
528+ #endif // NEW_MMA_AVAILABLE
521529}
522530
523531template <int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap>
0 commit comments