33
44#include < mma.h>
55
6- static __device__ __forceinline__ half2 warp_reduce_sum (half2 a) {
7- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
8- #pragma unroll
9- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
10- a = __hadd2 (a, __shfl_xor_sync (0xffffffff , a, mask, 32 ));
11- }
12- return a;
13- #else
14- GGML_UNUSED (a);
15- NO_DEVICE_CODE;
16- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
17- }
18-
19- // static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
20- // #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
21- // #pragma unroll
22- // for (int mask = 16; mask > 0; mask >>= 1) {
23- // x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
24- // }
25- // return x;
26- // #else
27- // GGML_UNUSED(x);
28- // NO_DEVICE_CODE;
29- // #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
30- // }
31-
326#define FATTN_KQ_STRIDE 256
337
348template <int D, int parallel_blocks> // D == head size
@@ -61,6 +35,7 @@ static __global__ void flash_attn_vec_ext_f16(
6135 const int ne1,
6236 const int ne2,
6337 const int ne3) {
38+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
6439 // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
6540 const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
6641 const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx .y );
@@ -201,6 +176,9 @@ static __global__ void flash_attn_vec_ext_f16(
201176 dst_meta[blockIdx .y *parallel_blocks + blockIdx .x ] = make_half2 (kqmax, kqsum);
202177 }
203178 }
179+ #else
180+ NO_DEVICE_CODE;
181+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
204182}
205183
206184template <int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
@@ -233,6 +211,7 @@ static __global__ void flash_attn_ext_f16(
233211 const int ne1,
234212 const int ne2,
235213 const int ne3) {
214+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
236215 // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
237216 static_assert (D <= FATTN_KQ_STRIDE, " D must be <= FATTN_KQ_STRIDE." );
238217 static_assert (ncols == 8 || ncols % 16 == 0 , " ncols must be 8 or a multiple of 16." );
@@ -491,6 +470,9 @@ static __global__ void flash_attn_ext_f16(
491470 __low2half (KQ_max[0 ]), __low2half (KQ_rowsum[0 ]) + __high2half (KQ_rowsum[0 ]));
492471 }
493472 }
473+ #else
474+ NO_DEVICE_CODE;
475+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
494476}
495477
496478template <int D, int parallel_blocks> // D == head size
@@ -499,6 +481,7 @@ static __global__ void flash_attn_combine_results(
499481 const float * __restrict__ VKQ_parts,
500482 const half2 * __restrict__ VKQ_meta,
501483 float * __restrict__ dst) {
484+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
502485
503486 const int tid = threadIdx .x ;
504487 __builtin_assume (tid < D);
@@ -527,6 +510,9 @@ static __global__ void flash_attn_combine_results(
527510 }
528511
529512 dst[blockIdx .y *D + tid] = VKQ_numerator / VKQ_denominator;
513+ #else
514+ NO_DEVICE_CODE;
515+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
530516}
531517
532518constexpr int get_max_power_of_2 (int x) {
0 commit comments