Skip to content

Commit 6022a58

Browse files
committed
musa: enable MMA
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 27aa259 commit 6022a58

File tree

4 files changed

+70
-25
lines changed

4 files changed

+70
-25
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,11 @@
7676
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
7777

7878
// Moore Threads
79-
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
80-
81-
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
82-
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
83-
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
79+
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 210) // MTT S80, MTT S3000
80+
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 220) // MTT S4000
81+
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 310) // TBD
8482

83+
#define GGML_CUDA_CC_TO_MTHREADS(cc) ((cc) - GGML_CUDA_CC_OFFSET_MTHREADS)
8584
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
8685
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
8786
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
@@ -203,9 +202,9 @@ typedef float2 dfloat2;
203202
#define FP16_AVAILABLE
204203
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
205204

206-
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
205+
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != GGML_CUDA_CC_DP4A
207206
#define FAST_FP16_AVAILABLE
208-
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
207+
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != GGML_CUDA_CC_DP4A
209208

210209
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
211210
#define FP16_MMA_AVAILABLE
@@ -215,6 +214,10 @@ typedef float2 dfloat2;
215214
#define FP16_MMA_AVAILABLE
216215
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
217216

217+
#if defined(GGML_USE_MUSA) && __MUSA_ARCH__ >= GGML_CUDA_CC_TO_MTHREADS(GGML_CUDA_CC_QY2)
218+
#define FP16_MMA_AVAILABLE
219+
#endif // defined(GGML_USE_MUSA) && __MUSA_ARCH__ >= GGML_CUDA_CC_TO_MTHREADS(GGML_CUDA_CC_QY2)
220+
218221
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
219222
#define NEW_MMA_AVAILABLE
220223
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -223,21 +226,21 @@ typedef float2 dfloat2;
223226
#define CP_ASYNC_AVAILABLE
224227
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
225228

226-
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
229+
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < GGML_CUDA_CC_TO_MTHREADS(GGML_CUDA_CC_QY2))
227230
#define FLASH_ATTN_AVAILABLE
228-
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
231+
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < GGML_CUDA_CC_TO_MTHREADS(GGML_CUDA_CC_QY2))
229232

230233
static bool fp16_available(const int cc) {
231234
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
232235
}
233236

234237
static bool fast_fp16_available(const int cc) {
235-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
238+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != GGML_CUDA_CC_DP4A) || GGML_CUDA_CC_IS_AMD(cc);
236239
}
237240

238241
// To be used for feature selection of external libraries, e.g. cuBLAS.
239242
static bool fast_fp16_hardware_available(const int cc) {
240-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
243+
return cc >= GGML_CUDA_CC_PASCAL && cc != GGML_CUDA_CC_DP4A && cc != GGML_CUDA_CC_QY1;
241244
}
242245

243246
// Any FP16 tensor core instructions are available for ggml code.
@@ -246,13 +249,15 @@ static bool fp16_mma_available(const int cc) {
246249
return false;
247250
#else
248251
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
252+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
249253
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
250254
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
251255
}
252256

