@@ -1865,13 +1865,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18651865        //  use cublasGemmBatchedEx
18661866        const  int64_t  ne23 = ne12*ne13;
18671867
1868+ #ifdef  GGML_USE_MUSA
1869+         const  void  ** ptrs_src;
1870+         void  ** ptrs_dst;
1871+         CUDA_CHECK (cudaMalloc ((void  **)&ptrs_src, sizeof (void  *)*2 *ne23));
1872+         CUDA_CHECK (cudaMalloc ((void  **)&ptrs_dst, sizeof (void  *)*1 *ne23));
1873+ #else  //  GGML_USE_MUSA
18681874        ggml_cuda_pool_alloc<const  void  *> ptrs_src (ctx.pool (), 2 *ne23);
18691875        ggml_cuda_pool_alloc<      void  *> ptrs_dst (ctx.pool (), 1 *ne23);
1876+ #endif  //  GGML_USE_MUSA
18701877
18711878        dim3  block_dims (ne13, ne12);
18721879        k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
18731880                src0_f16, src1_f16, dst_t ,
1881+ #ifdef  GGML_USE_MUSA
1882+                 ptrs_src, ptrs_dst,
1883+ #else  //  GGML_USE_MUSA
18741884                ptrs_src.get (), ptrs_dst.get (),
1885+ #endif  //  GGML_USE_MUSA
18751886                ne12, ne13,
18761887                ne23,
18771888                nb02, nb03,
@@ -1881,15 +1892,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18811892                r2, r3);
18821893        CUDA_CHECK (cudaGetLastError ());
18831894
1884-         CUBLAS_CHECK (
1895+ #ifdef  GGML_USE_MUSA
1896+         cudaDeviceSynchronize ();
1897+         const  void  **Aarray = (const  void  **) (ptrs_src + 0 *ne23);
1898+         const  void  **Barray = (const  void  **) (ptrs_src + 1 *ne23);
1899+               void  **Carray = (      void  **) (ptrs_dst + 0 *ne23);
1900+ #else  //  GGML_USE_MUSA
1901+         const  void  **Aarray = (const  void  **) (ptrs_src.get () + 0 *ne23);
1902+         const  void  **Barray = (const  void  **) (ptrs_src.get () + 1 *ne23);
1903+               void  **Carray = (      void  **) (ptrs_dst.get () + 0 *ne23);
1904+ #endif  //  GGML_USE_MUSA
1905+ 
1906+        CUBLAS_CHECK (
18851907        cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
18861908                ne01, ne11, ne10,
1887-                 alpha, ( const   void  **) (ptrs_src. get () +  0 *ne23) , CUDA_R_16F,   nb01/nb00,
1888-                        ( const   void  **) (ptrs_src. get () +  1 *ne23) , CUDA_R_16F,   s11,
1889-                 beta,  (       void  **) (ptrs_dst. get () +  0 *ne23) , cu_data_type, ne0,
1909+                 alpha, Aarray , CUDA_R_16F,   nb01/nb00,
1910+                        Barray , CUDA_R_16F,   s11,
1911+                 beta,  Carray , cu_data_type, ne0,
18901912                ne23,
18911913                cu_compute_type,
18921914                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1915+ 
1916+ #ifdef  GGML_USE_MUSA
1917+         CUDA_CHECK (cudaFree (ptrs_src));
1918+         CUDA_CHECK (cudaFree (ptrs_dst));
1919+ #endif  //  GGML_USE_MUSA
18931920    }
18941921#endif 
18951922
@@ -2989,12 +3016,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29893016                if  (b->type  == GGML_TYPE_F16 && a->type  != GGML_TYPE_F16) {
29903017                    return  false ;
29913018                }
2992- #ifdef  GGML_USE_MUSA
2993-                 if  (b->type  == GGML_TYPE_F16 && b->ne [2 ]*b->ne [3 ] > 1  &&
2994-                     !ggml_is_transposed (a) && !ggml_is_transposed (b)) {
2995-                     return  false ;
2996-                 }
2997- #endif  //  GGML_USE_MUSA
3019+ //   #ifdef GGML_USE_MUSA
3020+ //                   if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3021+ //                       !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3022+ //                       return false;
3023+ //                   }
3024+ //   #endif // GGML_USE_MUSA
29983025                switch  (a->type ) {
29993026                    case  GGML_TYPE_F32:
30003027                    case  GGML_TYPE_F16:
@@ -3019,11 +3046,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30193046                    case  GGML_TYPE_IQ4_NL:
30203047                    case  GGML_TYPE_IQ4_XS:
30213048                    case  GGML_TYPE_BF16:
3022- #ifdef  GGML_USE_MUSA
3023-                         if  (a->type  == GGML_TYPE_Q3_K) {
3024-                             return  false ;
3025-                         }
3026- #endif  //  GGML_USE_MUSA
3049+ //   #ifdef GGML_USE_MUSA
3050+ //                           if (a->type == GGML_TYPE_Q3_K) {
3051+ //                               return false;
3052+ //                           }
3053+ //   #endif // GGML_USE_MUSA
30273054                        return  true ;
30283055                    default :
30293056                        return  false ;
0 commit comments