Skip to content

Commit 39f4ef6

Browse files
committed
musa: enable MMA
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent f0dd6a1 commit 39f4ef6

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ typedef float2 dfloat2;
215215
#define FP16_MMA_AVAILABLE
216216
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
217217

218+
#if defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1
219+
#define FP16_MMA_AVAILABLE
220+
#endif // defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1
221+
218222
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
219223
#define NEW_MMA_AVAILABLE
220224
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -232,12 +236,12 @@ static bool fp16_available(const int cc) {
232236
}
233237

234238
static bool fast_fp16_available(const int cc) {
235-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
239+
return fp16_available(cc) && cc != 610;
236240
}
237241

238242
// To be used for feature selection of external libraries, e.g. cuBLAS.
239243
static bool fast_fp16_hardware_available(const int cc) {
240-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
244+
return cc >= GGML_CUDA_CC_PASCAL && cc != 610 && cc != GGML_CUDA_CC_QY1;
241245
}
242246

243247
// Any FP16 tensor core instructions are available for ggml code.
@@ -246,13 +250,15 @@ static bool fp16_mma_available(const int cc) {
246250
return false;
247251
#else
248252
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
253+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
249254
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
250255
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
251256
}
252257

253258
// To be used for feature selection of external libraries, e.g. cuBLAS.
254259
static bool fp16_mma_hardware_available(const int cc) {
255260
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
261+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
256262
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
257263
}
258264

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
#ifdef FP16_MMA_AVAILABLE
1010
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
#include <mma.h>
12+
#ifdef GGML_USE_MUSA
13+
namespace wmma = mtmusa::wmma;
14+
#else // GGML_USE_MUSA
1215
namespace wmma = nvcuda::wmma;
16+
#endif // GGML_USE_MUSA
1317
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1418
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1519
#include <rocwmma/rocwmma.hpp>

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,13 +1851,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18511851
// use cublasGemmBatchedEx
18521852
const int ne23 = ne12*ne13;
18531853

1854+
#ifdef GGML_USE_MUSA
1855+
const void ** ptrs_src;
1856+
void ** ptrs_dst;
1857+
CUDA_CHECK(cudaMalloc((void **)&ptrs_src, sizeof(half *)*2*ne23));
1858+
CUDA_CHECK(cudaMalloc((void **)&ptrs_dst, sizeof(half *)*1*ne23));
1859+
#else // GGML_USE_MUSA
18541860
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
18551861
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1862+
#endif // GGML_USE_MUSA
18561863

18571864
dim3 block_dims(ne13, ne12);
18581865
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
18591866
src0_f16, src1_f16, dst_t,
1867+
#ifdef GGML_USE_MUSA
1868+
ptrs_src, ptrs_dst,
1869+
#else // GGML_USE_MUSA
18601870
ptrs_src.get(), ptrs_dst.get(),
1871+
#endif // GGML_USE_MUSA
18611872
ne12, ne13,
18621873
ne23,
18631874
nb02, nb03,
@@ -1867,15 +1878,30 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18671878
r2, r3);
18681879
CUDA_CHECK(cudaGetLastError());
18691880

1881+
#ifdef GGML_USE_MUSA
1882+
const void **Aarray = (const void **) (ptrs_src + 0 * ne23);
1883+
const void **Barray = (const void **) (ptrs_src + 1 * ne23);
1884+
void **Carray = (void **) (ptrs_dst + 0 * ne23);
1885+
#else // GGML_USE_MUSA
1886+
const void **Aarray = (const void **) (ptrs_src.get() + 0 * ne23);
1887+
const void **Barray = (const void **) (ptrs_src.get() + 1 * ne23);
1888+
void **Carray = (void **) (ptrs_dst.get() + 0 * ne23);
1889+
#endif // GGML_USE_MUSA
1890+
18701891
CUBLAS_CHECK(
18711892
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
18721893
ne01, ne11, ne10,
1873-
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1874-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10,
1875-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
1894+
alpha, Aarray, CUDA_R_16F, nb01/nb00,
1895+
Barray, CUDA_R_16F, nb11/nb10,
1896+
beta, Carray, cu_data_type, ne01,
18761897
ne23,
18771898
cu_compute_type,
18781899
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1900+
1901+
#ifdef GGML_USE_MUSA
1902+
CUDA_CHECK(cudaFree(ptrs_src));
1903+
CUDA_CHECK(cudaFree(ptrs_dst));
1904+
#endif // GGML_USE_MUSA
18791905
}
18801906
#endif
18811907

0 commit comments

Comments
 (0)