@@ -1200,7 +1200,9 @@ static void ggml_cuda_op_mul_mat_cublas(
12001200
12011201 const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT;
12021202
1203- if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ]) {
1203+ if ((GGML_CUDA_CC_IS_NVIDIA (cc) || GGML_CUDA_CC_IS_AMD (cc) ||
1204+ (GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2)) &&
1205+ src0->type == GGML_TYPE_BF16 && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ]) {
12041206 ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16 (ctx.pool (id));
12051207 if (src1->type != GGML_TYPE_BF16) {
12061208 const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda (src1->type );
@@ -1228,7 +1230,9 @@ static void ggml_cuda_op_mul_mat_cublas(
12281230
12291231 const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_BF16);
12301232 to_fp32_cuda (dst_bf16.get (), dst_dd_i, row_diff*src1_ncols, stream);
1231- } else if (((GGML_CUDA_CC_IS_NVIDIA (cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD (cc)) && use_fp16) {
1233+ } else if (((GGML_CUDA_CC_IS_NVIDIA (cc) && cc >= GGML_CUDA_CC_VOLTA) ||
1234+ (GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2) ||
1235+ GGML_CUDA_CC_IS_AMD (cc)) && use_fp16) {
12321236 // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
12331237 ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool (id));
12341238 if (src0->type != GGML_TYPE_F16) {
@@ -1872,13 +1876,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18721876 // use cublasGemmBatchedEx
18731877 const int64_t ne23 = ne12*ne13;
18741878
1879+ #ifdef GGML_USE_MUSA
1880+ const void ** ptrs_src;
1881+ void ** ptrs_dst;
1882+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_src, sizeof (void *)*2 *ne23));
1883+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_dst, sizeof (void *)*1 *ne23));
1884+ #else // GGML_USE_MUSA
18751885 ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
18761886 ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
1887+ #endif // GGML_USE_MUSA
18771888
18781889 dim3 block_dims (ne13, ne12);
18791890 k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
18801891 src0_f16, src1_f16, dst_t ,
1892+ #ifdef GGML_USE_MUSA
1893+ ptrs_src, ptrs_dst,
1894+ #else // GGML_USE_MUSA
18811895 ptrs_src.get (), ptrs_dst.get (),
1896+ #endif // GGML_USE_MUSA
18821897 ne12, ne13,
18831898 ne23,
18841899 nb02, nb03,
@@ -1888,15 +1903,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18881903 r2, r3);
18891904 CUDA_CHECK (cudaGetLastError ());
18901905
1891- CUBLAS_CHECK (
1906+ #ifdef GGML_USE_MUSA
1907+ CUDA_CHECK (cudaDeviceSynchronize ());
1908+ const void **Aarray = (const void **) (ptrs_src + 0 *ne23);
1909+ const void **Barray = (const void **) (ptrs_src + 1 *ne23);
1910+ void **Carray = ( void **) (ptrs_dst + 0 *ne23);
1911+ #else // GGML_USE_MUSA
1912+ const void **Aarray = (const void **) (ptrs_src.get () + 0 *ne23);
1913+ const void **Barray = (const void **) (ptrs_src.get () + 1 *ne23);
1914+ void **Carray = ( void **) (ptrs_dst.get () + 0 *ne23);
1915+ #endif // GGML_USE_MUSA
1916+
1917+ CUBLAS_CHECK (
18921918 cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
18931919 ne01, ne11, ne10,
1894- alpha, ( const void **) (ptrs_src. get () + 0 *ne23) , CUDA_R_16F, nb01/nb00,
1895- ( const void **) (ptrs_src. get () + 1 *ne23) , CUDA_R_16F, s11,
1896- beta, ( void **) (ptrs_dst. get () + 0 *ne23) , cu_data_type, ne0,
1920+ alpha, Aarray , CUDA_R_16F, nb01/nb00,
1921+ Barray , CUDA_R_16F, s11,
1922+ beta, Carray , cu_data_type, ne0,
18971923 ne23,
18981924 cu_compute_type,
18991925 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1926+
1927+ #ifdef GGML_USE_MUSA
1928+ CUDA_CHECK (cudaFree (ptrs_src));
1929+ CUDA_CHECK (cudaFree (ptrs_dst));
1930+ #endif // GGML_USE_MUSA
19001931 }
19011932#endif
19021933
@@ -1920,6 +1951,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19201951
19211952 bool any_gpus_with_slow_fp16 = false ;
19221953 bool any_gpus_without_fp16_mma = false ;
1954+ bool any_gpus_without_cublas_gemm = false ;
19231955
19241956 if (split) {
19251957 ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer ->buft ->context ;
@@ -1930,16 +1962,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19301962 continue ;
19311963 }
19321964
1933- const int cc = ggml_cuda_info ().devices [id].cc ;
1934- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
1935- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
1936- any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1965+ const int cc = ggml_cuda_info ().devices [id].cc ;
1966+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
1967+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
1968+ any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1969+ any_gpus_without_cublas_gemm = any_gpus_without_cublas_gemm || !(GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2);
19371970 }
19381971 } else {
1939- const int cc = ggml_cuda_info ().devices [ctx.device ].cc ;
1940- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
1941- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
1942- any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1972+ const int cc = ggml_cuda_info ().devices [ctx.device ].cc ;
1973+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
1974+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
1975+ any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1976+ any_gpus_without_cublas_gemm = any_gpus_without_cublas_gemm || !(GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2);
19431977 }
19441978
19451979 // debug helpers
@@ -1958,8 +1992,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19581992 ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
19591993 } else if (!split && use_mul_mat_q) {
19601994 ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
1961- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1962- !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
1995+ } else if (!split && !any_gpus_without_cublas_gemm && src0->type == GGML_TYPE_F16 &&
1996+ (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1997+ !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
19631998 // general KQ + KQV multi-batch without FlashAttention
19641999 ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
19652000 } else if (use_mul_mat_vec) {
@@ -2999,9 +3034,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29993034 return false ;
30003035 }
30013036#ifdef GGML_USE_MUSA
3002- if (b->type == GGML_TYPE_F16 && b->ne [2 ]*b->ne [3 ] > 1 &&
3037+ const int cc = ggml_cuda_info ().devices [dev_ctx->device ].cc ;
3038+ if (GGML_CUDA_CC_IS_MTHREADS (cc) && b->ne [2 ]*b->ne [3 ] > 1 &&
30033039 !ggml_is_transposed (a) && !ggml_is_transposed (b)) {
3004- return false ;
3040+ if (GGML_CUDA_CC_IS_QY1 (cc) && op->op == GGML_OP_MUL_MAT &&
3041+ a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
3042+ return false ;
3043+ }
3044+ if (GGML_CUDA_CC_IS_QY2 (cc) && op->op == GGML_OP_MUL_MAT_ID &&
3045+ a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
3046+ return false ;
3047+ }
30053048 }
30063049#endif // GGML_USE_MUSA
30073050 switch (a->type ) {
@@ -3028,11 +3071,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30283071 case GGML_TYPE_IQ4_NL:
30293072 case GGML_TYPE_IQ4_XS:
30303073 case GGML_TYPE_BF16:
3031- #ifdef GGML_USE_MUSA
3032- if (a->type == GGML_TYPE_Q3_K) {
3033- return false ;
3034- }
3035- #endif // GGML_USE_MUSA
30363074 return true ;
30373075 default :
30383076 return false ;
0 commit comments