|
4 | 4 | #include <hip/hip_runtime.h> |
5 | 5 | #include <hipblas/hipblas.h> |
6 | 6 | #include <hip/hip_fp16.h> |
7 | | -#include <hip/hip_bfloat16.h> |
| 7 | +#include <hip/hip_bf16.h> |
8 | 8 | // for rocblas_initialize() |
9 | 9 | #include "rocblas/rocblas.h" |
10 | 10 |
|
|
137 | 137 | #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR |
138 | 138 | #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED |
139 | 139 |
|
140 | | -#if HIP_VERSION >= 70000000 |
| 140 | +#if HIP_VERSION >= 60500000 |
141 | 141 | #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F |
142 | 142 | #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F |
143 | 143 | #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F |
|
149 | 149 | #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F |
150 | 150 | #define cublasComputeType_t hipblasDatatype_t |
151 | 151 | #define cudaDataType_t hipblasDatatype_t |
152 | | -#endif // HIP_VERSION >= 7000000 |
| 152 | +#endif // HIP_VERSION >= 6050000 |
153 | 153 |
|
154 | 154 | #if !defined(__HIP_PLATFORM_AMD__) |
155 | 155 | #error "The HIP backend supports only AMD targets" |
|
181 | 181 | #define RDNA4 |
182 | 182 | #endif |
183 | 183 |
|
184 | | -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ |
185 | | - defined(__gfx1150__) || defined(__gfx1151__) |
| 184 | +#if defined(__GFX11__) |
186 | 185 | #define RDNA3 |
187 | 186 | #endif |
188 | 187 |
|
|
199 | 198 | #define __has_builtin(x) 0 |
200 | 199 | #endif |
201 | 200 |
|
202 | | -typedef hip_bfloat16 nv_bfloat16; |
203 | | -typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix |
| 201 | +typedef __hip_bfloat16 nv_bfloat16; |
| 202 | +typedef __hip_bfloat162 nv_bfloat162; |
204 | 203 |
|
205 | 204 | typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); |
206 | 205 | typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); |
|
0 commit comments