Skip to content

Commit 1c62c5a

Browse files
committed
Revert all changes from PR #2: Fix HIP warp synchronization mask compatibility for ROCm 7.0+ due to requests and comments from ggml-org#15241
1 parent 86793ee commit 1c62c5a

File tree

7 files changed

+29
-37
lines changed

7 files changed

+29
-37
lines changed

ggml/src/ggml-cuda/argmax.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
2222

2323
#pragma unroll
2424
for (int offset = 16; offset > 0; offset >>= 1) {
25-
const float val = __shfl_xor_sync(GGML_WARP_SYNC_MASK, maxval, offset, WARP_SIZE);
26-
const int col = __shfl_xor_sync(GGML_WARP_SYNC_MASK, argmax, offset, WARP_SIZE);
25+
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
26+
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
2727
if (val > maxval) {
2828
maxval = val;
2929
argmax = col;
@@ -51,8 +51,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
5151
}
5252
#pragma unroll
5353
for (int offset = 16; offset > 0; offset >>= 1) {
54-
const float val = __shfl_xor_sync(GGML_WARP_SYNC_MASK, maxval, offset, WARP_SIZE);
55-
const int col = __shfl_xor_sync(GGML_WARP_SYNC_MASK, argmax, offset, WARP_SIZE);
54+
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
55+
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
5656
if (val > maxval) {
5757
maxval = val;
5858
argmax = col;

ggml/src/ggml-cuda/common.cuh

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -375,22 +375,22 @@ struct ggml_cuda_unroll<1> {
375375

376376
template<int width = WARP_SIZE>
377377
static __device__ __forceinline__ int warp_reduce_sum(int x) {
378-
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE) || (defined(GGML_USE_HIP) && HIP_VERSION >= 70000000)
379-
return __reduce_add_sync(GGML_WARP_SYNC_MASK, x);
378+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
379+
return __reduce_add_sync(0xFFFFFFFF, x);
380380
#else
381381
#pragma unroll
382382
for (int offset = width/2; offset > 0; offset >>= 1) {
383-
x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width);
383+
x += __shfl_xor_sync(0xFFFFFFFF, x, offset, width);
384384
}
385385
return x;
386-
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE) || (defined(GGML_USE_HIP) && HIP_VERSION >= 70000000)
386+
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
387387
}
388388

389389
template<int width = WARP_SIZE>
390390
static __device__ __forceinline__ float warp_reduce_sum(float x) {
391391
#pragma unroll
392392
for (int offset = width/2; offset > 0; offset >>= 1) {
393-
x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width);
393+
x += __shfl_xor_sync(0xFFFFFFFF, x, offset, width);
394394
}
395395
return x;
396396
}
@@ -399,8 +399,8 @@ template<int width = WARP_SIZE>
399399
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
400400
#pragma unroll
401401
for (int offset = width/2; offset > 0; offset >>= 1) {
402-
a.x += __shfl_xor_sync(GGML_WARP_SYNC_MASK, a.x, offset, width);
403-
a.y += __shfl_xor_sync(GGML_WARP_SYNC_MASK, a.y, offset, width);
402+
a.x += __shfl_xor_sync(0xFFFFFFFF, a.x, offset, width);
403+
a.y += __shfl_xor_sync(0xFFFFFFFF, a.y, offset, width);
404404
}
405405
return a;
406406
}
@@ -410,7 +410,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
410410
#ifdef FP16_AVAILABLE
411411
#pragma unroll
412412
for (int offset = width/2; offset > 0; offset >>= 1) {
413-
a = __hadd2(a, __shfl_xor_sync(GGML_WARP_SYNC_MASK, a, offset, width));
413+
a = __hadd2(a, __shfl_xor_sync(0xFFFFFFFF, a, offset, width));
414414
}
415415
return a;
416416

@@ -445,20 +445,20 @@ static __device__ __forceinline__ int warp_reduce_all(int x) {
445445
#ifdef GGML_USE_HIP
446446
#pragma unroll
447447
for (int offset = width/2; offset > 0; offset >>= 1) {
448-
x = x && __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width);
448+
x = x && __shfl_xor_sync(0xFFFFFFFF, x, offset, width);
449449
}
450450
return x;
451451
#else
452452
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
453-
return __all_sync(GGML_WARP_SYNC_MASK, x);
453+
return __all_sync(0xFFFFFFFF, x);
454454
#endif // GGML_USE_HIP
455455
}
456456

