@@ -304,6 +304,14 @@ static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */
304304}
305305#endif // NEW_MMA_AVAILABLE
306306
307+ static int mmq_get_nwarps_host (const int /* cc*/ , const int warp_size) {
308+ return 256 /warp_size;
309+ }
310+
311+ static constexpr __device__ int mmq_get_nwarps_device () {
312+ return 256 /ggml_cuda_get_physical_warp_size ();
313+ }
314+
307315// ------------------------------------------------------------
308316
309317template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0 (
@@ -4141,6 +4149,10 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
41414149 const int id = ggml_cuda_get_device ();
41424150 const int cc = ggml_cuda_info ().devices [id].cc ;
41434151 const int nsm = ggml_cuda_info ().devices [id].nsm ;
4152+
4153+ const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
4154+ const int nwarps = mmq_get_nwarps_host (cc, warp_size);
4155+
41444156 const int mmq_y = get_mmq_y_host (cc);
41454157
41464158 const dim3 block_dims (WARP_SIZE, MMQ_NWARPS, 1 );
@@ -4198,6 +4210,9 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
41984210 const int cc = ggml_cuda_info ().devices [id].cc ;
41994211 const int smpbo = ggml_cuda_info ().devices [id].smpbo ;
42004212
4213+ const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
4214+ const int nwarps = mmq_get_nwarps_host (cc, warp_size);
4215+
42014216 const int mmq_x_max = get_mmq_x_max_host (cc);
42024217 const int mmq_y = get_mmq_y_host (cc);
42034218 const int block_num_y = (args.ne01 + mmq_y - 1 ) / mmq_y;
0 commit comments