Skip to content

Commit e7b884d

Browse files
try AMD fix
1 parent 07c814b commit e7b884d

File tree

3 files changed

+8
-12
lines changed

3 files changed

+8
-12
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -420,12 +420,9 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
420420

421421
template<int width = WARP_SIZE>
422422
static __device__ __forceinline__ int warp_reduce_all(int x) {
423-
#ifndef GGML_USE_HIP
424-
if (width == WARP_SIZE) {
423+
if (width == ggml_cuda_get_physical_warp_size()) {
425424
return __all_sync(0xffffffff, x);
426-
} else
427-
#endif // GGML_USE_HIP
428-
{
425+
} else {
429426
#pragma unroll
430427
for (int offset = width/2; offset > 0; offset >>= 1) {
431428
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
@@ -436,12 +433,9 @@ static __device__ __forceinline__ int warp_reduce_all(int x) {
436433

437434
template<int width = WARP_SIZE>
438435
static __device__ __forceinline__ int warp_reduce_any(int x) {
439-
#ifndef GGML_USE_HIP
440-
if (width == WARP_SIZE) {
436+
if (width == ggml_cuda_get_physical_warp_size()) {
441437
return __any_sync(0xffffffff, x);
442-
} else
443-
#endif // GGML_USE_HIP
444-
{
438+
} else {
445439
#pragma unroll
446440
for (int offset = width/2; offset > 0; offset >>= 1) {
447441
x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;

ggml/src/ggml-cuda/mmq.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ static __global__ void mmq_ids_helper(
4141
ids_dst_shared[it_compact] = it*n_expert_used + iex_used;
4242
}
4343

44-
if (warp_reduce_any(iex_used != -1)) {
44+
if (warp_reduce_any<warp_size>(iex_used != -1)) {
4545
it_compact++;
4646
}
4747
}
@@ -80,7 +80,7 @@ static __global__ void mmq_ids_helper(
8080
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
8181
}
8282
}
83-
nex_prev = warp_reduce_sum(nex_prev);
83+
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
8484

8585
for (int it = threadIdx.x; it < it_compact; it += warp_size) {
8686
ids_src1[nex_prev + it] = ids_src1_shared[it];

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
2525
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
2626
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
27+
#define __all_sync(mask, var) __all(var)
28+
#define __any_sync(mask, var) __any(var)
2729
#define cublasCreate hipblasCreate
2830
#define cublasDestroy hipblasDestroy
2931
#define cublasGemmEx hipblasGemmEx

0 commit comments

Comments
 (0)