@@ -259,7 +259,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
259
259
const int64_t m_start = 0 ;
260
260
261
261
const int64_t n_step = static_cast <int64_t >(kernel->get_n_step ());
262
- const int64_t num_threads = KAI_MIN (n / n_step, nth);
262
+ int64_t num_threads = KAI_MIN (n / n_step, nth);
263
+ if (num_threads <= 0 ) {
264
+ num_threads = 1 ;
265
+ }
263
266
264
267
if (ith < num_threads) {
265
268
const int64_t num_n_per_thread0 = round_down (n / num_threads, n_step);
@@ -309,7 +312,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
309
312
GGML_ASSERT (kernel);
310
313
311
314
const int ith = params->ith ;
312
- const int nth = params->nth ;
315
+ const int nth_raw = params->nth ;
316
+ const int nth = nth_raw > 0 ? nth_raw : 1 ;
313
317
314
318
const size_t k = ne00;
315
319
const size_t m = ne11;
@@ -327,9 +331,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
327
331
const size_t num_n_per_thread = kai_roundup (kai_roundup (n, nth) / nth, n_step);
328
332
const size_t n_start = ith * num_n_per_thread;
329
333
330
- size_t n_to_process = num_n_per_thread;
331
- if ((n_start + n_to_process) > n) {
332
- n_to_process = n - n_start;
334
+ size_t n_to_process = 0 ;
335
+ if (n_start < n) {
336
+ n_to_process = num_n_per_thread;
337
+ if ((n_start + n_to_process) > n) {
338
+ n_to_process = n - n_start;
339
+ }
333
340
}
334
341
335
342
// Calculate number of columns to be processed per thread
@@ -361,8 +368,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
361
368
const void * lhs_ptr = (const void *)((const char *)lhs_packed + lhs_packed_offset);
362
369
float *dst_ptr = reinterpret_cast <float *>(static_cast <uint8_t *>(dst->data ) + dst_offset);
363
370
364
- variant_call<void >(kernel->run_kernel , m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
365
- sizeof (float ), -FLT_MAX, FLT_MAX);
371
+ if (n_to_process > 0 ) {
372
+ variant_call<void >(kernel->run_kernel , m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
373
+ sizeof (float ), -FLT_MAX, FLT_MAX);
374
+ }
366
375
367
376
return true ;
368
377
}
0 commit comments