457457
template<int width = WARP_SIZE>
458458
static __device__ __forceinline__ float warp_reduce_max(float x) {
459459
#pragma unroll
460460
for (int offset = width/2; offset > 0; offset >>= 1) {
461-
x = fmaxf(x, __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width));
461+
x = fmaxf(x, __shfl_xor_sync(0xFFFFFFFF, x, offset, width));
462462
}
463463
return x;
464464
}
@@ -501,7 +501,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
501501
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
502502
#pragma unroll
503503
for (int offset = width/2; offset > 0; offset >>= 1) {
504-
x = ggml_cuda_hmax2(x, __shfl_xor_sync(GGML_WARP_SYNC_MASK, x, offset, width));
504+
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xFFFFFFFF, x, offset, width));
505505
}
506506
return x;
507507
#else

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
314314
}
315315
#pragma unroll
316316
for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
317-
amax = fmaxf(amax, __shfl_xor_sync(GGML_WARP_SYNC_MASK, amax, mask, 32));
318-
sum += __shfl_xor_sync(GGML_WARP_SYNC_MASK, sum, mask, 32);
317+
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
318+
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32);
319319
}
320320

321321
const float d = amax / 127;

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
572572
for (int col = 0; col < cols_per_thread; ++col) {
573573
#pragma unroll
574574
for (int offset = 16; offset >= 4; offset >>= 1) {
575-
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_max_new[col], offset, WARP_SIZE));
575+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
576576
}
577577
}
578578

@@ -627,7 +627,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
627627
for (int col = 0; col < cols_per_thread; ++col) {
628628
#pragma unroll
629629
for (int offset = 2; offset >= 1; offset >>= 1) {
630-
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_max_new[col], offset, WARP_SIZE));
630+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
631631
}
632632
}
633633

@@ -953,7 +953,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
953953
for (int col = 0; col < cols_per_thread; ++col) {
954954
#pragma unroll
955955
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
956-
KQ_rowsum[col] += __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_rowsum[col], offset, WARP_SIZE);
956+
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
957957
}
958958
}
959959
}
@@ -1086,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
10861086
#pragma unroll
10871087
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
10881088
if (offset < WARP_SIZE) {
1089-
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_cmn, offset, WARP_SIZE));
1089+
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
10901090
}
10911091
}
10921092

@@ -1104,7 +1104,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
11041104
#pragma unroll
11051105
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
11061106
if (offset < WARP_SIZE) {
1107-
KQ_crs += __shfl_xor_sync(GGML_WARP_SYNC_MASK, KQ_crs, offset, WARP_SIZE);
1107+
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
11081108
}
11091109
}
11101110

ggml/src/ggml-cuda/mma.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
4747
const int shift_low = ((src_j + 0) % 2) * 16;
4848
const int shift_high = ((src_j + 1) % 2) * 16;
4949

50-
const int ret_low = (__shfl_sync(GGML_WARP_SYNC_MASK, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
51-
const int ret_high = (__shfl_sync(GGML_WARP_SYNC_MASK, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
50+
const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
51+
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
5252

5353
return ret_low | ret_high;
5454
}

ggml/src/ggml-cuda/quantize.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ static __global__ void quantize_mmq_q8_1(
8989
// Exchange max. abs. value between vals_per_scale/4 threads.
9090
#pragma unroll
9191
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
92-
amax = fmaxf(amax, __shfl_xor_sync(GGML_WARP_SYNC_MASK, amax, offset, WARP_SIZE));
92+
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
9393
}
9494

9595
float sum;
@@ -99,7 +99,7 @@ static __global__ void quantize_mmq_q8_1(
9999
// Calculate sums across vals_per_sum/4 threads.
100100
#pragma unroll
101101
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
102-
sum += __shfl_xor_sync(GGML_WARP_SYNC_MASK, sum, offset, WARP_SIZE);
102+
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
103103
}
104104
}
105105

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
2424
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
2525
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
26+
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
27+
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
2628
#define cublasCreate hipblasCreate
2729
#define cublasDestroy hipblasDestroy
2830
#define cublasGemmEx hipblasGemmEx
@@ -135,7 +137,6 @@
135137
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
136138
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
137139

138-
// HIP version-specific type mappings
139140
#if HIP_VERSION >= 70000000
140141
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
141142
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
@@ -150,15 +151,6 @@
150151
#define cudaDataType_t hipblasDatatype_t
151152
#endif // HIP_VERSION >= 70000000
152153

153-
// Warp sync functions and masks
154-
#if HIP_VERSION >= 70000000 && defined(GGML_HIP_ROCWMMA_FATTN)
155-
#define GGML_WARP_SYNC_MASK 0xffffffffffffffffULL // ROCm 7.0+ requires 64-bit masks for __*_*_sync functions
156-
#else
157-
#define GGML_WARP_SYNC_MASK 0xffffffff
158-
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
159-
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
160-
#endif // HIP_VERSION >= 70000000 && defined(GGML_HIP_ROCWMMA_FATTN)
161-
162154
#if !defined(__HIP_PLATFORM_AMD__)
163155
#error "The HIP backend supports only AMD targets"
164156
#endif // !defined(__HIP_PLATFORM_AMD__)

0 commit comments

Comments
 (0)