@@ -259,7 +259,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
259259 const int64_t m_start = 0 ;
260260
261261 const int64_t n_step = static_cast <int64_t >(kernel->get_n_step ());
262- const int64_t num_threads = KAI_MIN (n / n_step, nth);
262+ int64_t num_threads = KAI_MIN (n / n_step, nth);
263+ if (num_threads <= 0 ) {
264+ num_threads = 1 ;
265+ }
263266
264267 if (ith < num_threads) {
265268 const int64_t num_n_per_thread0 = round_down (n / num_threads, n_step);
@@ -309,7 +312,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
309312 GGML_ASSERT (kernel);
310313
311314 const int ith = params->ith ;
312- const int nth = params->nth ;
315+ const int nth_raw = params->nth ;
316+ const int nth = nth_raw > 0 ? nth_raw : 1 ;
313317
314318 const size_t k = ne00;
315319 const size_t m = ne11;
@@ -327,9 +331,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
327331 const size_t num_n_per_thread = kai_roundup (kai_roundup (n, nth) / nth, n_step);
328332 const size_t n_start = ith * num_n_per_thread;
329333
330- size_t n_to_process = num_n_per_thread;
331- if ((n_start + n_to_process) > n) {
332- n_to_process = n - n_start;
334+ size_t n_to_process = 0 ;
335+ if (n_start < n) {
336+ n_to_process = num_n_per_thread;
337+ if ((n_start + n_to_process) > n) {
338+ n_to_process = n - n_start;
339+ }
333340 }
334341
335342 // Calculate number of columns to be processed per thread
@@ -361,8 +368,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
361368 const void * lhs_ptr = (const void *)((const char *)lhs_packed + lhs_packed_offset);
362369 float *dst_ptr = reinterpret_cast <float *>(static_cast <uint8_t *>(dst->data ) + dst_offset);
363370
364- variant_call<void >(kernel->run_kernel , m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
365- sizeof (float ), -FLT_MAX, FLT_MAX);
371+ if (n_to_process > 0 ) {
372+ variant_call<void >(kernel->run_kernel , m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
373+ sizeof (float ), -FLT_MAX, FLT_MAX);
374+ }
366375
367376 return true ;
368377 }
0 commit comments