@@ -1195,7 +1195,8 @@ static void ggml_cuda_op_mul_mat_cublas(
11951195
11961196    const  bool  use_fp16 = (src0->type  == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT;
11971197
1198-     if  (src0->type  == GGML_TYPE_BF16 && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ]) {
1198+     if  ((GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2) &&
1199+         src0->type  == GGML_TYPE_BF16 && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ]) {
11991200        ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16 (ctx.pool (id));
12001201        if  (src1->type  != GGML_TYPE_BF16) {
12011202            const  to_bf16_cuda_t  to_bf16_cuda = ggml_get_to_bf16_cuda (src1->type );
@@ -1865,13 +1866,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18651866        //  use cublasGemmBatchedEx
18661867        const  int64_t  ne23 = ne12*ne13;
18671868
1869+ #ifdef  GGML_USE_MUSA
1870+         const  void  ** ptrs_src;
1871+         void  ** ptrs_dst;
1872+         CUDA_CHECK (cudaMalloc ((void  **)&ptrs_src, sizeof (void  *)*2 *ne23));
1873+         CUDA_CHECK (cudaMalloc ((void  **)&ptrs_dst, sizeof (void  *)*1 *ne23));
1874+ #else  //  GGML_USE_MUSA
18681875        ggml_cuda_pool_alloc<const  void  *> ptrs_src (ctx.pool (), 2 *ne23);
18691876        ggml_cuda_pool_alloc<      void  *> ptrs_dst (ctx.pool (), 1 *ne23);
1877+ #endif  //  GGML_USE_MUSA
18701878
18711879        dim3  block_dims (ne13, ne12);
18721880        k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
18731881                src0_f16, src1_f16, dst_t ,
1882+ #ifdef  GGML_USE_MUSA
1883+                 ptrs_src, ptrs_dst,
1884+ #else  //  GGML_USE_MUSA
18741885                ptrs_src.get (), ptrs_dst.get (),
1886+ #endif  //  GGML_USE_MUSA
18751887                ne12, ne13,
18761888                ne23,
18771889                nb02, nb03,
@@ -1881,15 +1893,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18811893                r2, r3);
18821894        CUDA_CHECK (cudaGetLastError ());
18831895
1884-         CUBLAS_CHECK (
1896+ #ifdef  GGML_USE_MUSA
1897+         cudaDeviceSynchronize ();
1898+         const  void  **Aarray = (const  void  **) (ptrs_src + 0 *ne23);
1899+         const  void  **Barray = (const  void  **) (ptrs_src + 1 *ne23);
1900+               void  **Carray = (      void  **) (ptrs_dst + 0 *ne23);
1901+ #else  //  GGML_USE_MUSA
1902+         const  void  **Aarray = (const  void  **) (ptrs_src.get () + 0 *ne23);
1903+         const  void  **Barray = (const  void  **) (ptrs_src.get () + 1 *ne23);
1904+               void  **Carray = (      void  **) (ptrs_dst.get () + 0 *ne23);
1905+ #endif  //  GGML_USE_MUSA
1906+ 
1907+        CUBLAS_CHECK (
18851908        cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
18861909                ne01, ne11, ne10,
1887-                 alpha, ( const   void  **) (ptrs_src. get () +  0 *ne23) , CUDA_R_16F,   nb01/nb00,
1888-                        ( const   void  **) (ptrs_src. get () +  1 *ne23) , CUDA_R_16F,   s11,
1889-                 beta,  (       void  **) (ptrs_dst. get () +  0 *ne23) , cu_data_type, ne0,
1910+                 alpha, Aarray , CUDA_R_16F,   nb01/nb00,
1911+                        Barray , CUDA_R_16F,   s11,
1912+                 beta,  Carray , cu_data_type, ne0,
18901913                ne23,
18911914                cu_compute_type,
18921915                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1916+ 
1917+ #ifdef  GGML_USE_MUSA
1918+         CUDA_CHECK (cudaFree (ptrs_src));
1919+         CUDA_CHECK (cudaFree (ptrs_dst));
1920+ #endif  //  GGML_USE_MUSA
18931921    }
18941922#endif 
18951923
@@ -1913,6 +1941,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19131941
19141942    bool  any_gpus_with_slow_fp16   = false ;
19151943    bool  any_gpus_without_fp16_mma = false ;
1944+     bool  any_gpus_without_batched_cublas = false ;
19161945
19171946    if  (split) {
19181947        ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer ->buft ->context ;
@@ -1927,12 +1956,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19271956            use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
19281957            any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available (cc);
19291958            any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1959+             any_gpus_without_batched_cublas = any_gpus_without_batched_cublas || !(GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2);
19301960        }
19311961    } else  {
19321962        const  int  cc              = ggml_cuda_info ().devices [ctx.device ].cc ;
19331963        use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
19341964        any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available (cc);
19351965        any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1966+         any_gpus_without_batched_cublas = any_gpus_without_batched_cublas || !(GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2);
19361967    }
19371968
19381969    //  debug helpers
@@ -1951,7 +1982,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19511982        ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
19521983    } else  if  (!split && use_mul_mat_q) {
19531984        ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
1954-     } else  if  (!split && src0->type  == GGML_TYPE_F16 && (src1->type  == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1985+     } else  if  (!split && !any_gpus_without_batched_cublas && src0->type  == GGML_TYPE_F16 &&
1986+             (src1->type  == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
19551987            !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
19561988        //  general KQ + KQV multi-batch without FlashAttention
19571989        ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
@@ -2989,12 +3021,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29893021                if  (b->type  == GGML_TYPE_F16 && a->type  != GGML_TYPE_F16) {
29903022                    return  false ;
29913023                }
2992- #ifdef   GGML_USE_MUSA
3024+ #if  defined( GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1 
29933025                if  (b->type  == GGML_TYPE_F16 && b->ne [2 ]*b->ne [3 ] > 1  &&
29943026                    !ggml_is_transposed (a) && !ggml_is_transposed (b)) {
29953027                    return  false ;
29963028                }
2997- #endif  //  GGML_USE_MUSA
3029+ #endif  //  defined( GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1 
29983030                switch  (a->type ) {
29993031                    case  GGML_TYPE_F32:
30003032                    case  GGML_TYPE_F16:
@@ -3019,11 +3051,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30193051                    case  GGML_TYPE_IQ4_NL:
30203052                    case  GGML_TYPE_IQ4_XS:
30213053                    case  GGML_TYPE_BF16:
3022- #ifdef   GGML_USE_MUSA
3023-                         if  (a->type  == GGML_TYPE_Q3_K ) {
3054+ #if  defined( GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1 
3055+                         if  (a->type  == GGML_TYPE_Q2_K ) {
30243056                            return  false ;
30253057                        }
3026- #endif  //  GGML_USE_MUSA
3058+ #endif  //  defined( GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1 
30273059                        return  true ;
30283060                    default :
30293061                        return  false ;
0 commit comments