| 
4 | 4 | #include <hip/hip_runtime.h>  | 
5 | 5 | #include <hipblas/hipblas.h>  | 
6 | 6 | #include <hip/hip_fp16.h>  | 
7 |  | -#include <hip/hip_bfloat16.h>  | 
 | 7 | +#include <hip/hip_bf16.h>  | 
8 | 8 | 
 
  | 
9 | 9 | #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT  | 
10 | 10 | #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT  | 
 | 
135 | 135 | #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR  | 
136 | 136 | #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED  | 
137 | 137 | 
 
  | 
138 |  | -#if HIP_VERSION >= 70000000  | 
 | 138 | +#if HIP_VERSION >= 60500000  | 
139 | 139 | #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F  | 
140 | 140 | #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F  | 
141 | 141 | #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F  | 
 | 
147 | 147 | #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F  | 
148 | 148 | #define cublasComputeType_t hipblasDatatype_t  | 
149 | 149 | #define cudaDataType_t hipblasDatatype_t  | 
150 |  | -#endif // HIP_VERSION >= 7000000  | 
 | 150 | +#endif // HIP_VERSION >= 6050000  | 
151 | 151 | 
 
  | 
152 | 152 | #if !defined(__HIP_PLATFORM_AMD__)  | 
153 | 153 | #error "The HIP backend supports only AMD targets"  | 
 | 
179 | 179 | #define RDNA4  | 
180 | 180 | #endif  | 
181 | 181 | 
 
  | 
182 |  | -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \  | 
183 |  | -    defined(__gfx1150__) || defined(__gfx1151__)  | 
 | 182 | +#if defined(__GFX11__)  | 
184 | 183 | #define RDNA3  | 
185 | 184 | #endif  | 
186 | 185 | 
 
  | 
 | 
197 | 196 |     #define __has_builtin(x) 0  | 
198 | 197 | #endif  | 
199 | 198 | 
 
  | 
200 |  | -typedef hip_bfloat16 nv_bfloat16;  | 
201 |  | -typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix  | 
 | 199 | +typedef __hip_bfloat16 nv_bfloat16;  | 
 | 200 | +typedef __hip_bfloat162 nv_bfloat162;  | 
202 | 201 | 
 
  | 
203 | 202 | typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));  | 
204 | 203 | typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));  | 
 | 
0 commit comments