From f380e1bb244b34f3eab5e21d516bfb6afd8f239a Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Tue, 15 Jul 2025 14:01:44 -0400 Subject: [PATCH] [ROCm] warpSize is being made non constexpr in ROCm 7.0 (#20330) Signed-off-by: Gregory Shtrasberg --- csrc/attention/attention_kernels.cuh | 8 +------- csrc/attention/paged_attention_v1.cu | 8 +------- csrc/attention/paged_attention_v2.cu | 8 +------- csrc/cuda_compat.h | 6 +++--- 4 files changed, 6 insertions(+), 24 deletions(-) diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index eb216dc8baf1..eefa2815a39f 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -24,6 +24,7 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" +#include "cuda_compat.h" #ifdef USE_ROCM #include @@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16; #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif -#ifndef USE_ROCM - #define WARP_SIZE 32 -#else - #define WARP_SIZE warpSize -#endif - #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 9b3a5c4b1014..83cc0dc47507 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -18,12 +18,7 @@ */ #include "attention_kernels.cuh" - -#ifndef USE_ROCM - #define WARP_SIZE 32 -#else - #define WARP_SIZE warpSize -#endif +#include "cuda_compat.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -190,7 +185,6 @@ void paged_attention_v1( CALL_V1_LAUNCHER_BLOCK_SIZE) } -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 9935359e02fb..1d48c3da681e 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -18,12 +18,7 @@ */ #include "attention_kernels.cuh" - -#ifndef USE_ROCM - #define WARP_SIZE 32 -#else - #define WARP_SIZE warpSize -#endif +#include "cuda_compat.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -200,7 +195,6 @@ void paged_attention_v2( CALL_V2_LAUNCHER_BLOCK_SIZE) } -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 82e55613d915..affa051c7595 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -4,10 +4,10 @@ #include #endif -#ifndef USE_ROCM - #define WARP_SIZE 32 +#if defined(USE_ROCM) && defined(__GFX9__) + #define WARP_SIZE 64 #else - #define WARP_SIZE warpSize + #define WARP_SIZE 32 #endif #ifndef USE_ROCM