Skip to content

Commit f1821c6

Browse files
committed
Revert "HIP: bump requirement to rocm 6.1 (ggml-org#15296)"
1 parent 239af52 commit f1821c6

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -468,21 +468,25 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
468468
}
469469

470470
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
471-
#if defined(GGML_USE_HIP)
471+
#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000
472472
return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
473-
#elif CUDART_VERSION >= CUDART_HMAX
473+
#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX
474474
return __hmax2(a, b);
475-
#else
475+
#elif !defined(GGML_USE_HIP)
476476
half2 ret;
477477
reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
478478
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
479479
return ret;
480+
#else
481+
GGML_UNUSED(a);
482+
GGML_UNUSED(b);
483+
NO_DEVICE_CODE;
480484
#endif
481485
}
482486

483487
template<int width = WARP_SIZE>
484488
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
485-
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
489+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
486490
#pragma unroll
487491
for (int offset = width/2; offset > 0; offset >>= 1) {
488492
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
@@ -491,7 +495,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
491495
#else
492496
GGML_UNUSED(x);
493497
NO_DEVICE_CODE;
494-
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
498+
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
495499
}
496500

497501
#if CUDART_VERSION < CUDART_HMASK

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,30 @@ static int ggml_cuda_parse_id(char devName[]) {
183183
#endif // defined(GGML_USE_HIP)
184184

185185
static ggml_cuda_device_info ggml_cuda_init() {
186+
#if defined(GGML_USE_HIP)
187+
// Workaround for a rocBLAS bug when using multiple graphics cards:
188+
// https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
189+
{
190+
int major_version = 0;
191+
size_t version_length = 0;
192+
if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
193+
std::vector<char> version(version_length+1, '\0');
194+
if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
195+
version.resize(::strlen(version.data()));
196+
int parsed_value = 0;
197+
if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
198+
major_version = parsed_value;
199+
}
200+
}
201+
}
202+
if (major_version < 4) {
203+
GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
204+
rocblas_initialize();
205+
CUDA_CHECK(cudaDeviceSynchronize());
206+
}
207+
}
208+
#endif
209+
186210
ggml_cuda_device_info info = {};
187211

188212
cudaError_t err = cudaGetDeviceCount(&info.device_count);

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <hipblas/hipblas.h>
66
#include <hip/hip_fp16.h>
77
#include <hip/hip_bfloat16.h>
8+
// for rocblas_initialize()
9+
#include "rocblas/rocblas.h"
810

911
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
1012
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
@@ -249,3 +251,17 @@ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigne
249251
}
250252
return c;
251253
}
254+
255+
#if HIP_VERSION < 50600000
256+
// __shfl_xor() for half2 was added in ROCm 5.6
257+
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
258+
typedef union half2_b32 {
259+
half2 val;
260+
int b32;
261+
} half2_b32_t;
262+
half2_b32_t tmp;
263+
tmp.val = var;
264+
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
265+
return tmp.val;
266+
}
267+
#endif // HIP_VERSION < 50600000

0 commit comments

Comments
 (0)