|
10 | 10 | #include "rocblas/rocblas.h" |
11 | 11 | #endif // __HIP_PLATFORM_AMD__ |
12 | 12 |
|
13 | | -#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F |
14 | | -#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F |
15 | | -#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F |
16 | 13 | #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT |
17 | 14 | #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT |
18 | 15 | #define CUBLAS_OP_N HIPBLAS_OP_N |
|
30 | 27 | #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} |
31 | 28 | #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) |
32 | 29 | #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) |
33 | | -#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 |
34 | 30 | #define cublasCreate hipblasCreate |
35 | 31 | #define cublasDestroy hipblasDestroy |
36 | 32 | #define cublasGemmEx hipblasGemmEx |
|
42 | 38 | #define cublasSgemm hipblasSgemm |
43 | 39 | #define cublasStatus_t hipblasStatus_t |
44 | 40 | #define cublasOperation_t hipblasOperation_t |
45 | | -#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6 |
46 | 41 | #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer |
47 | 42 | #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess |
48 | 43 | #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess |
|
144 | 139 | #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR |
145 | 140 | #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED |
146 | 141 |
|
| 142 | +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000 |
| 143 | +#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F |
| 144 | +#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F |
| 145 | +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F |
| 146 | +#define cublasComputeType_t hipblasComputeType_t |
| 147 | +#define cudaDataType_t hipDataType |
| 148 | +#else |
| 149 | +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F |
| 150 | +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F |
| 151 | +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F |
| 152 | +#define cublasComputeType_t hipblasDatatype_t |
| 153 | +#define cudaDataType_t hipblasDatatype_t |
| 154 | +#endif |
| 155 | + |
147 | 156 | #define __CUDA_ARCH__ 1300 |
148 | 157 |
|
149 | 158 | #if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) |
|
0 commit comments