Skip to content

Commit 2d1b309

Browse files
committed
HIP: Prepare reduction operators for wave 64
1 parent a151674 commit 2d1b309

File tree

1 file changed

+27
-30
lines changed

1 file changed

+27
-30
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -190,64 +190,58 @@ static __device__ void no_device_code(
190190
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
191191
#endif // __CUDA_ARCH__
192192

193+
template<int width = WARP_SIZE>
193194
static __device__ __forceinline__ int warp_reduce_sum(int x) {
194195
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
195196
return __reduce_add_sync(0xffffffff, x);
196197
#else
197198
#pragma unroll
198-
for (int offset = 16; offset > 0; offset >>= 1) {
199-
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
199+
for (int offset = width/2; offset > 0; offset >>= 1) {
200+
x += __shfl_xor_sync(0xffffffff, x, offset, width);
200201
}
201202
return x;
202203
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
203204
}
204205

206+
template<int width = WARP_SIZE>
205207
static __device__ __forceinline__ float warp_reduce_sum(float x) {
206208
#pragma unroll
207-
for (int offset = 16; offset > 0; offset >>= 1) {
208-
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
209+
for (int offset = width/2; offset > 0; offset >>= 1) {
210+
x += __shfl_xor_sync(0xffffffff, x, offset, width);
209211
}
210212
return x;
211213
}
212214

215+
template<int width = WARP_SIZE>
213216
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
214217
#pragma unroll
215-
for (int offset = 16; offset > 0; offset >>= 1) {
216-
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
217-
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
218+
for (int offset = width/2; offset > 0; offset >>= 1) {
219+
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
220+
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
218221
}
219222
return a;
220223
}
221224

225+
template<int width = WARP_SIZE>
222226
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
223227
#ifdef FP16_AVAILABLE
224-
225-
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
226228
#pragma unroll
227-
for (int offset = 16; offset > 0; offset >>= 1) {
228-
const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
229-
reinterpret_cast<half&>(a.x) += __low2half(a_other);
230-
reinterpret_cast<half&>(a.y) += __high2half(a_other);
229+
for (int offset = width/2; offset > 0; offset >>= 1) {
230+
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
231231
}
232232
return a;
233-
#else
234-
#pragma unroll
235-
for (int offset = 16; offset > 0; offset >>= 1) {
236-
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
237-
}
238-
return a;
239-
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
240233

241234
#else
242235
NO_DEVICE_CODE;
243236
return a;
244237
#endif // FP16_AVAILABLE
245238
}
246239

240+
template<int width = WARP_SIZE>
247241
static __device__ __forceinline__ float warp_reduce_max(float x) {
248242
#pragma unroll
249-
for (int offset = 16; offset > 0; offset >>= 1) {
250-
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
243+
for (int offset = width/2; offset > 0; offset >>= 1) {
244+
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
251245
}
252246
return x;
253247
}
@@ -269,35 +263,38 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
269263
}
270264

271265
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
272-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
266+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
273267

274-
#if CUDART_VERSION >= CUDART_HMAX
268+
#if defined(GGML_USE_HIP)
269+
return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
270+
#elif CUDART_VERSION >= CUDART_HMAX
275271
return __hmax2(a, b);
276272
#else
277273
half2 ret;
278274
reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
279275
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
280276
return ret;
281-
#endif // CUDART_VERSION >= CUDART_HMAX
277+
#endif
282278

283279
#else
284280
GGML_UNUSED(a);
285281
GGML_UNUSED(b);
286282
NO_DEVICE_CODE;
287-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
283+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
288284
}
289285

286+
template<int width = WARP_SIZE>
290287
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
291-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
288+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
292289
#pragma unroll
293-
for (int offset = 16; offset > 0; offset >>= 1) {
294-
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
290+
for (int offset = width/2; offset > 0; offset >>= 1) {
291+
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
295292
}
296293
return x;
297294
#else
298295
GGML_UNUSED(x);
299296
NO_DEVICE_CODE;
300-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
297+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
301298
}
302299

303300
#if CUDART_VERSION < CUDART_HMASK

0 commit comments

Comments
 (0)