Skip to content

Commit 12b9902

Browse files
ikawrakowIwan Kawrakow
andauthored
Fix PPL increase caused by mmq_id (#913)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 6a805c7 commit 12b9902

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

ggml/src/ggml-cuda/mmq_id_common.cuh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3960,7 +3960,10 @@ template <ggml_type type, int mmq_x>
39603960
static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) {
39613961
const int id = ggml_cuda_get_device();
39623962
const int cc = ggml_cuda_info().devices[id].cc;
3963-
const int nsm = ggml_cuda_info().devices[id].nsm;
3963+
const int nsm_max = ggml_cuda_info().devices[id].nsm;
3964+
int nsm = 1;
3965+
//while (nsm*2 <= nsm_max) nsm *= 2;
3966+
while (nsm < nsm_max) nsm *= 2;
39643967
const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size;
39653968
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
39663969
const int mmq_y = get_mmq_y_host(cc);

0 commit comments

Comments
 (0)