Skip to content

Commit eb7e3c0

Browse files
committed
* code review fixes
Signed-off-by: Dan Johansson <[email protected]>
1 parent 3addc2b commit eb7e3c0

File tree

2 files changed

+41
-43
lines changed

2 files changed

+41
-43
lines changed

ggml/src/ggml-cpu/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
433433
string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
434434
string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
435435

436-
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
436+
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
437437

438438
list(APPEND GGML_KLEIDIAI_SOURCES
439439
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c

ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
138138
}
139139

140140
bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
141+
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
142+
141143
const ggml_tensor * src0 = dst->src[0];
142144
const ggml_tensor * src1 = dst->src[1];
143145

@@ -149,32 +151,27 @@ class tensor_traits : public ggml::cpu::tensor_traits {
149151
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
150152
GGML_ASSERT(kernel);
151153

152-
const size_t nth = params->nth;
153-
const size_t ith = params->ith;
154+
const int nth = params->nth;
155+
const int ith = params->ith;
154156

155-
const size_t lhs_batch_size0 = ne12;
156-
const size_t rhs_batch_size0 = ne02;
157+
const int64_t lhs_batch_size0 = ne12;
158+
const int64_t rhs_batch_size0 = ne02;
159+
const int64_t batch_size = rhs_batch_size0;
157160

158-
const size_t r = lhs_batch_size0 / rhs_batch_size0;
161+
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
159162

160-
const size_t m = ne11 * r;
161-
const size_t n = ne01;
162-
const size_t k = ne00;
163-
const size_t batch_size = rhs_batch_size0;
163+
const int64_t m = ne11 * r;
164+
const int64_t n = ne01;
165+
const int64_t k = ne00;
164166

165167
const size_t lhs_stride = src1->nb[1];
166168
const size_t rhs_stride = src0->nb[1];
167169
const size_t dst_stride = dst->nb[1];
168170

169-
const size_t mr = kernel->get_mr();
170-
const size_t nr = kernel->get_nr();
171-
const size_t kr = kernel->get_kr();
172-
const size_t sr = kernel->get_sr();
173-
174-
const size_t m_step = kernel->get_m_step();
175-
const size_t n_step = kernel->get_n_step();
176-
177-
const bool parallelize_on_m = m >= m_step * nth;
171+
const int64_t mr = static_cast<int64_t>(kernel->get_mr());
172+
const int64_t nr = static_cast<int64_t>(kernel->get_nr());
173+
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
174+
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
178175

179176
const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
180177
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
@@ -189,40 +186,36 @@ class tensor_traits : public ggml::cpu::tensor_traits {
189186
uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
190187
uint8_t * bias = rhs_kxn + kxn_size;
191188

192-
for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
189+
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
193190
const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
194191
const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
195192
uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
196193

197194
// LHS packing
198195
{
199-
const size_t m_roundup_mr = kai_roundup(m, mr);
200-
const size_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
196+
const int64_t m_roundup_mr = kai_roundup(m, mr);
197+
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
201198

202199
if (ith < num_threads) {
203-
const size_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
204-
const size_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
200+
const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
201+
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
205202

206-
const size_t m_start = ith * num_m_per_thread0;
207-
const size_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
203+
const int64_t m_start = ith * num_m_per_thread0;
204+
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
208205

209-
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
206+
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
210207
const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
211208

212209
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
213-
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
210+
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
214211

215212
variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
216213
}
217-
218-
// Investigate if this barrier can be removed.
219-
if (parallelize_on_m == false || num_threads != nth) {
220-
ggml_barrier(params->threadpool);
221-
}
222214
}
223215

224216
// RHS packing
225-
if (ith == 0) {
217+
if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
218+
// First thread to reach this point handles RHS packing
226219
memset(bias, 0, n * sizeof(float));
227220
transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
228221
reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
@@ -233,31 +226,36 @@ class tensor_traits : public ggml::cpu::tensor_traits {
233226

234227
ggml_barrier(params->threadpool);
235228

229+
first_to_arrive.clear(std::memory_order_release);
230+
236231
// Perform the matmul
237232
{
238-
const size_t m_to_process = m;
239-
const size_t m_start = 0;
233+
const int64_t m_to_process = m;
234+
const int64_t m_start = 0;
240235

241-
const size_t num_threads = KAI_MIN(n / n_step, nth);
236+
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
237+
const int64_t num_threads = KAI_MIN(n / n_step, nth);
242238

243239
if (ith < num_threads) {
244-
const size_t num_n_per_thread0 = round_down(n / num_threads, n_step);
245-
const size_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
240+
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
241+
const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
246242

247-
const size_t n_start = ith * num_n_per_thread0;
248-
const size_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
243+
const int64_t n_start = ith * num_n_per_thread0;
244+
const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
249245

250246
const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
251247
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
252248
const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
253249

254250
const void * lhs_ptr = lhs_packed + lhs_packed_offset;
255251
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
256-
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
252+
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
257253

258254
variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
259255
}
256+
}
260257

258+
if (batch_idx != batch_size - 1) {
261259
ggml_barrier(params->threadpool);
262260
}
263261
}
@@ -311,7 +309,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
311309
m_to_process = m - m_start;
312310
}
313311

314-
if(m_start < m) {
312+
if (m_start < m) {
315313
// Transform LHS
316314
const size_t src_stride = src1->nb[1];
317315
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));

0 commit comments

Comments
 (0)