From ba72130af08cf43ce0f9e4b86ae388478421ccb7 Mon Sep 17 00:00:00 2001 From: Slobodan Josic Date: Sun, 10 Aug 2025 18:49:14 -0400 Subject: [PATCH 1/6] Fix HIP warp synchronization mask compatibility for ROCm 7.0+ --- ggml/src/ggml-cuda/argmax.cu | 8 ++++---- ggml/src/ggml-cuda/common.cuh | 24 ++++++++++++------------ ggml/src/ggml-cuda/fattn-common.cuh | 4 ++-- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 10 +++++----- ggml/src/ggml-cuda/mma.cuh | 4 ++-- ggml/src/ggml-cuda/quantize.cu | 4 ++-- ggml/src/ggml-cuda/vendors/hip.h | 14 +++++++++++--- 7 files changed, 38 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index 5340eedc08916..68accf9402e2b 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -22,8 +22,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { - const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); - const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + const float val = __shfl_xor_sync(GGML_WARP_SYNC_MASK, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(GGML_WARP_SYNC_MASK, argmax, offset, WARP_SIZE); if (val > maxval) { maxval = val; argmax = col; @@ -51,8 +51,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest } #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { - const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); - const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + const float val = __shfl_xor_sync(GGML_WARP_SYNC_MASK, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(GGML_WARP_SYNC_MASK, argmax, offset, WARP_SIZE); if (val > maxval) { maxval = val; argmax = col; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2e5d48797fa49..87d350d49818e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -375,22 +375,22 @@ struct ggml_cuda_unroll<1> { template static __device__ __forceinline__ int warp_reduce_sum(int x) { -#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE - return __reduce_add_sync(0xffffffff, x); +#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE) || (defined(GGML_USE_HIP) && HIP_VERSION >= 70000000) + return __reduce_add_sync(GGML_WARP_SYNC_MASK, x); #else #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, offset, width); + x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width); } return x; -#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE) || (defined(GGML_USE_HIP) && HIP_VERSION >= 70000000) } template static __device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, offset, width); + x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width); } return x; } @@ -399,8 +399,8 @@ template static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width); - a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width); + a.x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, a.x, offset, width); + a.y += __shfl_xor_sync(GGML_WARP_SYNC_MASK, a.y, offset, width); } return a; } @@ -410,7 +410,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #ifdef FP16_AVAILABLE #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width)); + a = __hadd2(a, __shfl_xor_sync(GGML_WARP_SYNC_MASK, a, offset, width)); } return a; @@ -445,12 +445,12 @@ static __device__ __forceinline__ int warp_reduce_all(int x) { #ifdef GGML_USE_HIP #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x = x && __shfl_xor_sync(0xffffffff, x, offset, width); + x = x && __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width); } return x; #else static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented"); - return __all_sync(0xffffffff, x); + return __all_sync(GGML_WARP_SYNC_MASK, x); #endif // GGML_USE_HIP } @@ -458,7 +458,7 @@ template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width)); + x = fmaxf(x, __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width)); } return x; } @@ -501,7 +501,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width)); + x = ggml_cuda_hmax2(x, __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width)); } return x; #else diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e46f0e2081bdf..b706d6be02209 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -314,8 +314,8 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( } #pragma unroll for (int mask = QI8_1/2; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32)); - sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32); + amax = fmaxf(amax, __shfl_xor_sync(GGML_WARP_SYNC_MASK, amax, mask, 32)); + sum += __shfl_xor_sync(GGML_WARP_SYNC_MASK, sum, mask, 32); } const float d = amax / 127; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 39731baaeb7f4..4a8fe453b5546 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = 16; offset >= 4; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_max_new[col], offset, WARP_SIZE)); } } @@ -627,7 +627,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = 2; offset >= 1; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_max_new[col], offset, WARP_SIZE)); } } @@ -953,7 +953,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + KQ_rowsum[col] += __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_rowsum[col], offset, WARP_SIZE); } } } @@ -1086,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { if (offset < WARP_SIZE) { - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_cmn, offset, WARP_SIZE)); } } @@ -1104,7 +1104,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { if (offset < WARP_SIZE) { - KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + KQ_crs += __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_crs, offset, WARP_SIZE); } } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 83ee16b27d0df..9fb827343ab7a 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -47,8 +47,8 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { const int shift_low = ((src_j + 0) % 2) * 16; const int shift_high = ((src_j + 1) % 2) * 16; - const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; - const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; + const int ret_low = (__shfl_sync(GGML_WARP_SYNC_MASK, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; + const int ret_high = (__shfl_sync(GGML_WARP_SYNC_MASK, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; return ret_low | ret_high; } diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index a0b03a740d74c..d797aacbe6219 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -89,7 +89,7 @@ static __global__ void quantize_mmq_q8_1( // Exchange max. abs. value between vals_per_scale/4 threads. #pragma unroll for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE)); + amax = fmaxf(amax, __shfl_xor_sync(GGML_WARP_SYNC_MASK, amax, offset, WARP_SIZE)); } float sum; @@ -99,7 +99,7 @@ static __global__ void quantize_mmq_q8_1( // Calculate sums across vals_per_sum/4 threads. #pragma unroll for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { - sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE); + sum += __shfl_xor_sync(GGML_WARP_SYNC_MASK, sum, offset, WARP_SIZE); } } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index c31f319232252..b86ad587c1a00 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -23,8 +23,6 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} -#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) -#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasCreate hipblasCreate #define cublasDestroy hipblasDestroy #define cublasGemmEx hipblasGemmEx @@ -137,6 +135,7 @@ #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED +// HIP version-specific type mappings #if HIP_VERSION >= 70000000 #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F @@ -149,7 +148,16 @@ #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define cublasComputeType_t hipblasDatatype_t #define cudaDataType_t hipblasDatatype_t -#endif // HIP_VERSION >= 7000000 +#endif // HIP_VERSION >= 70000000 + +// Warp sync functions and masks +#if HIP_VERSION >= 70000000 && defined(GGML_HIP_ROCWMMA_FATTN) +#define GGML_WARP_SYNC_MASK 0xffffffffffffffffULL // ROCm 7.0+ requires 64-bit masks for __*_*_sync functions +#else +#define GGML_WARP_SYNC_MASK 0xffffffff +#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#endif // HIP_VERSION >= 70000000 && defined(GGML_HIP_ROCWMMA_FATTN) #if !defined(__HIP_PLATFORM_AMD__) #error "The HIP backend supports only AMD targets" From 493f96ac6cda8c907c44f2168f154619380a556c Mon Sep 17 00:00:00 2001 From: Slobodan Josic Date: Mon, 11 Aug 2025 09:31:10 -0400 Subject: [PATCH 2/6] Fix CUDA/MUSA build: define GGML_WARP_SYNC_MASK in common.cuh --- ggml/src/ggml-cuda/common.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 87d350d49818e..dcd0288a86339 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -35,6 +35,10 @@ #include "vendors/cuda.h" #endif // defined(GGML_USE_HIP) +#ifndef GGML_WARP_SYNC_MASK +#define GGML_WARP_SYNC_MASK 0xffffffff +#endif + #define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) From f264793b4605fb74f99e187c2265ee047ee7e22e Mon Sep 17 00:00:00 2001 From: Slobodan Josic Date: Mon, 11 Aug 2025 15:57:09 -0400 Subject: [PATCH 3/6] Addressed code review comments: GGML_WARP_SYNC_MASK renamed to GGML_CUDA_WARP_MASK --- ggml/src/ggml-cuda/argmax.cu | 8 ++++---- ggml/src/ggml-cuda/common.cuh | 26 +++++++++++++------------- ggml/src/ggml-cuda/fattn-common.cuh | 4 ++-- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 10 +++++----- ggml/src/ggml-cuda/mma.cuh | 4 ++-- ggml/src/ggml-cuda/quantize.cu | 4 ++-- ggml/src/ggml-cuda/vendors/hip.h | 4 ++-- 7 files changed, 30 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index 68accf9402e2b..6dfbf18b19c17 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -22,8 +22,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { - const float val = __shfl_xor_sync(GGML_WARP_SYNC_MASK, maxval, offset, WARP_SIZE); - const int col = __shfl_xor_sync(GGML_WARP_SYNC_MASK, argmax, offset, WARP_SIZE); + const float val = __shfl_xor_sync(GGML_CUDA_WARP_MASK, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(GGML_CUDA_WARP_MASK, argmax, offset, WARP_SIZE); if (val > maxval) { maxval = val; argmax = col; @@ -51,8 +51,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest } #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { - const float val = __shfl_xor_sync(GGML_WARP_SYNC_MASK, maxval, offset, WARP_SIZE); - const int col = __shfl_xor_sync(GGML_WARP_SYNC_MASK, argmax, offset, WARP_SIZE); + const float val = __shfl_xor_sync(GGML_CUDA_WARP_MASK, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(GGML_CUDA_WARP_MASK, argmax, offset, WARP_SIZE); if (val > maxval) { maxval = val; argmax = col; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index dcd0288a86339..e90d2810fa489 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -35,9 +35,9 @@ #include "vendors/cuda.h" #endif // defined(GGML_USE_HIP) -#ifndef GGML_WARP_SYNC_MASK -#define GGML_WARP_SYNC_MASK 0xffffffff -#endif +#ifndef GGML_CUDA_WARP_MASK +#define GGML_CUDA_WARP_MASK 0xffffffff +#endif // GGML_CUDA_WARP_MASK #define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) @@ -380,11 +380,11 @@ struct ggml_cuda_unroll<1> { template static __device__ __forceinline__ int warp_reduce_sum(int x) { #if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE) || (defined(GGML_USE_HIP) && HIP_VERSION >= 70000000) - return __reduce_add_sync(GGML_WARP_SYNC_MASK, x); + return __reduce_add_sync(GGML_CUDA_WARP_MASK, x); #else #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width); + x += __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width); } return x; #endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE) || (defined(GGML_USE_HIP) && HIP_VERSION >= 70000000) @@ -394,7 +394,7 @@ template static __device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width); + x += __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width); } return x; } @@ -403,8 +403,8 @@ template static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - a.x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, a.x, offset, width); - a.y += __shfl_xor_sync(GGML_WARP_SYNC_MASK, a.y, offset, width); + a.x += __shfl_xor_sync(GGML_CUDA_WARP_MASK, a.x, offset, width); + a.y += __shfl_xor_sync(GGML_CUDA_WARP_MASK, a.y, offset, width); } return a; } @@ -414,7 +414,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #ifdef FP16_AVAILABLE #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - a = __hadd2(a, __shfl_xor_sync(GGML_WARP_SYNC_MASK, a, offset, width)); + a = __hadd2(a, __shfl_xor_sync(GGML_CUDA_WARP_MASK, a, offset, width)); } return a; @@ -449,12 +449,12 @@ static __device__ __forceinline__ int warp_reduce_all(int x) { #ifdef GGML_USE_HIP #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x = x && __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width); + x = x && __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width); } return x; #else static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented"); - return __all_sync(GGML_WARP_SYNC_MASK, x); + return __all_sync(GGML_CUDA_WARP_MASK, x); #endif // GGML_USE_HIP } @@ -462,7 +462,7 @@ template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x = fmaxf(x, __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width)); + x = fmaxf(x, __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width)); } return x; } @@ -505,7 +505,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x = ggml_cuda_hmax2(x, __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width)); + x = ggml_cuda_hmax2(x, __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width)); } return x; #else diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index b706d6be02209..0b0f2a1411f43 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -314,8 +314,8 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( } #pragma unroll for (int mask = QI8_1/2; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(GGML_WARP_SYNC_MASK, amax, mask, 32)); - sum += __shfl_xor_sync(GGML_WARP_SYNC_MASK, sum, mask, 32); + amax = fmaxf(amax, __shfl_xor_sync(GGML_CUDA_WARP_MASK, amax, mask, 32)); + sum += __shfl_xor_sync(GGML_CUDA_WARP_MASK, sum, mask, 32); } const float d = amax / 127; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 4a8fe453b5546..c0b96c6439d25 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = 16; offset >= 4; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_max_new[col], offset, WARP_SIZE)); } } @@ -627,7 +627,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = 2; offset >= 1; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_max_new[col], offset, WARP_SIZE)); } } @@ -953,7 +953,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_rowsum[col] += __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_rowsum[col], offset, WARP_SIZE); + KQ_rowsum[col] += __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_rowsum[col], offset, WARP_SIZE); } } } @@ -1086,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { if (offset < WARP_SIZE) { - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_cmn, offset, WARP_SIZE)); + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_cmn, offset, WARP_SIZE)); } } @@ -1104,7 +1104,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { if (offset < WARP_SIZE) { - KQ_crs += __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_crs, offset, WARP_SIZE); + KQ_crs += __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_crs, offset, WARP_SIZE); } } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 9fb827343ab7a..2c6990c9b590d 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -47,8 +47,8 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { const int shift_low = ((src_j + 0) % 2) * 16; const int shift_high = ((src_j + 1) % 2) * 16; - const int ret_low = (__shfl_sync(GGML_WARP_SYNC_MASK, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; - const int ret_high = (__shfl_sync(GGML_WARP_SYNC_MASK, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; + const int ret_low = (__shfl_sync(GGML_CUDA_WARP_MASK, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; + const int ret_high = (__shfl_sync(GGML_CUDA_WARP_MASK, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; return ret_low | ret_high; } diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index d797aacbe6219..0faa0c9e384fd 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -89,7 +89,7 @@ static __global__ void quantize_mmq_q8_1( // Exchange max. abs. value between vals_per_scale/4 threads. #pragma unroll for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(GGML_WARP_SYNC_MASK, amax, offset, WARP_SIZE)); + amax = fmaxf(amax, __shfl_xor_sync(GGML_CUDA_WARP_MASK, amax, offset, WARP_SIZE)); } float sum; @@ -99,7 +99,7 @@ static __global__ void quantize_mmq_q8_1( // Calculate sums across vals_per_sum/4 threads. #pragma unroll for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { - sum += __shfl_xor_sync(GGML_WARP_SYNC_MASK, sum, offset, WARP_SIZE); + sum += __shfl_xor_sync(GGML_CUDA_WARP_MASK, sum, offset, WARP_SIZE); } } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index b86ad587c1a00..1d8575ab11f2d 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -152,9 +152,9 @@ // Warp sync functions and masks #if HIP_VERSION >= 70000000 && defined(GGML_HIP_ROCWMMA_FATTN) -#define GGML_WARP_SYNC_MASK 0xffffffffffffffffULL // ROCm 7.0+ requires 64-bit masks for __*_*_sync functions +#define GGML_CUDA_WARP_MASK 0xffffffffffffffffULL // ROCm 7.0+ requires 64-bit masks for __*_*_sync functions #else -#define GGML_WARP_SYNC_MASK 0xffffffff +#define GGML_CUDA_WARP_MASK 0xffffffff #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #endif // HIP_VERSION >= 70000000 && defined(GGML_HIP_ROCWMMA_FATTN) From 4d9df1d51474dd769a21e384ce964b0b80321929 Mon Sep 17 00:00:00 2001 From: Slobodan Josic Date: Mon, 11 Aug 2025 20:44:39 -0400 Subject: [PATCH 4/6] Revert all changes introduced in this PR --- ggml/src/ggml-cuda/argmax.cu | 8 ++++---- ggml/src/ggml-cuda/common.cuh | 28 ++++++++++++---------------- ggml/src/ggml-cuda/fattn-common.cuh | 4 ++-- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 10 +++++----- ggml/src/ggml-cuda/mma.cuh | 4 ++-- ggml/src/ggml-cuda/quantize.cu | 4 ++-- ggml/src/ggml-cuda/vendors/hip.h | 12 ++---------- 7 files changed, 29 insertions(+), 41 deletions(-) diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index 6dfbf18b19c17..5340eedc08916 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -22,8 +22,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { - const float val = __shfl_xor_sync(GGML_CUDA_WARP_MASK, maxval, offset, WARP_SIZE); - const int col = __shfl_xor_sync(GGML_CUDA_WARP_MASK, argmax, offset, WARP_SIZE); + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) { maxval = val; argmax = col; @@ -51,8 +51,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest } #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { - const float val = __shfl_xor_sync(GGML_CUDA_WARP_MASK, maxval, offset, WARP_SIZE); - const int col = __shfl_xor_sync(GGML_CUDA_WARP_MASK, argmax, offset, WARP_SIZE); + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) { maxval = val; argmax = col; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index e90d2810fa489..6f43fd9affef5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -35,10 +35,6 @@ #include "vendors/cuda.h" #endif // defined(GGML_USE_HIP) -#ifndef GGML_CUDA_WARP_MASK -#define GGML_CUDA_WARP_MASK 0xffffffff -#endif // GGML_CUDA_WARP_MASK - #define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) @@ -379,22 +375,22 @@ struct ggml_cuda_unroll<1> { template static __device__ __forceinline__ int warp_reduce_sum(int x) { -#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE) || (defined(GGML_USE_HIP) && HIP_VERSION >= 70000000) - return __reduce_add_sync(GGML_CUDA_WARP_MASK, x); +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + return __reduce_add_sync(0xFFFFFFFF, x); #else #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x += __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width); + x += __shfl_xor_sync(0xFFFFFFFF, x, offset, width); } return x; -#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE) || (defined(GGML_USE_HIP) && HIP_VERSION >= 70000000) +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE } template static __device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x += __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width); + x += __shfl_xor_sync(0xFFFFFFFF, x, offset, width); } return x; } @@ -403,8 +399,8 @@ template static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - a.x += __shfl_xor_sync(GGML_CUDA_WARP_MASK, a.x, offset, width); - a.y += __shfl_xor_sync(GGML_CUDA_WARP_MASK, a.y, offset, width); + a.x += __shfl_xor_sync(0xFFFFFFFF, a.x, offset, width); + a.y += __shfl_xor_sync(0xFFFFFFFF, a.y, offset, width); } return a; } @@ -414,7 +410,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #ifdef FP16_AVAILABLE #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - a = __hadd2(a, __shfl_xor_sync(GGML_CUDA_WARP_MASK, a, offset, width)); + a = __hadd2(a, __shfl_xor_sync(0xFFFFFFFF, a, offset, width)); } return a; @@ -449,12 +445,12 @@ static __device__ __forceinline__ int warp_reduce_all(int x) { #ifdef GGML_USE_HIP #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x = x && __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width); + x = x && __shfl_xor_sync(0xFFFFFFFF, x, offset, width); } return x; #else static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented"); - return __all_sync(GGML_CUDA_WARP_MASK, x); + return __all_sync(0xFFFFFFFF, x); #endif // GGML_USE_HIP } @@ -462,7 +458,7 @@ template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x = fmaxf(x, __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width)); + x = fmaxf(x, __shfl_xor_sync(0xFFFFFFFF, x, offset, width)); } return x; } @@ -505,7 +501,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { - x = ggml_cuda_hmax2(x, __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width)); + x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xFFFFFFFF, x, offset, width)); } return x; #else diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 0b0f2a1411f43..e46f0e2081bdf 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -314,8 +314,8 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( } #pragma unroll for (int mask = QI8_1/2; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(GGML_CUDA_WARP_MASK, amax, mask, 32)); - sum += __shfl_xor_sync(GGML_CUDA_WARP_MASK, sum, mask, 32); + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32)); + sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32); } const float d = amax / 127; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index c0b96c6439d25..39731baaeb7f4 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = 16; offset >= 4; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); } } @@ -627,7 +627,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = 2; offset >= 1; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); } } @@ -953,7 +953,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_rowsum[col] += __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_rowsum[col], offset, WARP_SIZE); + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); } } } @@ -1086,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { if (offset < WARP_SIZE) { - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_cmn, offset, WARP_SIZE)); + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } } @@ -1104,7 +1104,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { if (offset < WARP_SIZE) { - KQ_crs += __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_crs, offset, WARP_SIZE); + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 2c6990c9b590d..83ee16b27d0df 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -47,8 +47,8 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { const int shift_low = ((src_j + 0) % 2) * 16; const int shift_high = ((src_j + 1) % 2) * 16; - const int ret_low = (__shfl_sync(GGML_CUDA_WARP_MASK, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; - const int ret_high = (__shfl_sync(GGML_CUDA_WARP_MASK, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; + const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; + const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; return ret_low | ret_high; } diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 0faa0c9e384fd..a0b03a740d74c 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -89,7 +89,7 @@ static __global__ void quantize_mmq_q8_1( // Exchange max. abs. value between vals_per_scale/4 threads. #pragma unroll for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(GGML_CUDA_WARP_MASK, amax, offset, WARP_SIZE)); + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE)); } float sum; @@ -99,7 +99,7 @@ static __global__ void quantize_mmq_q8_1( // Calculate sums across vals_per_sum/4 threads. #pragma unroll for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { - sum += __shfl_xor_sync(GGML_CUDA_WARP_MASK, sum, offset, WARP_SIZE); + sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE); } } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 1d8575ab11f2d..9a6469ced9eb7 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -23,6 +23,8 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasCreate hipblasCreate #define cublasDestroy hipblasDestroy #define cublasGemmEx hipblasGemmEx @@ -135,7 +137,6 @@ #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED -// HIP version-specific type mappings #if HIP_VERSION >= 70000000 #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F @@ -150,15 +151,6 @@ #define cudaDataType_t hipblasDatatype_t #endif // HIP_VERSION >= 70000000 -// Warp sync functions and masks -#if HIP_VERSION >= 70000000 && defined(GGML_HIP_ROCWMMA_FATTN) -#define GGML_CUDA_WARP_MASK 0xffffffffffffffffULL // ROCm 7.0+ requires 64-bit masks for __*_*_sync functions -#else -#define GGML_CUDA_WARP_MASK 0xffffffff -#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) -#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) -#endif // HIP_VERSION >= 70000000 && defined(GGML_HIP_ROCWMMA_FATTN) - #if !defined(__HIP_PLATFORM_AMD__) #error "The HIP backend supports only AMD targets" #endif // !defined(__HIP_PLATFORM_AMD__) From 74c7a2fe81d24483c1fc84c873109910a9e1a4ba Mon Sep 17 00:00:00 2001 From: Slobodan Josic Date: Mon, 11 Aug 2025 21:12:56 -0400 Subject: [PATCH 5/6] Starting from ROCm 6.5 (aka 6.4 with 7.0 preview) HIP_ENABLE_WARP_SYNC_BUILTINS has been replaced with HIP_DISABLE_WARP_SYNC_BUILTINS (https://github.com/ROCm/clr/blob/rocm-6.4.x-with-7.0-preview/hipamd/include/hip/amd_detail/amd_warp_sync_functions.h#L30) --- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index fdc4d17da2da9..e390b7d733b51 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -15,7 +15,11 @@ namespace wmma = mtmusa::wmma; namespace wmma = nvcuda::wmma; #endif // GGML_USE_MUSA #elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) -#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers +#if HIP_VERSION >= 60500000 +#define HIP_DISABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers for ROCm 6.5+ +#else +#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers before ROCm 6.5 +#endif // HIP_VERSION >= 60500000 #include namespace wmma = rocwmma; #endif // !defined(GGML_USE_HIP) From 1299c04c66f33d2d6a877f171bb2649fa47cf149 Mon Sep 17 00:00:00 2001 From: Slobodan Josic Date: Mon, 11 Aug 2025 21:17:36 -0400 Subject: [PATCH 6/6] hipBLAS changes have been introduced also from ROCm 6.5 (aka 6.4 with 7.0 preview) --- ggml/src/ggml-cuda/vendors/hip.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 9a6469ced9eb7..ea13371fc7a90 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -137,7 +137,7 @@ #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED -#if HIP_VERSION >= 70000000 +#if HIP_VERSION >= 60500000 #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F @@ -149,7 +149,7 @@ #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define cublasComputeType_t hipblasDatatype_t #define cudaDataType_t hipblasDatatype_t -#endif // HIP_VERSION >= 70000000 +#endif // HIP_VERSION >= 60500000 #if !defined(__HIP_PLATFORM_AMD__) #error "The HIP backend supports only AMD targets"