Skip to content

hip : fix warp mask width for rocWMMA compatibility #15239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ggml/src/ggml-cuda/argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest

#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
const float val = __shfl_xor_sync(GGML_CUDA_WARP_MASK, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(GGML_CUDA_WARP_MASK, argmax, offset, WARP_SIZE);
if (val > maxval) {
maxval = val;
argmax = col;
Expand Down Expand Up @@ -51,8 +51,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
const float val = __shfl_xor_sync(GGML_CUDA_WARP_MASK, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(GGML_CUDA_WARP_MASK, argmax, offset, WARP_SIZE);
if (val > maxval) {
maxval = val;
argmax = col;
Expand Down
24 changes: 14 additions & 10 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons

#ifndef GGML_CUDA_WARP_MASK
#define GGML_CUDA_WARP_MASK 0xFFFFFFFF
#endif

#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
Expand Down Expand Up @@ -376,11 +380,11 @@ struct ggml_cuda_unroll<1> {
template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_sum(int x) {
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
return __reduce_add_sync(0xffffffff, x);
return __reduce_add_sync(GGML_CUDA_WARP_MASK, x);
#else
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, width);
x += __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width);
}
return x;
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
Expand All @@ -390,7 +394,7 @@ template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, width);
x += __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width);
}
return x;
}
Expand All @@ -399,8 +403,8 @@ template<int width = WARP_SIZE>
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
a.x += __shfl_xor_sync(GGML_CUDA_WARP_MASK, a.x, offset, width);
a.y += __shfl_xor_sync(GGML_CUDA_WARP_MASK, a.y, offset, width);
}
return a;
}
Expand All @@ -410,7 +414,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#ifdef FP16_AVAILABLE
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
a = __hadd2(a, __shfl_xor_sync(GGML_CUDA_WARP_MASK, a, offset, width));
}
return a;

Expand Down Expand Up @@ -445,20 +449,20 @@ static __device__ __forceinline__ int warp_reduce_all(int x) {
#ifdef GGML_USE_HIP
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
x = x && __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width);
}
return x;
#else
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
return __all_sync(0xffffffff, x);
return __all_sync(GGML_CUDA_WARP_MASK, x);
#endif // GGML_USE_HIP
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
x = fmaxf(x, __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width));
}
return x;
}
Expand Down Expand Up @@ -501,7 +505,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
x = ggml_cuda_hmax2(x, __shfl_xor_sync(GGML_CUDA_WARP_MASK, x, offset, width));
}
return x;
#else
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
}
#pragma unroll
for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32);
amax = fmaxf(amax, __shfl_xor_sync(GGML_CUDA_WARP_MASK, amax, mask, 32));
sum += __shfl_xor_sync(GGML_CUDA_WARP_MASK, sum, mask, 32);
}

const float d = amax / 127;
Expand Down Expand Up @@ -690,7 +690,7 @@ static __global__ void flash_attn_combine_results(
for (int l = 0; l < parallel_blocks; ++l) {
const float diff = meta[l].x - kqmax;
float KQ_max_scale = expf(diff);
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
const uint32_t ftz_mask = GGML_CUDA_WARP_MASK * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;

VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
Expand Down
10 changes: 5 additions & 5 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
for (int offset = 16; offset >= 4; offset >>= 1) {
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_max_new[col], offset, WARP_SIZE));
}
}

Expand Down Expand Up @@ -627,7 +627,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
for (int offset = 2; offset >= 1; offset >>= 1) {
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_max_new[col], offset, WARP_SIZE));
}
}

Expand Down Expand Up @@ -953,7 +953,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
KQ_rowsum[col] += __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_rowsum[col], offset, WARP_SIZE);
}
}
}
Expand Down Expand Up @@ -1086,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#pragma unroll
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
if (offset < WARP_SIZE) {
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_cmn, offset, WARP_SIZE));
}
}

Expand All @@ -1104,7 +1104,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#pragma unroll
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
if (offset < WARP_SIZE) {
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
KQ_crs += __shfl_xor_sync(GGML_CUDA_WARP_MASK, KQ_crs, offset, WARP_SIZE);
}
}

Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
const int shift_low = ((src_j + 0) % 2) * 16;
const int shift_high = ((src_j + 1) % 2) * 16;

const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
const int ret_low = (__shfl_sync(GGML_CUDA_WARP_MASK, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
const int ret_high = (__shfl_sync(GGML_CUDA_WARP_MASK, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;

return ret_low | ret_high;
}
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ static __global__ void quantize_mmq_q8_1(
// Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
amax = fmaxf(amax, __shfl_xor_sync(GGML_CUDA_WARP_MASK, amax, offset, WARP_SIZE));
}

float sum;
Expand All @@ -99,7 +99,7 @@ static __global__ void quantize_mmq_q8_1(
// Calculate sums across vals_per_sum/4 threads.
#pragma unroll
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
sum += __shfl_xor_sync(GGML_CUDA_WARP_MASK, sum, offset, WARP_SIZE);
}
}

Expand Down
11 changes: 11 additions & 0 deletions ggml/src/ggml-cuda/vendors/hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,19 @@
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#ifdef GGML_HIP_ROCWMMA_FATTN
// ROCm requires 64-bit masks for __shfl_*_sync functions
#define GGML_CUDA_WARP_MASK 0xFFFFFFFFFFFFFFFFULL
#else
#define GGML_CUDA_WARP_MASK 0xFFFFFFFF
// Only define __shfl_*_sync macros if they're not already available
#ifndef __shfl_sync
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#endif
#ifndef __shfl_xor_sync
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#endif
#endif
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
Expand Down