@@ -259,7 +259,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
259259                const  int64_t  m_start      = 0 ;
260260
261261                const  int64_t  n_step      = static_cast <int64_t >(kernel->get_n_step ());
262-                 const  int64_t  num_threads = KAI_MIN (n / n_step, nth);
262+                 int64_t  num_threads       = KAI_MIN (n / n_step, nth);
263+                 if  (num_threads <= 0 ) {
264+                     num_threads = 1 ;
265+                 }
263266
264267                if  (ith < num_threads) {
265268                    const  int64_t  num_n_per_thread0   = round_down (n / num_threads, n_step);
@@ -309,7 +312,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
309312        GGML_ASSERT (kernel);
310313
311314        const  int  ith = params->ith ;
312-         const  int  nth = params->nth ;
315+         const  int  nth_raw = params->nth ;
316+         const  int  nth = nth_raw > 0  ? nth_raw : 1 ;
313317
314318        const  size_t  k = ne00;
315319        const  size_t  m = ne11;
@@ -327,9 +331,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
327331        const  size_t  num_n_per_thread = kai_roundup (kai_roundup (n, nth) / nth, n_step);
328332        const  size_t  n_start = ith * num_n_per_thread;
329333
330-         size_t  n_to_process = num_n_per_thread;
331-         if  ((n_start + n_to_process) > n) {
332-             n_to_process = n - n_start;
334+         size_t  n_to_process = 0 ;
335+         if  (n_start < n) {
336+             n_to_process = num_n_per_thread;
337+             if  ((n_start + n_to_process) > n) {
338+                 n_to_process = n - n_start;
339+             }
333340        }
334341
335342        //  Calculate number of columns to be processed per thread
@@ -361,8 +368,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
361368        const  void * lhs_ptr            = (const  void *)((const  char  *)lhs_packed + lhs_packed_offset);
362369        float  *dst_ptr                 = reinterpret_cast <float  *>(static_cast <uint8_t  *>(dst->data ) + dst_offset);
363370
364-         variant_call<void >(kernel->run_kernel , m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
365-                            sizeof (float ), -FLT_MAX, FLT_MAX);
371+         if  (n_to_process > 0 ) {
372+             variant_call<void >(kernel->run_kernel , m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
373+                                sizeof (float ), -FLT_MAX, FLT_MAX);
374+         }
366375
367376        return  true ;
368377    }
0 commit comments