From a151674c3e073121a4058d28938c9664a4b26c3a Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 29 Jan 2025 17:46:23 +0100 Subject: [PATCH 1/3] CUDA/HIP: add warp_size to cuda_device_info --- ggml/src/ggml-cuda/common.cuh | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a66322da05a36..eec227dce3a1e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -520,6 +520,7 @@ struct ggml_cuda_device_info { bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory size_t total_vram; + int warp_size; // Number of threads in a dispatch }; cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index de3f9c2ca1ed5..ecf06fec408bb 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -242,6 +242,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].nsm = prop.multiProcessorCount; info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].warp_size = prop.warpSize; #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpbo = prop.sharedMemPerBlock; @@ -256,8 +257,9 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc += prop.minor * 0x10; } } - GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s\n", - id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no"); + GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n", + id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, + device_vmm ? "yes" : "no", prop.warpSize); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; From af71052cddd351d7573b16807438f777f3aad779 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 29 Jan 2025 19:12:42 +0100 Subject: [PATCH 2/3] HIP: Prepare reduction operators for wave 64 --- ggml/src/ggml-cuda/common.cuh | 59 +++++++++++++++------------------ ggml/src/ggml-cuda/ggml-cuda.cu | 4 +-- 2 files changed, 28 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index eec227dce3a1e..8d8d3932e0e58 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -190,53 +190,46 @@ static __device__ void no_device_code( #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.") #endif // __CUDA_ARCH__ +template static __device__ __forceinline__ int warp_reduce_sum(int x) { #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE return __reduce_add_sync(0xffffffff, x); #else #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, offset, 32); + for (int offset = width/2; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, width); } return x; #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE } +template static __device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, offset, 32); + for (int offset = width/2; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, width); } return x; } +template static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32); - a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32); + for (int offset = width/2; offset > 0; offset >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width); + a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width); } return a; } +template static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #ifdef FP16_AVAILABLE - -#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32); - reinterpret_cast(a.x) += __low2half(a_other); - reinterpret_cast(a.y) += __high2half(a_other); - } - return a; -#else #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32)); + for (int offset = width/2; offset > 0; offset >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width)); } return a; -#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #else NO_DEVICE_CODE; @@ -244,10 +237,11 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // FP16_AVAILABLE } +template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32)); + for (int offset = width/2; offset > 0; offset >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width)); } return x; } @@ -269,35 +263,34 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b } static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - -#if CUDART_VERSION >= CUDART_HMAX +#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000 + return half2(__hmax(a.x, b.x), __hmax(a.y, b.y)); +#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX return __hmax2(a, b); -#else +#elif !defined(GGML_USE_HIP) half2 ret; reinterpret_cast(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b))); reinterpret_cast(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b))); return ret; -#endif // CUDART_VERSION >= CUDART_HMAX - #else GGML_UNUSED(a); GGML_UNUSED(b); NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif } +template static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32)); + for (int offset = width/2; offset > 0; offset >>= 1) { + x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width)); } return x; #else GGML_UNUSED(x); NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) } #if CUDART_VERSION < CUDART_HMASK diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ecf06fec408bb..383131c7789d5 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -240,8 +240,8 @@ static ggml_cuda_device_info ggml_cuda_init() { info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; - info.devices[id].nsm = prop.multiProcessorCount; - info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].nsm = prop.multiProcessorCount; + info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].warp_size = prop.warpSize; #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpbo = prop.sharedMemPerBlock; From 7e1c85cd3e040ff30a67e7e7f1449f796f5c9b81 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 29 Jan 2025 19:36:00 +0100 Subject: [PATCH 3/3] HIP: require at least HIP 5.5 --- ggml/src/ggml-hip/CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index ecc3bc66d44c0..7a877bdc11a6f 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -40,6 +40,10 @@ find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) +if (${hip_VERSION} VERSION_LESS 5.5) + message(FATAL_ERROR "At least ROCM/HIP V5.5 is required") +endif() + message(STATUS "HIP and hipBLAS found") file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")