1212#else
1313#define GGML_COMMON_DECL_CUDA
1414#define GGML_COMMON_IMPL_CUDA
15+ #if defined(GGML_USE_MUSA)
16+ #define GGML_COMMON_DECL_MUSA
17+ #define GGML_COMMON_IMPL_MUSA
18+ #endif
1519#endif
1620#include " ggml-common.h"
1721
114118#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
115119#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
116120#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
121+ #elif defined(GGML_USE_MUSA)
122+ #include < musa_runtime.h>
123+ #include < musa.h>
124+ #include < mublas.h>
125+ #include < musa_fp16.h>
126+ // XXX: Keep the following order the same as hipBLAS
127+ // #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
128+ // #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
129+ #define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
130+ #define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
131+ #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
132+ #define CUBLAS_OP_N MUBLAS_OP_N
133+ #define CUBLAS_OP_T MUBLAS_OP_T
134+ #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
135+ // #define CUBLAS_TF32_TENSOR_OP_MATH 0
136+ #define CUDA_R_16F MUSA_R_16F
137+ #define CUDA_R_32F MUSA_R_32F
138+ // #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
139+ // #define cublasComputeType_t mublasComputeType_t
140+ #define cublasCreate mublasCreate
141+ #define cublasDestroy mublasDestroy
142+ #define cublasGemmEx mublasGemmEx
143+ #define cublasGemmBatchedEx mublasGemmBatchedEx
144+ #define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
145+ #define cublasHandle_t mublasHandle_t
146+ // #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
147+ #define cublasSetMathMode mublasSetMathMode
148+ #define cublasSetStream mublasSetStream
149+ #define cublasSgemm mublasSgemm
150+ #define cublasStatus_t mublasStatus_t
151+ #define cudaDataType_t musaDataType_t // deprecated, new hipblasDatatype not in 5.6
152+ #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
153+ #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
154+ #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
155+ #define cudaDeviceProp musaDeviceProp
156+ #define cudaDeviceSynchronize musaDeviceSynchronize
157+ #define cudaError_t musaError_t
158+ #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
159+ #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
160+ #define cudaEventCreateWithFlags musaEventCreateWithFlags
161+ #define cudaEventDisableTiming musaEventDisableTiming
162+ #define cudaEventRecord musaEventRecord
163+ #define cudaEventSynchronize musaEventSynchronize
164+ #define cudaEvent_t musaEvent_t
165+ #define cudaEventDestroy musaEventDestroy
166+ #define cudaFree musaFree
167+ #define cudaFreeHost musaFreeHost
168+ #define cudaGetDevice musaGetDevice
169+ #define cudaGetDeviceCount musaGetDeviceCount
170+ #define cudaGetDeviceProperties musaGetDeviceProperties
171+ #define cudaGetErrorString musaGetErrorString
172+ #define cudaGetLastError musaGetLastError
173+ #define cudaHostRegister musaHostRegister
174+ #define cudaHostRegisterPortable musaHostRegisterPortable
175+ #define cudaHostRegisterReadOnly musaHostRegisterReadOnly
176+ #define cudaHostUnregister musaHostUnregister
177+ #define cudaLaunchHostFunc musaLaunchHostFunc
178+ #define cudaMalloc musaMalloc
179+ #define cudaMallocHost musaMallocHost
180+ #define cudaMemcpy musaMemcpy
181+ #define cudaMemcpyAsync musaMemcpyAsync
182+ #define cudaMemcpyPeerAsync musaMemcpyPeerAsync
183+ #define cudaMemcpy2DAsync musaMemcpy2DAsync
184+ #define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
185+ #define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
186+ #define cudaMemcpyHostToDevice musaMemcpyHostToDevice
187+ #define cudaMemcpyKind musaMemcpyKind
188+ #define cudaMemset musaMemset
189+ #define cudaMemsetAsync musaMemsetAsync
190+ #define cudaMemGetInfo musaMemGetInfo
191+ #define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
192+ #define cudaSetDevice musaSetDevice
193+ #define cudaStreamCreateWithFlags musaStreamCreateWithFlags
194+ #define cudaStreamDestroy musaStreamDestroy
195+ #define cudaStreamFireAndForget musaStreamFireAndForget
196+ #define cudaStreamNonBlocking musaStreamNonBlocking
197+ #define cudaStreamPerThread musaStreamPerThread
198+ #define cudaStreamSynchronize musaStreamSynchronize
199+ #define cudaStreamWaitEvent musaStreamWaitEvent
200+ #define cudaStream_t musaStream_t
201+ #define cudaSuccess musaSuccess
202+
203+ // XXX: Other CUDA => MUSA mapping
204+ #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
205+ #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
206+ #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
207+ #define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
208+ #define CUdevice MUdevice
209+ #define CUdeviceptr MUdeviceptr
210+ #define CUmemAccessDesc MUmemAccessDesc
211+ #define CUmemAllocationProp MUmemAllocationProp
212+ #define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
213+ #define cuDeviceGet muDeviceGet
214+ #define cuDeviceGetAttribute muDeviceGetAttribute
215+ #define cuMemAddressFree muMemAddressFree
216+ #define cuMemAddressReserve muMemAddressReserve
217+ #define cuMemCreate muMemCreate
218+ #define cuMemGetAllocationGranularity muMemGetAllocationGranularity
219+ #define cuMemMap muMemMap
220+ #define cuMemRelease muMemRelease
221+ #define cuMemSetAccess muMemSetAccess
222+ #define cuMemUnmap muMemUnmap
223+ #define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
224+ #define cudaFuncSetAttribute musaFuncSetAttribute
225+ #define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
226+ #define make_cudaExtent make_musaExtent
227+ #define make_cudaPitchedPtr make_musaPitchedPtr
228+
229+ // XXX: USE_CUDA_GRAPH
230+ #define CUDA_SUCCESS MUSA_SUCCESS
231+ #define CUresult MUresult
232+ #define cuGetErrorString muGetErrorString
233+ #define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
234+ #define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
235+ #define cudaGraphDestroy musaGraphDestroy
236+ #define cudaGraphExecDestroy musaGraphExecDestroy
237+ #define cudaGraphExec_t musaGraphExec_t
238+ #define cudaGraphExecUpdate musaGraphExecUpdate
239+ #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
240+ #define cudaGraphGetNodes musaGraphGetNodes
241+ #define cudaGraphInstantiate musaGraphInstantiate
242+ #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
243+ #define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
244+ #define cudaGraphLaunch musaGraphLaunch
245+ #define cudaGraphNodeGetType musaGraphNodeGetType
246+ #define cudaGraphNode_t musaGraphNode_t
247+ #define cudaGraphNodeType musaGraphNodeType
248+ #define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
249+ #define cudaGraph_t musaGraph_t
250+ #define cudaKernelNodeParams musaKernelNodeParams
251+ #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
252+ #define cudaStreamEndCapture musaStreamEndCapture
253+
254+ // XXX: cuBLAS => muBLAS mapping
255+ #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
256+ #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
257+ #define CUBLAS_COMPUTE_16F CUDA_R_16F
258+ #define CUBLAS_COMPUTE_32F CUDA_R_32F
259+ #define cublasComputeType_t cudaDataType_t
260+
261+ // XXX: Clang builtins mapping
262+ #define __vsubss4 __vsubss4_musa
263+ #define __vsub4 __vsub4_musa
264+ #define __vcmpeq4 __vcmpeq4_musa
265+ #define __vcmpne4 __vcmpne4_musa
117266#else
118267#include < cuda_runtime.h>
119268#include < cuda.h>
@@ -168,9 +317,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
168317
169318#define CUDA_CHECK (err ) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
170319
171- #if CUDART_VERSION >= 12000
172- static const char * cublas_get_error_str (const cublasStatus_t err) {
320+ #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
321+ static const char * cublas_get_error_str (const mublasStatus_t err) {
322+ #ifndef GGML_USE_MUSA
173323 return cublasGetStatusString (err);
324+ #else
325+ return mublasStatus_to_string (err);
326+ #endif // GGML_USE_MUSA
174327 }
175328#else
176329 static const char * cublas_get_error_str (const cublasStatus_t err) {
@@ -200,7 +353,7 @@ static const char * cu_get_error_str(CUresult err) {
200353#define CU_CHECK (err ) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
201354#endif
202355
203- #if CUDART_VERSION >= 11100
356+ #if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
204357#define GGML_CUDA_ASSUME (x ) __builtin_assume(x)
205358#else
206359#define GGML_CUDA_ASSUME (x )
@@ -214,6 +367,62 @@ typedef float dfloat; // dequantize float
214367typedef float2 dfloat2;
215368#endif // GGML_CUDA_F16
216369
370+ #if defined(GGML_USE_MUSA)
371+ #ifndef __has_builtin
372+ #define __has_builtin (x ) 0
373+ #endif
374+
375+ typedef int8_t int8x4_t __attribute__ ((ext_vector_type(4 )));
376+ typedef uint8_t uint8x4_t __attribute__ ((ext_vector_type(4 )));
377+ static __device__ __forceinline__ int __vsubss4_musa (const int a, const int b) {
378+ const int8x4_t va = reinterpret_cast <const int8x4_t &>(a);
379+ const int8x4_t vb = reinterpret_cast <const int8x4_t &>(b);
380+ #if __has_builtin(__builtin_elementwise_sub_sat)
381+ const int8x4_t c = __builtin_elementwise_sub_sat (va, vb);
382+ return reinterpret_cast <const int &>(c);
383+ #else
384+ int8x4_t c;
385+ int16_t tmp;
386+ #pragma unroll
387+ for (int i = 0 ; i < 4 ; i++) {
388+ tmp = va[i] - vb[i];
389+ if (tmp > std::numeric_limits<int8_t >::max ()) tmp = std::numeric_limits<int8_t >::max ();
390+ if (tmp < std::numeric_limits<int8_t >::min ()) tmp = std::numeric_limits<int8_t >::min ();
391+ c[i] = tmp;
392+ }
393+ return reinterpret_cast <int &>(c);
394+ #endif // __has_builtin(__builtin_elementwise_sub_sat)
395+ }
396+
397+ static __device__ __forceinline__ int __vsub4_musa (const int a, const int b) {
398+ return __vsubss4_musa (a, b);
399+ }
400+
401+ static __device__ __forceinline__ unsigned int __vcmpeq4_musa (unsigned int a, unsigned int b) {
402+ const uint8x4_t & va = reinterpret_cast <const uint8x4_t &>(a);
403+ const uint8x4_t & vb = reinterpret_cast <const uint8x4_t &>(b);
404+ unsigned int c;
405+ uint8x4_t & vc = reinterpret_cast <uint8x4_t &>(c);
406+ #pragma unroll
407+ for (int i = 0 ; i < 4 ; ++i) {
408+ vc[i] = va[i] == vb[i] ? 0xff : 0x00 ;
409+ }
410+ return c;
411+ }
412+
413+ static __device__ __forceinline__ unsigned int __vcmpne4_musa (unsigned int a, unsigned int b) {
414+ const uint8x4_t & va = reinterpret_cast <const uint8x4_t &>(a);
415+ const uint8x4_t & vb = reinterpret_cast <const uint8x4_t &>(b);
416+ unsigned int c;
417+ uint8x4_t & vc = reinterpret_cast <uint8x4_t &>(c);
418+ #pragma unroll
419+ for (int i = 0 ; i < 4 ; ++i) {
420+ vc[i] = va[i] == vb[i] ? 0x00 : 0xff ;
421+ }
422+ return c;
423+ }
424+ #endif // defined(GGML_USE_MUSA)
425+
217426#if defined(GGML_USE_HIPBLAS)
218427#define __CUDA_ARCH__ 1300
219428
@@ -455,7 +664,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
455664 const uint32_t mask_high = 0xFFFF0000 * (float (__high2half (a)) > float (__high2half (b)));
456665 return mask_low | mask_high;
457666}
458- #endif // CUDART_VERSION < 12000
667+ #endif // CUDART_VERSION < CUDART_HMASK
459668
460669static __device__ __forceinline__ int ggml_cuda_dp4a (const int a, const int b, int c) {
461670#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
0 commit comments