@@ -16,18 +16,18 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
1616#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
1717}
1818
19- static __device__ __forceinline__ half2 warp_reduce_max (half2 x) {
20- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
21- #pragma unroll
22- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
23- x = __hmax2 (x, __shfl_xor_sync (0xffffffff , x, mask, 32 ));
24- }
25- return x;
26- #else
27- GGML_UNUSED (x);
28- NO_DEVICE_CODE;
29- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
30- }
19+ // static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
20+ // #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
21+ // #pragma unroll
22+ // for (int mask = 16; mask > 0; mask >>= 1) {
23+ // x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
24+ // }
25+ // return x;
26+ // #else
27+ // GGML_UNUSED(x);
28+ // NO_DEVICE_CODE;
29+ // #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
30+ // }
3131
3232#define FATTN_KQ_STRIDE 256
3333
@@ -472,21 +472,20 @@ static __global__ void flash_attn_ext_f16(
472472 dst[D*gridDim .y *(ncols*blockIdx .x + j) + D*blockIdx .y + i] = __half2float (VKQ[j*D_padded + i]) / KQ_rowsum_j;
473473 }
474474 }
475- return ;
476- }
477-
475+ } else {
478476#pragma unroll
479- for (int i0 = 0 ; i0 < D; i0 += nwarps*WARP_SIZE) {
480- const int i = i0 + threadIdx .y *WARP_SIZE + threadIdx .x ;
481- if (i0 + nwarps*WARP_SIZE > D && i >= D) {
482- return ;
477+ for (int i0 = 0 ; i0 < D; i0 += nwarps*WARP_SIZE) {
478+ const int i = i0 + threadIdx .y *WARP_SIZE + threadIdx .x ;
479+ if (i0 + nwarps*WARP_SIZE > D && i >= D) {
480+ return ;
481+ }
482+ dst[D*gridDim .y *blockIdx .x + D*blockIdx .y + i] = VKQ[i];
483483 }
484- dst[D*gridDim .y *blockIdx .x + D*blockIdx .y + i] = VKQ[i];
485- }
486484
487- if (threadIdx .y == 0 && threadIdx .x == 0 ) {
488- dst_meta[blockIdx .y *parallel_blocks + blockIdx .x ] = make_half2 (
489- __low2half (KQ_max[0 ]), __low2half (KQ_rowsum[0 ]) + __high2half (KQ_rowsum[0 ]));
485+ if (threadIdx .y == 0 && threadIdx .x == 0 ) {
486+ dst_meta[blockIdx .y *parallel_blocks + blockIdx .x ] = make_half2 (
487+ __low2half (KQ_max[0 ]), __low2half (KQ_rowsum[0 ]) + __high2half (KQ_rowsum[0 ]));
488+ }
490489 }
491490}
492491
@@ -781,7 +780,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
781780 } else {
782781 cols_per_block = 8 ;
783782 }
784- const int frag_m = cols_per_block == 8 ? 32 : 16 ;
785783 constexpr int nwarps = 4 ;
786784 const dim3 blocks_num ((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block, Q->ne [2 ], Q->ne [3 ]);
787785 const dim3 block_dim (WARP_SIZE, nwarps, 1 );
0 commit comments