@@ -251,25 +251,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
251251#endif // AMD_MFMA_AVAILABLE
252252
253253#if defined(GGML_USE_HIP)
254- static int mmq_get_nwarps_host (const int cc) {
255- return amd_mfma_available (cc) ? 8 : 4 ;
254+ static int mmq_get_nwarps_host (const int cc, const int warp_size ) {
255+ return amd_mfma_available (cc) ? 8 : 256 /warp_size ;
256256}
257257#else
258- static int mmq_get_nwarps_host (const int /* cc*/ ) {
259- return 8 ;
258+ static int mmq_get_nwarps_host (const int /* cc*/ , const int warp_size ) {
259+ return 256 /warp_size ;
260260}
261261#endif // (GGML_USE_HIP)
262262
263263static constexpr __device__ int mmq_get_nwarps_device () {
264- #if defined(GGML_USE_HIP)
265264#if defined(AMD_MFMA_AVAILABLE)
266265 return 8 ;
267266#else
268- return 4 ;
267+ return 256 / ggml_cuda_get_physical_warp_size () ;
269268#endif // AMD_MFMA_AVAILABLE
270- #else
271- return 8 ;
272- #endif // defined(GGML_USE_HIP)
273269}
274270
275271// ------------------------------------------------------------
@@ -3472,7 +3468,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
34723468 const int cc = ggml_cuda_info ().devices [id].cc ;
34733469 const int nsm = ggml_cuda_info ().devices [id].nsm ;
34743470 const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
3475- const int nwarps = mmq_get_nwarps_host (cc);
3471+ const int nwarps = mmq_get_nwarps_host (cc, warp_size );
34763472 const int mmq_y = get_mmq_y_host (cc);
34773473
34783474 const dim3 block_dims (warp_size, nwarps, 1 );
@@ -3559,7 +3555,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
35593555 const int cc = ggml_cuda_info ().devices [id].cc ;
35603556 const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
35613557 const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
3562- const int nwarps = mmq_get_nwarps_host (cc);
3558+ const int nwarps = mmq_get_nwarps_host (cc, warp_size );
35633559
35643560 const int mmq_x_max = get_mmq_x_max_host (cc);
35653561 const int mmq_y = get_mmq_y_host (cc);
0 commit comments