253257
// To be used for feature selection of external libraries, e.g. cuBLAS.
254258
static bool fp16_mma_hardware_available(const int cc) {
255259
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
260+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
256261
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
257262
}
258263

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
#ifdef FP16_MMA_AVAILABLE
1010
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
#include <mma.h>
12+
#ifdef GGML_USE_MUSA
13+
namespace wmma = mtmusa::wmma;
14+
#else // GGML_USE_MUSA
1215
namespace wmma = nvcuda::wmma;
16+
#endif // GGML_USE_MUSA
1317
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1418
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1519
#include <rocwmma/rocwmma.hpp>

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
268268
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
269269
info.devices[id].warp_size = 32;
270270
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
271-
info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
272-
info.devices[id].cc += prop.minor * 0x10;
271+
info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + 100*prop.major + 10*prop.minor;
273272
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
274273
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
275274
#else
@@ -1195,7 +1194,8 @@ static void ggml_cuda_op_mul_mat_cublas(
11951194

11961195
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
11971196

1198-
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1197+
if ((GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) &&
1198+
src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
11991199
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
12001200
if (src1->type != GGML_TYPE_BF16) {
12011201
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1865,13 +1865,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18651865
// use cublasGemmBatchedEx
18661866
const int64_t ne23 = ne12*ne13;
18671867

1868+
#ifdef GGML_USE_MUSA
1869+
const void ** ptrs_src;
1870+
void ** ptrs_dst;
1871+
CUDA_CHECK(cudaMalloc((void **)&ptrs_src, sizeof(void *)*2*ne23));
1872+
CUDA_CHECK(cudaMalloc((void **)&ptrs_dst, sizeof(void *)*1*ne23));
1873+
#else // GGML_USE_MUSA
18681874
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
18691875
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1876+
#endif // GGML_USE_MUSA
18701877

18711878
dim3 block_dims(ne13, ne12);
18721879
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
18731880
src0_f16, src1_f16, dst_t,
1881+
#ifdef GGML_USE_MUSA
1882+
ptrs_src, ptrs_dst,
1883+
#else // GGML_USE_MUSA
18741884
ptrs_src.get(), ptrs_dst.get(),
1885+
#endif // GGML_USE_MUSA
18751886
ne12, ne13,
18761887
ne23,
18771888
nb02, nb03,
@@ -1881,15 +1892,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18811892
r2, r3);
18821893
CUDA_CHECK(cudaGetLastError());
18831894

1884-
CUBLAS_CHECK(
1895+
#ifdef GGML_USE_MUSA
1896+
cudaDeviceSynchronize();
1897+
const void **Aarray = (const void **) (ptrs_src + 0*ne23);
1898+
const void **Barray = (const void **) (ptrs_src + 1*ne23);
1899+
void **Carray = ( void **) (ptrs_dst + 0*ne23);
1900+
#else // GGML_USE_MUSA
1901+
const void **Aarray = (const void **) (ptrs_src.get() + 0*ne23);
1902+
const void **Barray = (const void **) (ptrs_src.get() + 1*ne23);
1903+
void **Carray = ( void **) (ptrs_dst.get() + 0*ne23);
1904+
#endif // GGML_USE_MUSA
1905+
1906+
CUBLAS_CHECK(
18851907
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
18861908
ne01, ne11, ne10,
1887-
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1888-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1889-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1909+
alpha, Aarray, CUDA_R_16F, nb01/nb00,
1910+
Barray, CUDA_R_16F, s11,
1911+
beta, Carray, cu_data_type, ne0,
18901912
ne23,
18911913
cu_compute_type,
18921914
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1915+
1916+
#ifdef GGML_USE_MUSA
1917+
CUDA_CHECK(cudaFree(ptrs_src));
1918+
CUDA_CHECK(cudaFree(ptrs_dst));
1919+
#endif // GGML_USE_MUSA
18931920
}
18941921
#endif
18951922

@@ -1913,6 +1940,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19131940

19141941
bool any_gpus_with_slow_fp16 = false;
19151942
bool any_gpus_without_fp16_mma = false;
1943+
bool any_gpus_without_batched_cublas = false;
19161944

19171945
if (split) {
19181946
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@@ -1927,12 +1955,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19271955
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
19281956
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
19291957
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1958+
any_gpus_without_batched_cublas = any_gpus_without_batched_cublas || !(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
19301959
}
19311960
} else {
19321961
const int cc = ggml_cuda_info().devices[ctx.device].cc;
19331962
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
19341963
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
19351964
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1965+
any_gpus_without_batched_cublas = any_gpus_without_batched_cublas || !(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
19361966
}
19371967

19381968
// debug helpers
@@ -1951,7 +1981,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19511981
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
19521982
} else if (!split && use_mul_mat_q) {
19531983
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1954-
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1984+
} else if (!split && !any_gpus_without_batched_cublas && src0->type == GGML_TYPE_F16 &&
1985+
(src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
19551986
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
19561987
// general KQ + KQV multi-batch without FlashAttention
19571988
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
@@ -2989,12 +3020,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29893020
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
29903021
return false;
29913022
}
2992-
#ifdef GGML_USE_MUSA
2993-
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3023+
#if defined(GGML_USE_MUSA)
3024+
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3025+
if (GGML_CUDA_CC_IS_MTHREADS(cc) && GGML_CUDA_CC_IS_QY1(cc) &&
3026+
b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
29943027
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
29953028
return false;
29963029
}
2997-
#endif // GGML_USE_MUSA
3030+
#endif // defined(GGML_USE_MUSA)
29983031
switch (a->type) {
29993032
case GGML_TYPE_F32:
30003033
case GGML_TYPE_F16:
@@ -3019,11 +3052,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30193052
case GGML_TYPE_IQ4_NL:
30203053
case GGML_TYPE_IQ4_XS:
30213054
case GGML_TYPE_BF16:
3022-
#ifdef GGML_USE_MUSA
3023-
if (a->type == GGML_TYPE_Q3_K) {
3055+
#if defined(GGML_USE_MUSA)
3056+
if (GGML_CUDA_CC_IS_MTHREADS(cc) && GGML_CUDA_CC_IS_QY2(cc) &&
3057+
a->type == GGML_TYPE_Q2_K) {
30243058
return false;
30253059
}
3026-
#endif // GGML_USE_MUSA
3060+
#endif // defined(GGML_USE_MUSA)
30273061
return true;
30283062
default:
30293063
return false;

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4277,6 +4277,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
42774277
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
42784278
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
42794279

4280+
#if !defined(GGML_USE_MUSA) || (defined(GGML_USE_MUSA) && __MUSA_ARCH__ >= 220)
42804281
for (auto bs : {1,2,4,8}) {
42814282
for (auto nr : {1,4}) {
42824283
for (uint32_t m = 0; m < 2; ++m) {
@@ -4287,6 +4288,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
42874288
}
42884289
}
42894290
}
4291+
#endif // !defined(GGML_USE_MUSA) || (defined(GGML_USE_MUSA) && __MUSA_ARCH__ >= 220)
42904292

42914293
// sycl backend will limit task global_range < MAX_INT
42924294
// test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)

0 commit comments

Comments
 (0)