Skip to content

Commit 97e91c2

Browse files
committed
musa: enable MMA
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 27aa259 commit 97e91c2

File tree

4 files changed

+56
-12
lines changed

4 files changed

+56
-12
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ typedef float2 dfloat2;
215215
#define FP16_MMA_AVAILABLE
216216
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
217217

218+
#if defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1
219+
#define FP16_MMA_AVAILABLE
220+
#endif // defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1
221+
218222
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
219223
#define NEW_MMA_AVAILABLE
220224
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -237,7 +241,7 @@ static bool fast_fp16_available(const int cc) {
237241

238242
// To be used for feature selection of external libraries, e.g. cuBLAS.
239243
static bool fast_fp16_hardware_available(const int cc) {
240-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
244+
return cc >= GGML_CUDA_CC_PASCAL && cc != 610 && cc != GGML_CUDA_CC_QY1;
241245
}
242246

243247
// Any FP16 tensor core instructions are available for ggml code.
@@ -246,13 +250,15 @@ static bool fp16_mma_available(const int cc) {
246250
return false;
247251
#else
248252
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
253+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
249254
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
250255
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
251256
}
252257

253258
// To be used for feature selection of external libraries, e.g. cuBLAS.
254259
static bool fp16_mma_hardware_available(const int cc) {
255260
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
261+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
256262
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
257263
}
258264

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
#ifdef FP16_MMA_AVAILABLE
1010
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
#include <mma.h>
12+
#ifdef GGML_USE_MUSA
13+
namespace wmma = mtmusa::wmma;
14+
#else // GGML_USE_MUSA
1215
namespace wmma = nvcuda::wmma;
16+
#endif // GGML_USE_MUSA
1317
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1418
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1519
#include <rocwmma/rocwmma.hpp>

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,7 +1195,8 @@ static void ggml_cuda_op_mul_mat_cublas(
11951195

11961196
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
11971197

1198-
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1198+
if ((GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) &&
1199+
src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
11991200
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
12001201
if (src1->type != GGML_TYPE_BF16) {
12011202
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1865,13 +1866,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18651866
// use cublasGemmBatchedEx
18661867
const int64_t ne23 = ne12*ne13;
18671868

1869+
#ifdef GGML_USE_MUSA
1870+
const void ** ptrs_src;
1871+
void ** ptrs_dst;
1872+
CUDA_CHECK(cudaMalloc((void **)&ptrs_src, sizeof(void *)*2*ne23));
1873+
CUDA_CHECK(cudaMalloc((void **)&ptrs_dst, sizeof(void *)*1*ne23));
1874+
#else // GGML_USE_MUSA
18681875
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
18691876
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1877+
#endif // GGML_USE_MUSA
18701878

18711879
dim3 block_dims(ne13, ne12);
18721880
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
18731881
src0_f16, src1_f16, dst_t,
1882+
#ifdef GGML_USE_MUSA
1883+
ptrs_src, ptrs_dst,
1884+
#else // GGML_USE_MUSA
18741885
ptrs_src.get(), ptrs_dst.get(),
1886+
#endif // GGML_USE_MUSA
18751887
ne12, ne13,
18761888
ne23,
18771889
nb02, nb03,
@@ -1881,15 +1893,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18811893
r2, r3);
18821894
CUDA_CHECK(cudaGetLastError());
18831895

1884-
CUBLAS_CHECK(
1896+
#ifdef GGML_USE_MUSA
1897+
cudaDeviceSynchronize();
1898+
const void **Aarray = (const void **) (ptrs_src + 0*ne23);
1899+
const void **Barray = (const void **) (ptrs_src + 1*ne23);
1900+
void **Carray = ( void **) (ptrs_dst + 0*ne23);
1901+
#else // GGML_USE_MUSA
1902+
const void **Aarray = (const void **) (ptrs_src.get() + 0*ne23);
1903+
const void **Barray = (const void **) (ptrs_src.get() + 1*ne23);
1904+
void **Carray = ( void **) (ptrs_dst.get() + 0*ne23);
1905+
#endif // GGML_USE_MUSA
1906+
1907+
CUBLAS_CHECK(
18851908
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
18861909
ne01, ne11, ne10,
1887-
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1888-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1889-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1910+
alpha, Aarray, CUDA_R_16F, nb01/nb00,
1911+
Barray, CUDA_R_16F, s11,
1912+
beta, Carray, cu_data_type, ne0,
18901913
ne23,
18911914
cu_compute_type,
18921915
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1916+
1917+
#ifdef GGML_USE_MUSA
1918+
CUDA_CHECK(cudaFree(ptrs_src));
1919+
CUDA_CHECK(cudaFree(ptrs_dst));
1920+
#endif // GGML_USE_MUSA
18931921
}
18941922
#endif
18951923

@@ -1913,6 +1941,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19131941

19141942
bool any_gpus_with_slow_fp16 = false;
19151943
bool any_gpus_without_fp16_mma = false;
1944+
bool any_gpus_without_batched_cublas = false;
19161945

19171946
if (split) {
19181947
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@@ -1927,12 +1956,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19271956
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
19281957
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
19291958
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1959+
any_gpus_without_batched_cublas = any_gpus_without_batched_cublas || !(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
19301960
}
19311961
} else {
19321962
const int cc = ggml_cuda_info().devices[ctx.device].cc;
19331963
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
19341964
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
19351965
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1966+
any_gpus_without_batched_cublas = any_gpus_without_batched_cublas || !(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
19361967
}
19371968

19381969
// debug helpers
@@ -1951,7 +1982,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19511982
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
19521983
} else if (!split && use_mul_mat_q) {
19531984
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1954-
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1985+
} else if (!split && !any_gpus_without_batched_cublas && src0->type == GGML_TYPE_F16 &&
1986+
(src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
19551987
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
19561988
// general KQ + KQV multi-batch without FlashAttention
19571989
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
@@ -2989,12 +3021,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29893021
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
29903022
return false;
29913023
}
2992-
#ifdef GGML_USE_MUSA
3024+
#if defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1
29933025
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
29943026
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
29953027
return false;
29963028
}
2997-
#endif // GGML_USE_MUSA
3029+
#endif // defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1
29983030
switch (a->type) {
29993031
case GGML_TYPE_F32:
30003032
case GGML_TYPE_F16:
@@ -3019,11 +3051,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30193051
case GGML_TYPE_IQ4_NL:
30203052
case GGML_TYPE_IQ4_XS:
30213053
case GGML_TYPE_BF16:
3022-
#ifdef GGML_USE_MUSA
3023-
if (a->type == GGML_TYPE_Q3_K) {
3054+
#if defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1
3055+
if (a->type == GGML_TYPE_Q2_K) {
30243056
return false;
30253057
}
3026-
#endif // GGML_USE_MUSA
3058+
#endif // defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1
30273059
return true;
30283060
default:
30293061
return false;

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4277,6 +4277,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
42774277
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
42784278
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
42794279

4280+
#if !defined(GGML_USE_MUSA) || (defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
42804281
for (auto bs : {1,2,4,8}) {
42814282
for (auto nr : {1,4}) {
42824283
for (uint32_t m = 0; m < 2; ++m) {
@@ -4287,6 +4288,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
42874288
}
42884289
}
42894290
}
4291+
#endif // !defined(__MUSA_ARCH__) || __MUSA_ARCH__ > 210
42904292

42914293
// sycl backend will limit task global_range < MAX_INT
42924294
// test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)

0 commit comments

Comments
 (0)