@@ -109,8 +109,8 @@ void ggml_cuda_mul_mat_q(
109109 const int64_t s03 = src0->nb [3 ] / ts_src0;
110110 const int64_t s3 = dst->nb [3 ] / ts_dst;
111111
112- const bool use_stream_k = (( GGML_CUDA_CC_IS_NVIDIA (cc) && ggml_cuda_highest_compiled_arch (cc) >= GGML_CUDA_CC_VOLTA)
113- || ( GGML_CUDA_CC_IS_AMD (cc) && GGML_CUDA_CC_IS_CDNA3 (cc)) );
112+ const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA (cc) && ggml_cuda_highest_compiled_arch (cc) >= GGML_CUDA_CC_VOLTA)
113+ || GGML_CUDA_CC_IS_CDNA (cc );
114114
115115 if (!ids) {
116116 const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof (block_q8_1)/QK8_1 +
@@ -252,7 +252,7 @@ void ggml_cuda_op_mul_mat_q(
252252 // Also its fixup needs to allocate a temporary buffer in the memory pool.
253253 // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
254254 const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA (cc) && ggml_cuda_highest_compiled_arch (cc) >= GGML_CUDA_CC_VOLTA)
255- || ( GGML_CUDA_CC_IS_AMD (cc) && GGML_CUDA_CC_IS_CDNA3 (cc) ))
255+ || GGML_CUDA_CC_IS_CDNA (cc ))
256256 && src1_ncols == ne11;
257257 const mmq_args args = {
258258 src0_dd_i, src0->type , (const int *) src1_ddq_i, nullptr , nullptr , dst_dd_i,
@@ -306,7 +306,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
306306 return false ;
307307 }
308308
309- if (new_mma_available (cc) || amd_mfma_available (cc) ) {
309+ if (new_mma_available (cc)) {
310310 return true ;
311311 }
312312
@@ -322,5 +322,21 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
322322 return !fp16_mma_hardware_available (cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
323323 }
324324
325+ if (amd_mfma_available (cc)) {
326+ // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)
327+ // performs better but is currently suffering from a crash on this architecture.
328+ // TODO: Revisit when hipblaslt is fixed on CDNA3
329+ if (GGML_CUDA_CC_IS_CDNA3 (cc)) {
330+ return true ;
331+ }
332+ if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
333+ return true ;
334+ }
335+ if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
336+ return true ;
337+ }
338+ return false ;
339+ }
340+
325341 return (!GGML_CUDA_CC_IS_RDNA4 (cc) && !GGML_CUDA_CC_IS_RDNA3 (cc) && !GGML_CUDA_CC_IS_CDNA (cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
326342}
0 commit comments