Skip to content

Commit 6c665fa

Browse files
committed
musa: enable fp16 mma (all) and cublas on qy2
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent cdf94a1 commit 6c665fa

File tree

3 files changed

+27
-20
lines changed

3 files changed

+27
-20
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,9 @@
7676
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
7777

7878
// Moore Threads
79-
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
80-
81-
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
82-
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
83-
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
79+
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
80+
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
81+
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
8482

8583
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
8684
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
@@ -211,6 +209,10 @@ typedef float2 dfloat2;
211209
#define FP16_MMA_AVAILABLE
212210
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
213211

212+
#if defined(GGML_USE_MUSA)
213+
#define FP16_MMA_AVAILABLE
214+
#endif // defined(GGML_USE_MUSA)
215+
214216
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
215217
#define NEW_MMA_AVAILABLE
216218
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -219,9 +221,9 @@ typedef float2 dfloat2;
219221
#define CP_ASYNC_AVAILABLE
220222
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
221223

222-
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
224+
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
223225
#define FLASH_ATTN_AVAILABLE
224-
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
226+
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
225227

226228
static bool fp16_available(const int cc) {
227229
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
@@ -233,7 +235,8 @@ static bool fast_fp16_available(const int cc) {
233235

234236
// To be used for feature selection of external libraries, e.g. cuBLAS.
235237
static bool fast_fp16_hardware_available(const int cc) {
236-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
238+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||
239+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
237240
}
238241

239242
// Any FP16 tensor core instructions are available for ggml code.
@@ -242,14 +245,16 @@ static bool fp16_mma_available(const int cc) {
242245
return false;
243246
#else
244247
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
245-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
248+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
249+
GGML_CUDA_CC_IS_MTHREADS(cc);
246250
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
247251
}
248252

249253
// To be used for feature selection of external libraries, e.g. cuBLAS.
250254
static bool fp16_mma_hardware_available(const int cc) {
251255
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
252-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
256+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
257+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
253258
}
254259

255260
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
#ifdef FP16_MMA_AVAILABLE
1010
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
#include <mma.h>
12+
#ifdef GGML_USE_MUSA
13+
namespace wmma = mtmusa::wmma;
14+
#else // GGML_USE_MUSA
1215
namespace wmma = nvcuda::wmma;
16+
#endif // GGML_USE_MUSA
1317
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1418
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1519
#include <rocwmma/rocwmma.hpp>

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,7 +1200,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12001200

12011201
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
12021202

1203-
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1203+
if ((GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2)) && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
12041204
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
12051205
if (src1->type != GGML_TYPE_BF16) {
12061206
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1228,7 +1228,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12281228

12291229
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
12301230
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1231-
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
1231+
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2)) && use_fp16) {
12321232
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
12331233
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
12341234
if (src0->type != GGML_TYPE_F16) {
@@ -3009,9 +3009,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30093009
return false;
30103010
}
30113011
#ifdef GGML_USE_MUSA
3012-
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3013-
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3014-
return false;
3012+
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3013+
if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3014+
if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
3015+
a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
3016+
return false;
3017+
}
30153018
}
30163019
#endif // GGML_USE_MUSA
30173020
switch (a->type) {
@@ -3038,11 +3041,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30383041
case GGML_TYPE_IQ4_NL:
30393042
case GGML_TYPE_IQ4_XS:
30403043
case GGML_TYPE_BF16:
3041-
#ifdef GGML_USE_MUSA
3042-
if (a->type == GGML_TYPE_Q3_K) {
3043-
return false;
3044-
}
3045-
#endif // GGML_USE_MUSA
30463044
return true;
30473045
default:
30483046
return false;

0 commit comments

Comments
 (0)