@@ -1851,13 +1851,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18511851 // use cublasGemmBatchedEx
18521852 const int ne23 = ne12*ne13;
18531853
1854+ #ifdef GGML_USE_MUSA
1855+ const void ** ptrs_src;
1856+ void ** ptrs_dst;
1857+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_src, sizeof (half *)*2 *ne23));
1858+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_dst, sizeof (half *)*1 *ne23));
1859+ #else // GGML_USE_MUSA
18541860 ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
18551861 ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
1862+ #endif // GGML_USE_MUSA
18561863
18571864 dim3 block_dims (ne13, ne12);
18581865 k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
18591866 src0_f16, src1_f16, dst_t ,
1867+ #ifdef GGML_USE_MUSA
1868+ ptrs_src, ptrs_dst,
1869+ #else // GGML_USE_MUSA
18601870 ptrs_src.get (), ptrs_dst.get (),
1871+ #endif // GGML_USE_MUSA
18611872 ne12, ne13,
18621873 ne23,
18631874 nb02, nb03,
@@ -1867,15 +1878,30 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18671878 r2, r3);
18681879 CUDA_CHECK (cudaGetLastError ());
18691880
1881+ #ifdef GGML_USE_MUSA
1882+ const void **Aarray = (const void **) (ptrs_src + 0 * ne23);
1883+ const void **Barray = (const void **) (ptrs_src + 1 * ne23);
1884+ void **Carray = (void **) (ptrs_dst + 0 * ne23);
1885+ #else // GGML_USE_MUSA
1886+ const void **Aarray = (const void **) (ptrs_src.get () + 0 * ne23);
1887+ const void **Barray = (const void **) (ptrs_src.get () + 1 * ne23);
1888+ void **Carray = (void **) (ptrs_dst.get () + 0 * ne23);
1889+ #endif // GGML_USE_MUSA
1890+
18701891 CUBLAS_CHECK (
18711892 cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
18721893 ne01, ne11, ne10,
1873- alpha, ( const void **) (ptrs_src. get () + 0 *ne23) , CUDA_R_16F, nb01/nb00,
1874- ( const void **) (ptrs_src. get () + 1 *ne23) , CUDA_R_16F, nb11/nb10,
1875- beta, ( void **) (ptrs_dst. get () + 0 *ne23) , cu_data_type, ne01,
1894+ alpha, Aarray , CUDA_R_16F, nb01/nb00,
1895+ Barray , CUDA_R_16F, nb11/nb10,
1896+ beta, Carray , cu_data_type, ne01,
18761897 ne23,
18771898 cu_compute_type,
18781899 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1900+
1901+ #ifdef GGML_USE_MUSA
1902+ CUDA_CHECK (cudaFree (ptrs_src));
1903+ CUDA_CHECK (cudaFree (ptrs_dst));
1904+ #endif // GGML_USE_MUSA
18791905 }
18801906#endif
18811907
0 commit comments