@@ -1851,13 +1851,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18511851 // use cublasGemmBatchedEx
18521852 const int ne23 = ne12*ne13;
18531853
1854+ #ifdef GGML_USE_MUSA
1855+ const void ** ptrs_src;
1856+ void ** ptrs_dst;
1857+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_src, sizeof (void *)*2 *ne23));
1858+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_dst, sizeof (void *)*1 *ne23));
1859+ #else // GGML_USE_MUSA
18541860 ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
18551861 ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
1862+ #endif // GGML_USE_MUSA
18561863
18571864 dim3 block_dims (ne13, ne12);
18581865 k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
18591866 src0_f16, src1_f16, dst_t ,
1867+ #ifdef GGML_USE_MUSA
1868+ ptrs_src, ptrs_dst,
1869+ #else // GGML_USE_MUSA
18601870 ptrs_src.get (), ptrs_dst.get (),
1871+ #endif // GGML_USE_MUSA
18611872 ne12, ne13,
18621873 ne23,
18631874 nb02, nb03,
@@ -1867,15 +1878,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18671878 r2, r3);
18681879 CUDA_CHECK (cudaGetLastError ());
18691880
1870- CUBLAS_CHECK (
1881+ #ifdef GGML_USE_MUSA
1882+ cudaDeviceSynchronize ();
1883+ const void **Aarray = (const void **) (ptrs_src + 0 * ne23);
1884+ const void **Barray = (const void **) (ptrs_src + 1 * ne23);
1885+ void **Carray = (void **) (ptrs_dst + 0 * ne23);
1886+ #else // GGML_USE_MUSA
1887+ const void **Aarray = (const void **) (ptrs_src.get () + 0 * ne23);
1888+ const void **Barray = (const void **) (ptrs_src.get () + 1 * ne23);
1889+ void **Carray = (void **) (ptrs_dst.get () + 0 * ne23);
1890+ #endif // GGML_USE_MUSA
1891+
1892+ CUBLAS_CHECK (
18711893 cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
18721894 ne01, ne11, ne10,
1873- alpha, ( const void **) (ptrs_src. get () + 0 *ne23) , CUDA_R_16F, nb01/nb00,
1874- ( const void **) (ptrs_src. get () + 1 *ne23) , CUDA_R_16F, nb11/nb10,
1875- beta, ( void **) (ptrs_dst. get () + 0 *ne23) , cu_data_type, ne01,
1895+ alpha, Aarray , CUDA_R_16F, nb01/nb00,
1896+ Barray , CUDA_R_16F, nb11/nb10,
1897+ beta, Carray , cu_data_type, ne01,
18761898 ne23,
18771899 cu_compute_type,
18781900 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1901+
1902+ #ifdef GGML_USE_MUSA
1903+ CUDA_CHECK (cudaFree (ptrs_src));
1904+ CUDA_CHECK (cudaFree (ptrs_dst));
1905+ #endif // GGML_USE_MUSA
18791906 }
18801907#endif
18811908
@@ -3011,12 +3038,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30113038 if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
30123039 return false ;
30133040 }
3014- #ifdef GGML_USE_MUSA
3015- if (b->type == GGML_TYPE_F16 && b->ne [2 ]*b->ne [3 ] > 1 &&
3016- !ggml_is_transposed (a) && !ggml_is_transposed (b)) {
3017- return false ;
3018- }
3019- #endif // GGML_USE_MUSA
3041+ // #ifdef GGML_USE_MUSA
3042+ // if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3043+ // !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3044+ // return false;
3045+ // }
3046+ // #endif // GGML_USE_MUSA
30203047 switch (a->type ) {
30213048 case GGML_TYPE_F32:
30223049 case GGML_TYPE_F16:
@@ -3041,11 +3068,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30413068 case GGML_TYPE_IQ4_NL:
30423069 case GGML_TYPE_IQ4_XS:
30433070 case GGML_TYPE_BF16:
3044- #ifdef GGML_USE_MUSA
3045- if (a->type == GGML_TYPE_Q3_K) {
3046- return false ;
3047- }
3048- #endif // GGML_USE_MUSA
3071+ // #ifdef GGML_USE_MUSA
3072+ // if (a->type == GGML_TYPE_Q3_K) {
3073+ // return false;
3074+ // }
3075+ // #endif // GGML_USE_MUSA
30493076 return true ;
30503077 default :
30513078 return false ;
0 commit comments