Skip to content

Commit 56e53a2

Browse files
committed
use dpp instead of shfls for reduction from maximumbusdatatype
Originally written by maximumbusdatatype ggml-org#16291 (comment)
1 parent 1bcb140 commit 56e53a2

File tree

8 files changed

+67
-26
lines changed

8 files changed

+67
-26
lines changed

ggml/src/ggml-cuda/argmax.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
2222

2323
#pragma unroll
2424
for (int offset = 16; offset > 0; offset >>= 1) {
25-
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
26-
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
25+
const float val = ggml_cuda_shfl_xor_sync(maxval, offset);
26+
const int col = ggml_cuda_shfl_xor_sync(argmax, offset);
2727
if (val > maxval) {
2828
maxval = val;
2929
argmax = col;
@@ -51,8 +51,8 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
5151
}
5252
#pragma unroll
5353
for (int offset = 16; offset > 0; offset >>= 1) {
54-
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
55-
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
54+
const float val = ggml_cuda_shfl_xor_sync(maxval, offset);
55+
const int col = ggml_cuda_shfl_xor_sync(argmax, offset);
5656
if (val > maxval) {
5757
maxval = val;
5858
argmax = col;

ggml/src/ggml-cuda/common.cuh

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -359,14 +359,56 @@ struct ggml_cuda_unroll<1> {
359359
}
360360
};
361361

362+
#ifdef GGML_USE_HIP
363+
template <int dpp_ctrl, typename T, int row_mask = 0xf, int bank_mask = 0xf, bool bound_ctrl = true>
364+
static __device__ __forceinline__ T hip_move_dpp(T old, T v) {
365+
return __builtin_bit_cast(
366+
T,
367+
__builtin_amdgcn_update_dpp(
368+
__builtin_bit_cast(int, old),
369+
__builtin_bit_cast(int, v),
370+
dpp_ctrl,
371+
row_mask,
372+
bank_mask,
373+
bound_ctrl
374+
)
375+
);
376+
}
377+
378+
template <int mask, typename T>
379+
static __device__ __forceinline__ T hip_ds_swizzle(T v) {
380+
return __builtin_bit_cast(T, __builtin_amdgcn_ds_swizzle(__builtin_bit_cast(int, v), mask));
381+
}
382+
#endif // GGML_USE_HIP
383+
384+
template<int width = WARP_SIZE, typename T>
385+
static __device__ __forceinline__ T ggml_cuda_shfl_xor_sync(T x, int offset) {
386+
#if defined(GGML_USE_HIP)
387+
static T old;
388+
389+
// clang (v20) will not unroll loops with just the plain `offset` in switch
390+
switch (~offset) {
391+
// subgroups (width) should not make a difference for a butterfly shuffle pattern
392+
case ~1: return hip_move_dpp<0x160 + 1>(old, x); // row_xor_mask: offset
393+
case ~2: return hip_move_dpp<0x160 + 2>(old, x);
394+
case ~4: return hip_move_dpp<0x160 + 4>(old, x);
395+
case ~8: return hip_move_dpp<0x160 + 8>(old, x);
396+
case ~16: return hip_ds_swizzle<0x401f>(x); // swap neighboring groups of 16
397+
default: return __shfl_xor(x, offset, width);
398+
}
399+
#else
400+
return __shfl_xor_sync(0xffffffff, x, offset, width);
401+
#endif // GGML_USE_HIP
402+
}
403+
362404
template<int width = WARP_SIZE>
363405
static __device__ __forceinline__ int warp_reduce_sum(int x) {
364406
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
365407
return __reduce_add_sync(0xffffffff, x);
366408
#else
367409
#pragma unroll
368410
for (int offset = width/2; offset > 0; offset >>= 1) {
369-
x += __shfl_xor_sync(0xffffffff, x, offset, width);
411+
x += ggml_cuda_shfl_xor_sync<width>(x, offset);
370412
}
371413
return x;
372414
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -376,7 +418,7 @@ template<int width = WARP_SIZE>
376418
static __device__ __forceinline__ float warp_reduce_sum(float x) {
377419
#pragma unroll
378420
for (int offset = width/2; offset > 0; offset >>= 1) {
379-
x += __shfl_xor_sync(0xffffffff, x, offset, width);
421+
x += ggml_cuda_shfl_xor_sync<width>(x, offset);
380422
}
381423
return x;
382424
}
@@ -385,8 +427,8 @@ template<int width = WARP_SIZE>
385427
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
386428
#pragma unroll
387429
for (int offset = width/2; offset > 0; offset >>= 1) {
388-
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
389-
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
430+
a.x += ggml_cuda_shfl_xor_sync<width>(a.x, offset);
431+
a.y += ggml_cuda_shfl_xor_sync<width>(a.y, offset);
390432
}
391433
return a;
392434
}
@@ -396,7 +438,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
396438
#ifdef FP16_AVAILABLE
397439
#pragma unroll
398440
for (int offset = width/2; offset > 0; offset >>= 1) {
399-
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
441+
a = __hadd2(a, ggml_cuda_shfl_xor_sync<width>(a, offset));
400442
}
401443
return a;
402444

@@ -413,7 +455,7 @@ static __device__ __forceinline__ int warp_reduce_all(int x) {
413455
} else {
414456
#pragma unroll
415457
for (int offset = width/2; offset > 0; offset >>= 1) {
416-
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
458+
x = ggml_cuda_shfl_xor_sync<width>(x, offset) && x;
417459
}
418460
return x;
419461
}
@@ -426,7 +468,7 @@ static __device__ __forceinline__ int warp_reduce_any(int x) {
426468
} else {
427469
#pragma unroll
428470
for (int offset = width/2; offset > 0; offset >>= 1) {
429-
x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
471+
x = ggml_cuda_shfl_xor_sync<width>(x, offset) || x;
430472
}
431473
return x;
432474
}
@@ -436,7 +478,7 @@ template<int width = WARP_SIZE>
436478
static __device__ __forceinline__ float warp_reduce_max(float x) {
437479
#pragma unroll
438480
for (int offset = width/2; offset > 0; offset >>= 1) {
439-
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
481+
x = fmaxf(x, ggml_cuda_shfl_xor_sync<width>(x, offset));
440482
}
441483
return x;
442484
}
@@ -475,7 +517,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
475517
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
476518
#pragma unroll
477519
for (int offset = width/2; offset > 0; offset >>= 1) {
478-
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
520+
x = ggml_cuda_hmax2(x, ggml_cuda_shfl_xor_sync<width>(x, offset));
479521
}
480522
return x;
481523
#else

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,8 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
268268
}
269269
#pragma unroll
270270
for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
271-
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
272-
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32);
271+
amax = fmaxf(amax, ggml_cuda_shfl_xor_sync<32>(amax, mask));
272+
sum += ggml_cuda_shfl_xor_sync<32>(sum, mask);
273273
}
274274

