Skip to content
20 changes: 10 additions & 10 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,11 @@ struct ggml_cuda_unroll<1> {
template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_sum(int x) {
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
return __reduce_add_sync(0xffffffff, x);
return __reduce_add_sync(0xFFFFFFFF, x);
#else
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, width);
x += __shfl_xor_sync(0xFFFFFFFF, x, offset, width);
}
return x;
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
Expand All @@ -390,7 +390,7 @@ template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, width);
x += __shfl_xor_sync(0xFFFFFFFF, x, offset, width);
}
return x;
}
Expand All @@ -399,8 +399,8 @@ template<int width = WARP_SIZE>
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
a.x += __shfl_xor_sync(0xFFFFFFFF, a.x, offset, width);
a.y += __shfl_xor_sync(0xFFFFFFFF, a.y, offset, width);
}
return a;
}
Expand All @@ -410,7 +410,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#ifdef FP16_AVAILABLE
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
a = __hadd2(a, __shfl_xor_sync(0xFFFFFFFF, a, offset, width));
}
return a;

Expand Down Expand Up @@ -445,20 +445,20 @@ static __device__ __forceinline__ int warp_reduce_all(int x) {
#ifdef GGML_USE_HIP
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
x = x && __shfl_xor_sync(0xFFFFFFFF, x, offset, width);
}
return x;
#else
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
return __all_sync(0xffffffff, x);
return __all_sync(0xFFFFFFFF, x);
#endif // GGML_USE_HIP
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
x = fmaxf(x, __shfl_xor_sync(0xFFFFFFFF, x, offset, width));
}
return x;
}
Expand Down Expand Up @@ -501,7 +501,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xFFFFFFFF, x, offset, width));
}
return x;
#else
Expand Down
6 changes: 5 additions & 1 deletion ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ namespace wmma = mtmusa::wmma;
namespace wmma = nvcuda::wmma;
#endif // GGML_USE_MUSA
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
#if HIP_VERSION >= 60500000
#define HIP_DISABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers for ROCm 6.5+
#else
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers before ROCm 6.5
#endif // HIP_VERSION >= 60500000
#include <rocwmma/rocwmma.hpp>
namespace wmma = rocwmma;
#endif // !defined(GGML_USE_HIP)
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/vendors/hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED

#if HIP_VERSION >= 70000000
#if HIP_VERSION >= 60500000
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
Expand All @@ -149,7 +149,7 @@
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define cublasComputeType_t hipblasDatatype_t
#define cudaDataType_t hipblasDatatype_t
#endif // HIP_VERSION >= 7000000
#endif // HIP_VERSION >= 60500000

#if !defined(__HIP_PLATFORM_AMD__)
#error "The HIP backend supports only AMD targets"
Expand Down
Loading