@@ -190,64 +190,58 @@ static __device__ void no_device_code(
190190#define NO_DEVICE_CODE // GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
191191#endif // __CUDA_ARCH__
192192
193+ template <int width = WARP_SIZE>
193194static __device__ __forceinline__ int warp_reduce_sum (int x) {
194195#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
195196 return __reduce_add_sync (0xffffffff , x);
196197#else
197198#pragma unroll
198- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
199- x += __shfl_xor_sync (0xffffffff , x, offset, 32 );
199+ for (int offset = width/ 2 ; offset > 0 ; offset >>= 1 ) {
200+ x += __shfl_xor_sync (0xffffffff , x, offset, width );
200201 }
201202 return x;
202203#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
203204}
204205
206+ template <int width = WARP_SIZE>
205207static __device__ __forceinline__ float warp_reduce_sum (float x) {
206208#pragma unroll
207- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
208- x += __shfl_xor_sync (0xffffffff , x, offset, 32 );
209+ for (int offset = width/ 2 ; offset > 0 ; offset >>= 1 ) {
210+ x += __shfl_xor_sync (0xffffffff , x, offset, width );
209211 }
210212 return x;
211213}
212214
215+ template <int width = WARP_SIZE>
213216static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
214217#pragma unroll
215- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
216- a.x += __shfl_xor_sync (0xffffffff , a.x , offset, 32 );
217- a.y += __shfl_xor_sync (0xffffffff , a.y , offset, 32 );
218+ for (int offset = width/ 2 ; offset > 0 ; offset >>= 1 ) {
219+ a.x += __shfl_xor_sync (0xffffffff , a.x , offset, width );
220+ a.y += __shfl_xor_sync (0xffffffff , a.y , offset, width );
218221 }
219222 return a;
220223}
221224
225+ template <int width = WARP_SIZE>
222226static __device__ __forceinline__ half2 warp_reduce_sum (half2 a) {
223227#ifdef FP16_AVAILABLE
224-
225- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
226228#pragma unroll
227- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
228- const half2 a_other = __shfl_xor_sync (0xffffffff , a, offset, 32 );
229- reinterpret_cast <half&>(a.x ) += __low2half (a_other);
230- reinterpret_cast <half&>(a.y ) += __high2half (a_other);
229+ for (int offset = width/2 ; offset > 0 ; offset >>= 1 ) {
230+ a = __hadd2 (a, __shfl_xor_sync (0xffffffff , a, offset, width));
231231 }
232232 return a;
233- #else
234- #pragma unroll
235- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
236- a = __hadd2 (a, __shfl_xor_sync (0xffffffff , a, offset, 32 ));
237- }
238- return a;
239- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
240233
241234#else
242235 NO_DEVICE_CODE;
243236 return a;
244237#endif // FP16_AVAILABLE
245238}
246239
240+ template <int width = WARP_SIZE>
247241static __device__ __forceinline__ float warp_reduce_max (float x) {
248242#pragma unroll
249- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
250- x = fmaxf (x, __shfl_xor_sync (0xffffffff , x, offset, 32 ));
243+ for (int offset = width/ 2 ; offset > 0 ; offset >>= 1 ) {
244+ x = fmaxf (x, __shfl_xor_sync (0xffffffff , x, offset, width ));
251245 }
252246 return x;
253247}
@@ -269,35 +263,38 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
269263}
270264
271265static __device__ __forceinline__ half2 ggml_cuda_hmax2 (const half2 a, const half2 b) {
272- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
266+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
273267
274- #if CUDART_VERSION >= CUDART_HMAX
268+ #if defined(GGML_USE_HIP)
269+ return half2 (__hmax (a.x , b.x ), __hmax (a.y , b.y ));
270+ #elif CUDART_VERSION >= CUDART_HMAX
275271 return __hmax2 (a, b);
276272#else
277273 half2 ret;
278274 reinterpret_cast <half&>(ret.x ) = __float2half (fmaxf ( __low2float (a), __low2float (b)));
279275 reinterpret_cast <half&>(ret.y ) = __float2half (fmaxf (__high2float (a), __high2float (b)));
280276 return ret;
281- #endif // CUDART_VERSION >= CUDART_HMAX
277+ #endif
282278
283279#else
284280 GGML_UNUSED (a);
285281 GGML_UNUSED (b);
286282 NO_DEVICE_CODE;
287- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
283+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
288284}
289285
286+ template <int width = WARP_SIZE>
290287static __device__ __forceinline__ half2 warp_reduce_max (half2 x) {
291- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
288+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
292289#pragma unroll
293- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
294- x = ggml_cuda_hmax2 (x, __shfl_xor_sync (0xffffffff , x, offset, 32 ));
290+ for (int offset = width/ 2 ; offset > 0 ; offset >>= 1 ) {
291+ x = ggml_cuda_hmax2 (x, __shfl_xor_sync (0xffffffff , x, offset, width ));
295292 }
296293 return x;
297294#else
298295 GGML_UNUSED (x);
299296 NO_DEVICE_CODE;
300- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
297+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
301298}
302299
303300#if CUDART_VERSION < CUDART_HMASK
0 commit comments