|
4 | 4 | #include <hip/hip_runtime.h> |
5 | 5 | #include <hipblas/hipblas.h> |
6 | 6 | #include <hip/hip_fp16.h> |
7 | | -#include <hip/hip_bfloat16.h> |
| 7 | +#include <hip/hip_bf16.h> |
8 | 8 |
|
9 | 9 | #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT |
10 | 10 | #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT |
|
135 | 135 | #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR |
136 | 136 | #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED |
137 | 137 |
|
138 | | -#if HIP_VERSION >= 70000000 |
| 138 | +#if HIP_VERSION >= 60500000 |
139 | 139 | #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F |
140 | 140 | #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F |
141 | 141 | #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F |
|
147 | 147 | #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F |
148 | 148 | #define cublasComputeType_t hipblasDatatype_t |
149 | 149 | #define cudaDataType_t hipblasDatatype_t |
150 | | -#endif // HIP_VERSION >= 7000000 |
| 150 | +#endif // HIP_VERSION >= 6050000 |
151 | 151 |
|
152 | 152 | #if !defined(__HIP_PLATFORM_AMD__) |
153 | 153 | #error "The HIP backend supports only AMD targets" |
|
179 | 179 | #define RDNA4 |
180 | 180 | #endif |
181 | 181 |
|
182 | | -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ |
183 | | - defined(__gfx1150__) || defined(__gfx1151__) |
| 182 | +#if defined(__GFX11__) |
184 | 183 | #define RDNA3 |
185 | 184 | #endif |
186 | 185 |
|
|
197 | 196 | #define __has_builtin(x) 0 |
198 | 197 | #endif |
199 | 198 |
|
200 | | -typedef hip_bfloat16 nv_bfloat16; |
201 | | -typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix |
| 199 | +typedef __hip_bfloat16 nv_bfloat16; |
| 200 | +typedef __hip_bfloat162 nv_bfloat162; |
202 | 201 |
|
203 | 202 | typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); |
204 | 203 | typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); |
|
0 commit comments