@@ -136,7 +136,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
136136 return res;
137137#else
138138
139- #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
139+ #if !defined(GGML_USE_HIPBLAS)
140140 cudaError_t err;
141141 if (getenv (" GGML_CUDA_ENABLE_UNIFIED_MEMORY" ) != nullptr )
142142 {
@@ -149,7 +149,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
149149 return err;
150150#else
151151 return cudaMalloc (ptr, size);
152- #endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
152+ #endif // !defined(GGML_USE_HIPBLAS)
153153
154154#endif
155155}
@@ -2830,6 +2830,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28302830 if (op->op == GGML_OP_MUL_MAT && a->ne [3 ] != b->ne [3 ]) {
28312831 return false ;
28322832 }
2833+ #ifdef GGML_USE_MUSA
2834+ if (b->type == GGML_TYPE_F16 && b->ne [2 ]*b->ne [3 ] > 1 &&
2835+ !ggml_is_transposed (a) && !ggml_is_transposed (b)) {
2836+ return false ;
2837+ }
2838+ #endif // GGML_USE_MUSA
28332839 switch (a->type ) {
28342840 case GGML_TYPE_F32:
28352841 case GGML_TYPE_F16:
@@ -2853,6 +2859,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28532859 case GGML_TYPE_IQ3_XXS:
28542860 case GGML_TYPE_IQ4_NL:
28552861 case GGML_TYPE_IQ4_XS:
2862+ #ifdef GGML_USE_MUSA
2863+ if (a->type == GGML_TYPE_Q3_K) {
2864+ return false ;
2865+ }
2866+ #endif // GGML_USE_MUSA
28562867 return true ;
28572868 default :
28582869 return false ;
@@ -2978,6 +2989,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
29782989 case GGML_OP_RWKV_WKV:
29792990 return true ;
29802991 case GGML_OP_FLASH_ATTN_EXT: {
2992+ #ifndef FLASH_ATTN_AVAILABLE
2993+ return false ;
2994+ #endif
29812995 if (op->src [0 ]->ne [0 ] == 64 && op->src [1 ]->type == GGML_TYPE_F16) {
29822996 return true ;
29832997 }
0 commit comments