275275
const float d = amax / 127;

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
572572
for (int col = 0; col < cols_per_thread; ++col) {
573573
#pragma unroll
574574
for (int offset = 16; offset >= 4; offset >>= 1) {
575-
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
575+
KQ_max_new[col] = fmaxf(KQ_max_new[col], ggml_cuda_shfl_xor_sync(KQ_max_new[col], offset));
576576
}
577577
}
578578

@@ -627,7 +627,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
627627
for (int col = 0; col < cols_per_thread; ++col) {
628628
#pragma unroll
629629
for (int offset = 2; offset >= 1; offset >>= 1) {
630-
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
630+
KQ_max_new[col] = fmaxf(KQ_max_new[col], ggml_cuda_shfl_xor_sync(KQ_max_new[col], offset));
631631
}
632632
}
633633

@@ -950,7 +950,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
950950
for (int col = 0; col < cols_per_thread; ++col) {
951951
#pragma unroll
952952
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
953-
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
953+
KQ_rowsum[col] += ggml_cuda_shfl_xor_sync(KQ_rowsum[col], offset);
954954
}
955955
}
956956
}
@@ -1083,7 +1083,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
10831083
#pragma unroll
10841084
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
10851085
if (offset < WARP_SIZE) {
1086-
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
1086+
KQ_cmn = fmaxf(KQ_cmn, ggml_cuda_shfl_xor_sync(KQ_cmn, offset));
10871087
}
10881088
}
10891089

@@ -1101,7 +1101,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
11011101
#pragma unroll
11021102
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
11031103
if (offset < WARP_SIZE) {
1104-
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
1104+
KQ_crs += ggml_cuda_shfl_xor_sync(KQ_crs, offset);
11051105
}
11061106
}
11071107

ggml/src/ggml-cuda/fattn-vec.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ static __global__ void flash_attn_ext_vec(
282282
for (int j = 0; j < ncols; ++j) {
283283
#pragma unroll
284284
for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {
285-
KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));
285+
KQ_max_new[j] = fmaxf(KQ_max_new[j], ggml_cuda_shfl_xor_sync(KQ_max_new[j], offset));
286286
}
287287
const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);
288288
KQ_max[j] = KQ_max_new[j];

ggml/src/ggml-cuda/quantize.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ static __global__ void quantize_mmq_q8_1(
8989
// Exchange max. abs. value between vals_per_scale/4 threads.
9090
#pragma unroll
9191
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
92-
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
92+
amax = fmaxf(amax, ggml_cuda_shfl_xor_sync(amax, offset));
9393
}
9494

9595
float sum;
@@ -99,7 +99,7 @@ static __global__ void quantize_mmq_q8_1(
9999
// Calculate sums across vals_per_sum/4 threads.
100100
#pragma unroll
101101
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
102-
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
102+
sum += ggml_cuda_shfl_xor_sync(sum, offset);
103103
}
104104
}
105105

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
9090

9191
#pragma unroll
9292
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
93-
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
94-
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
93+
const float val = ggml_cuda_shfl_xor_sync(max_val, mask);
94+
const int expert = ggml_cuda_shfl_xor_sync(max_expert, mask);
9595
if (val > max_val || (val == max_val && expert < max_expert)) {
9696
max_val = val;
9797
max_expert = expert;

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
2828
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
2929
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
30-
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
3130
#define __all_sync(mask, var) __all(var)
3231
#define __any_sync(mask, var) __any(var)
3332
#define cublasCreate hipblasCreate

0 commit comments

Comments
 (0)