@@ -268,8 +268,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
268268 // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
269269 info.devices [id].warp_size = 32 ;
270270 info.devices [id].smpbo = prop.sharedMemPerBlockOptin ;
271- info.devices [id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100 ;
272- info.devices [id].cc += prop.minor * 0x10 ;
271+ info.devices [id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + 100 *prop.major + 10 *prop.minor ;
273272 GGML_LOG_INFO (" Device %d: %s, compute capability %d.%d, VMM: %s\n " ,
274273 id, prop.name , prop.major , prop.minor , device_vmm ? " yes" : " no" );
275274#else
@@ -1195,7 +1194,8 @@ static void ggml_cuda_op_mul_mat_cublas(
11951194
11961195 const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT;
11971196
1198- if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ]) {
1197+ if ((GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2) &&
1198+ src0->type == GGML_TYPE_BF16 && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ]) {
11991199 ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16 (ctx.pool (id));
12001200 if (src1->type != GGML_TYPE_BF16) {
12011201 const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda (src1->type );
@@ -1865,13 +1865,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18651865 // use cublasGemmBatchedEx
18661866 const int64_t ne23 = ne12*ne13;
18671867
1868+ #ifdef GGML_USE_MUSA
1869+ const void ** ptrs_src;
1870+ void ** ptrs_dst;
1871+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_src, sizeof (void *)*2 *ne23));
1872+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_dst, sizeof (void *)*1 *ne23));
1873+ #else // GGML_USE_MUSA
18681874 ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
18691875 ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
1876+ #endif // GGML_USE_MUSA
18701877
18711878 dim3 block_dims (ne13, ne12);
18721879 k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
18731880 src0_f16, src1_f16, dst_t ,
1881+ #ifdef GGML_USE_MUSA
1882+ ptrs_src, ptrs_dst,
1883+ #else // GGML_USE_MUSA
18741884 ptrs_src.get (), ptrs_dst.get (),
1885+ #endif // GGML_USE_MUSA
18751886 ne12, ne13,
18761887 ne23,
18771888 nb02, nb03,
@@ -1881,15 +1892,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18811892 r2, r3);
18821893 CUDA_CHECK (cudaGetLastError ());
18831894
1884- CUBLAS_CHECK (
1895+ #ifdef GGML_USE_MUSA
1896+ cudaDeviceSynchronize ();
1897+ const void **Aarray = (const void **) (ptrs_src + 0 *ne23);
1898+ const void **Barray = (const void **) (ptrs_src + 1 *ne23);
1899+ void **Carray = ( void **) (ptrs_dst + 0 *ne23);
1900+ #else // GGML_USE_MUSA
1901+ const void **Aarray = (const void **) (ptrs_src.get () + 0 *ne23);
1902+ const void **Barray = (const void **) (ptrs_src.get () + 1 *ne23);
1903+ void **Carray = ( void **) (ptrs_dst.get () + 0 *ne23);
1904+ #endif // GGML_USE_MUSA
1905+
1906+ CUBLAS_CHECK (
18851907 cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
18861908 ne01, ne11, ne10,
1887- alpha, ( const void **) (ptrs_src. get () + 0 *ne23) , CUDA_R_16F, nb01/nb00,
1888- ( const void **) (ptrs_src. get () + 1 *ne23) , CUDA_R_16F, s11,
1889- beta, ( void **) (ptrs_dst. get () + 0 *ne23) , cu_data_type, ne0,
1909+ alpha, Aarray , CUDA_R_16F, nb01/nb00,
1910+ Barray , CUDA_R_16F, s11,
1911+ beta, Carray , cu_data_type, ne0,
18901912 ne23,
18911913 cu_compute_type,
18921914 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1915+
1916+ #ifdef GGML_USE_MUSA
1917+ CUDA_CHECK (cudaFree (ptrs_src));
1918+ CUDA_CHECK (cudaFree (ptrs_dst));
1919+ #endif // GGML_USE_MUSA
18931920 }
18941921#endif
18951922
@@ -1913,6 +1940,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19131940
19141941 bool any_gpus_with_slow_fp16 = false ;
19151942 bool any_gpus_without_fp16_mma = false ;
1943+ bool any_gpus_without_batched_cublas = false ;
19161944
19171945 if (split) {
19181946 ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer ->buft ->context ;
@@ -1927,12 +1955,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19271955 use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
19281956 any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
19291957 any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1958+ any_gpus_without_batched_cublas = any_gpus_without_batched_cublas || !(GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2);
19301959 }
19311960 } else {
19321961 const int cc = ggml_cuda_info ().devices [ctx.device ].cc ;
19331962 use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
19341963 any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
19351964 any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1965+ any_gpus_without_batched_cublas = any_gpus_without_batched_cublas || !(GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2);
19361966 }
19371967
19381968 // debug helpers
@@ -1951,7 +1981,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19511981 ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
19521982 } else if (!split && use_mul_mat_q) {
19531983 ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
1954- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1984+ } else if (!split && !any_gpus_without_batched_cublas && src0->type == GGML_TYPE_F16 &&
1985+ (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
19551986 !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
19561987 // general KQ + KQV multi-batch without FlashAttention
19571988 ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
@@ -2989,12 +3020,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29893020 if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
29903021 return false ;
29913022 }
2992- #ifdef GGML_USE_MUSA
2993- if (b->type == GGML_TYPE_F16 && b->ne [2 ]*b->ne [3 ] > 1 &&
3023+ #if defined(GGML_USE_MUSA)
3024+ const int cc = ggml_cuda_info ().devices [dev_ctx->device ].cc ;
3025+ if (GGML_CUDA_CC_IS_MTHREADS (cc) && GGML_CUDA_CC_IS_QY1 (cc) &&
3026+ b->type == GGML_TYPE_F16 && b->ne [2 ]*b->ne [3 ] > 1 &&
29943027 !ggml_is_transposed (a) && !ggml_is_transposed (b)) {
29953028 return false ;
29963029 }
2997- #endif // GGML_USE_MUSA
3030+ #endif // defined( GGML_USE_MUSA)
29983031 switch (a->type ) {
29993032 case GGML_TYPE_F32:
30003033 case GGML_TYPE_F16:
@@ -3019,11 +3052,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30193052 case GGML_TYPE_IQ4_NL:
30203053 case GGML_TYPE_IQ4_XS:
30213054 case GGML_TYPE_BF16:
3022- #ifdef GGML_USE_MUSA
3023- if (a->type == GGML_TYPE_Q3_K) {
3055+ #if defined(GGML_USE_MUSA)
3056+ if (GGML_CUDA_CC_IS_MTHREADS (cc) && GGML_CUDA_CC_IS_QY2 (cc) &&
3057+ a->type == GGML_TYPE_Q2_K) {
30243058 return false ;
30253059 }
3026- #endif // GGML_USE_MUSA
3060+ #endif // defined( GGML_USE_MUSA)
30273061 return true ;
30283062 default :
30293063 return false ;
0 commit comments