Skip to content

Commit ef249d9

Browse files
committed
HIP: Cleanup hipification header
Switch over to hip_bf16 from legacy hip_bfloat16 Simplify RDNA3 define Reduce swap over of new hipblas api to rocm 6.5 as this version is used for rocm 7.0 previews
1 parent e71d48e commit ef249d9

File tree

1 file changed

+6
-7
lines changed
  • ggml/src/ggml-cuda/vendors

1 file changed

+6
-7
lines changed

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <hip/hip_runtime.h>
55
#include <hipblas/hipblas.h>
66
#include <hip/hip_fp16.h>
7-
#include <hip/hip_bfloat16.h>
7+
#include <hip/hip_bf16.h>
88
// for rocblas_initialize()
99
#include "rocblas/rocblas.h"
1010

@@ -137,7 +137,7 @@
137137
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
138138
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
139139

140-
#if HIP_VERSION >= 70000000
140+
#if HIP_VERSION >= 60500000
141141
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
142142
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
143143
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
@@ -149,7 +149,7 @@
149149
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
150150
#define cublasComputeType_t hipblasDatatype_t
151151
#define cudaDataType_t hipblasDatatype_t
152-
#endif // HIP_VERSION >= 7000000
152+
#endif // HIP_VERSION >= 6050000
153153

154154
#if !defined(__HIP_PLATFORM_AMD__)
155155
#error "The HIP backend supports only AMD targets"
@@ -181,8 +181,7 @@
181181
#define RDNA4
182182
#endif
183183

184-
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
185-
defined(__gfx1150__) || defined(__gfx1151__)
184+
#if defined(__GFX11__)
186185
#define RDNA3
187186
#endif
188187

@@ -199,8 +198,8 @@
199198
#define __has_builtin(x) 0
200199
#endif
201200

202-
typedef hip_bfloat16 nv_bfloat16;
203-
typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
201+
typedef __hip_bfloat16 nv_bfloat16;
202+
typedef __hip_bfloat162 nv_bfloat162;
204203

205204
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
206205
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));

0 commit comments

Comments
 (0)