@@ -252,21 +252,25 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
252252#endif // AMD_MFMA_AVAILABLE
253253
254254#if defined(GGML_USE_HIP)
255- static int mmq_get_nwarps_host (const int cc, const int warp_size ) {
256- return amd_mfma_available (cc) ? 8 : 256 /warp_size ;
255+ static int mmq_get_nwarps_host (const int cc) {
256+ return amd_mfma_available (cc) ? 8 : 4 ;
257257}
258258#else
259- static int mmq_get_nwarps_host (const int /* cc*/ , const int warp_size ) {
260- return 256 /warp_size ;
259+ static int mmq_get_nwarps_host (const int /* cc*/ ) {
260+ return 8 ;
261261}
262262#endif // (GGML_USE_HIP)
263263
264264static constexpr __device__ int mmq_get_nwarps_device () {
265+ #if defined(GGML_USE_HIP)
265266#if defined(AMD_MFMA_AVAILABLE)
266267 return 8 ;
267268#else
268- return 256 / ggml_cuda_get_physical_warp_size () ;
269+ return 4 ;
269270#endif // AMD_MFMA_AVAILABLE
271+ #else
272+ return 8 ;
273+ #endif // defined(GGML_USE_HIP)
270274}
271275
272276// ------------------------------------------------------------
@@ -3469,7 +3473,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
34693473 const int cc = ggml_cuda_info ().devices [id].cc ;
34703474 const int nsm = ggml_cuda_info ().devices [id].nsm ;
34713475 const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
3472- const int nwarps = mmq_get_nwarps_host (cc, warp_size );
3476+ const int nwarps = mmq_get_nwarps_host (cc);
34733477 const int mmq_y = get_mmq_y_host (cc);
34743478
34753479 const dim3 block_dims (warp_size, nwarps, 1 );
@@ -3556,7 +3560,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
35563560 const int cc = ggml_cuda_info ().devices [id].cc ;
35573561 const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
35583562 const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
3559- const int nwarps = mmq_get_nwarps_host (cc, warp_size );
3563+ const int nwarps = mmq_get_nwarps_host (cc);
35603564
35613565 const int mmq_x_max = get_mmq_x_max_host (cc);
35623566 const int mmq_y = get_mmq_y_host (cc);
0 commit comments