Skip to content

Commit 002cb1b

Browse files
authored
kleidiai: fix unsigned overflow bug (#15150)
* kleidiai: fix unsigned overflow bug * address review comments
1 parent 79c1160 commit 002cb1b

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
259259
const int64_t m_start = 0;
260260

261261
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
262-
const int64_t num_threads = KAI_MIN(n / n_step, nth);
262+
int64_t num_threads = KAI_MIN(n / n_step, nth);
263+
if (num_threads <= 0) {
264+
num_threads = 1;
265+
}
263266

264267
if (ith < num_threads) {
265268
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
@@ -309,7 +312,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
309312
GGML_ASSERT(kernel);
310313

311314
const int ith = params->ith;
312-
const int nth = params->nth;
315+
const int nth_raw = params->nth;
316+
const int nth = nth_raw > 0 ? nth_raw : 1;
313317

314318
const size_t k = ne00;
315319
const size_t m = ne11;
@@ -327,9 +331,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
327331
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
328332
const size_t n_start = ith * num_n_per_thread;
329333

330-
size_t n_to_process = num_n_per_thread;
331-
if ((n_start + n_to_process) > n) {
332-
n_to_process = n - n_start;
334+
size_t n_to_process = 0;
335+
if (n_start < n) {
336+
n_to_process = num_n_per_thread;
337+
if ((n_start + n_to_process) > n) {
338+
n_to_process = n - n_start;
339+
}
333340
}
334341

335342
// Calculate number of columns to be processed per thread
@@ -361,8 +368,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
361368
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
362369
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
363370

364-
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
365-
sizeof(float), -FLT_MAX, FLT_MAX);
371+
if (n_to_process > 0) {
372+
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
373+
sizeof(float), -FLT_MAX, FLT_MAX);
374+
}
366375

367376
return true;
368377
}

0 commit comments

Comments
 (0)