Skip to content

Fix HIP warp synchronization function conflicts for ROCm 7.0+ #15241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
8 changes: 4 additions & 4 deletions ggml/src/ggml-cuda/argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest

#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
const float val = __shfl_xor_sync(GGML_WARP_SYNC_MASK, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(GGML_WARP_SYNC_MASK, argmax, offset, WARP_SIZE);
if (val > maxval) {
maxval = val;
argmax = col;
Expand Down Expand Up @@ -51,8 +51,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
const float val = __shfl_xor_sync(GGML_WARP_SYNC_MASK, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(GGML_WARP_SYNC_MASK, argmax, offset, WARP_SIZE);
if (val > maxval) {
maxval = val;
argmax = col;
Expand Down
28 changes: 16 additions & 12 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
#include "vendors/cuda.h"
#endif // defined(GGML_USE_HIP)

#ifndef GGML_WARP_SYNC_MASK
#define GGML_WARP_SYNC_MASK 0xffffffff
#endif

#define STRINGIZE_IMPL(...) #__VA_ARGS__
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)

Expand Down Expand Up @@ -375,22 +379,22 @@ struct ggml_cuda_unroll<1> {

template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_sum(int x) {
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
return __reduce_add_sync(0xffffffff, x);
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE) || (defined(GGML_USE_HIP) && HIP_VERSION >= 70000000)
return __reduce_add_sync(GGML_WARP_SYNC_MASK, x);
#else
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, width);
x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width);
}
return x;
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE) || (defined(GGML_USE_HIP) && HIP_VERSION >= 70000000)
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, width);
x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width);
}
return x;
}
Expand All @@ -399,8 +403,8 @@ template<int width = WARP_SIZE>
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
a.x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, a.x, offset, width);
a.y += __shfl_xor_sync(GGML_WARP_SYNC_MASK, a.y, offset, width);
}
return a;
}
Expand All @@ -410,7 +414,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#ifdef FP16_AVAILABLE
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
a = __hadd2(a, __shfl_xor_sync(GGML_WARP_SYNC_MASK, a, offset, width));
}
return a;

Expand Down Expand Up @@ -445,20 +449,20 @@ static __device__ __forceinline__ int warp_reduce_all(int x) {
#ifdef GGML_USE_HIP
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
x = x && __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width);
}
return x;
#else
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
return __all_sync(0xffffffff, x);
return __all_sync(GGML_WARP_SYNC_MASK, x);
#endif // GGML_USE_HIP
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
x = fmaxf(x, __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width));
}
return x;
}
Expand Down Expand Up @@ -501,7 +505,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
x = ggml_cuda_hmax2(x, __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width));
}
return x;
#else
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
}
#pragma unroll
for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32);
amax = fmaxf(amax, __shfl_xor_sync(GGML_WARP_SYNC_MASK, amax, mask, 32));
sum += __shfl_xor_sync(GGML_WARP_SYNC_MASK, sum, mask, 32);
}

const float d = amax / 127;
Expand Down
10 changes: 5 additions & 5 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
for (int offset = 16; offset >= 4; offset >>= 1) {
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_max_new[col], offset, WARP_SIZE));
}
}

Expand Down Expand Up @@ -627,7 +627,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
for (int offset = 2; offset >= 1; offset >>= 1) {
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_max_new[col], offset, WARP_SIZE));
}
}

Expand Down Expand Up @@ -953,7 +953,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
KQ_rowsum[col] += __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_rowsum[col], offset, WARP_SIZE);
}
}
}
Expand Down Expand Up @@ -1086,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#pragma unroll
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
if (offset < WARP_SIZE) {
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_cmn, offset, WARP_SIZE));
}
}

Expand All @@ -1104,7 +1104,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#pragma unroll
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
if (offset < WARP_SIZE) {
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
KQ_crs += __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_crs, offset, WARP_SIZE);
}
}

Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
const int shift_low = ((src_j + 0) % 2) * 16;
const int shift_high = ((src_j + 1) % 2) * 16;

const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
const int ret_low = (__shfl_sync(GGML_WARP_SYNC_MASK, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
const int ret_high = (__shfl_sync(GGML_WARP_SYNC_MASK, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;

return ret_low | ret_high;
}
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ static __global__ void quantize_mmq_q8_1(
// Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
amax = fmaxf(amax, __shfl_xor_sync(GGML_WARP_SYNC_MASK, amax, offset, WARP_SIZE));
}

float sum;
Expand All @@ -99,7 +99,7 @@ static __global__ void quantize_mmq_q8_1(
// Calculate sums across vals_per_sum/4 threads.
#pragma unroll
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
sum += __shfl_xor_sync(GGML_WARP_SYNC_MASK, sum, offset, WARP_SIZE);
}
}

Expand Down
14 changes: 11 additions & 3 deletions ggml/src/ggml-cuda/vendors/hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
Expand Down Expand Up @@ -137,6 +135,7 @@
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED

// HIP version-specific type mappings
#if HIP_VERSION >= 70000000
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
Expand All @@ -149,7 +148,16 @@
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define cublasComputeType_t hipblasDatatype_t
#define cudaDataType_t hipblasDatatype_t
#endif // HIP_VERSION >= 7000000
#endif // HIP_VERSION >= 70000000

// Warp sync functions and masks
#if HIP_VERSION >= 70000000 && defined(GGML_HIP_ROCWMMA_FATTN)
#define GGML_WARP_SYNC_MASK 0xffffffffffffffffULL // ROCm 7.0+ requires 64-bit masks for __*_*_sync functions
#else
#define GGML_WARP_SYNC_MASK 0xffffffff
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#endif // HIP_VERSION >= 70000000 && defined(GGML_HIP_ROCWMMA_FATTN)

#if !defined(__HIP_PLATFORM_AMD__)
#error "The HIP backend supports only AMD targets"
Expand Down
Loading