@@ -138,6 +138,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
138138 }
139139
140140 bool compute_forward_kv_cache (ggml_compute_params * params, struct ggml_tensor * dst) {
141+ static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
142+
141143 const ggml_tensor * src0 = dst->src [0 ];
142144 const ggml_tensor * src1 = dst->src [1 ];
143145
@@ -149,32 +151,27 @@ class tensor_traits : public ggml::cpu::tensor_traits {
149151 kernel_info * kernel = src1->ne [1 ] == 1 ? &kernels->gemv : &kernels->gemm ;
150152 GGML_ASSERT (kernel);
151153
152- const size_t nth = params->nth ;
153- const size_t ith = params->ith ;
154+ const int nth = params->nth ;
155+ const int ith = params->ith ;
154156
155- const size_t lhs_batch_size0 = ne12;
156- const size_t rhs_batch_size0 = ne02;
157+ const int64_t lhs_batch_size0 = ne12;
158+ const int64_t rhs_batch_size0 = ne02;
159+ const int64_t batch_size = rhs_batch_size0;
157160
158- const size_t r = lhs_batch_size0 / rhs_batch_size0;
161+ const int64_t r = lhs_batch_size0 / rhs_batch_size0;
159162
160- const size_t m = ne11 * r;
161- const size_t n = ne01;
162- const size_t k = ne00;
163- const size_t batch_size = rhs_batch_size0;
163+ const int64_t m = ne11 * r;
164+ const int64_t n = ne01;
165+ const int64_t k = ne00;
164166
165167 const size_t lhs_stride = src1->nb [1 ];
166168 const size_t rhs_stride = src0->nb [1 ];
167169 const size_t dst_stride = dst->nb [1 ];
168170
169- const size_t mr = kernel->get_mr ();
170- const size_t nr = kernel->get_nr ();
171- const size_t kr = kernel->get_kr ();
172- const size_t sr = kernel->get_sr ();
173-
174- const size_t m_step = kernel->get_m_step ();
175- const size_t n_step = kernel->get_n_step ();
176-
177- const bool parallelize_on_m = m >= m_step * nth;
171+ const int64_t mr = static_cast <int64_t >(kernel->get_mr ());
172+ const int64_t nr = static_cast <int64_t >(kernel->get_nr ());
173+ const int64_t kr = static_cast <int64_t >(kernel->get_kr ());
174+ const int64_t sr = static_cast <int64_t >(kernel->get_sr ());
178175
179176 const size_t lhs_packed_size = variant_call<size_t >(kernels->lhs_info .packed_size , m, k, mr, kr, sr);
180177 const size_t rhs_packed_size = variant_call<size_t >(kernels->rhs_info .packed_size , n, k);
@@ -189,40 +186,36 @@ class tensor_traits : public ggml::cpu::tensor_traits {
189186 uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
190187 uint8_t * bias = rhs_kxn + kxn_size;
191188
192- for (size_t batch_idx = 0 ; batch_idx < batch_size; ++batch_idx) {
189+ for (int64_t batch_idx = 0 ; batch_idx < batch_size; ++batch_idx) {
193190 const uint8_t * lhs_batch = static_cast <const uint8_t *>(src1->data ) + batch_idx * m * lhs_stride;
194191 const uint8_t * rhs_batch = static_cast <const uint8_t *>(src0->data ) + batch_idx * n * rhs_stride;
195192 uint8_t * dst_batch = static_cast <uint8_t *>(dst->data ) + batch_idx * m * dst_stride;
196193
197194 // LHS packing
198195 {
199- const size_t m_roundup_mr = kai_roundup (m, mr);
200- const size_t num_threads = KAI_MIN (m_roundup_mr / mr, nth);
196+ const int64_t m_roundup_mr = kai_roundup (m, mr);
197+ const int64_t num_threads = KAI_MIN (m_roundup_mr / mr, nth);
201198
202199 if (ith < num_threads) {
203- const size_t num_m_per_thread0 = round_down (m_roundup_mr / num_threads, mr);
204- const size_t num_m_per_threadN_1 = m - (num_threads - 1 ) * num_m_per_thread0;
200+ const int64_t num_m_per_thread0 = round_down (m_roundup_mr / num_threads, mr);
201+ const int64_t num_m_per_threadN_1 = m - (num_threads - 1 ) * num_m_per_thread0;
205202
206- const size_t m_start = ith * num_m_per_thread0;
207- const size_t num_m_per_thread = (ith == num_threads - 1 ) ? num_m_per_threadN_1 : num_m_per_thread0;
203+ const int64_t m_start = ith * num_m_per_thread0;
204+ const int64_t num_m_per_thread = (ith == num_threads - 1 ) ? num_m_per_threadN_1 : num_m_per_thread0;
208205
209- const size_t lhs_offset = variant_call<size_t >(kernels->gemm .get_lhs_offset , m_start, lhs_stride);
206+ const size_t lhs_offset = variant_call<size_t >(kernels->gemm .get_lhs_offset , m_start, lhs_stride);
210207 const size_t lhs_packed_offset = variant_call<size_t >(kernels->lhs_info .get_packed_offset , m_start, k, mr, kr, sr);
211208
212209 const void * src_ptr = static_cast <const uint8_t *>(lhs_batch) + lhs_offset;
213- void * dst_ptr = static_cast <uint8_t *>(lhs_packed) + lhs_packed_offset;
210+ void * dst_ptr = static_cast <uint8_t *>(lhs_packed) + lhs_packed_offset;
214211
215212 variant_call<void >(kernels->lhs_info .pack_func , num_m_per_thread, k, mr, kr, sr, 0 , src_ptr, lhs_stride, dst_ptr);
216213 }
217-
218- // Investigate if this barrier can be removed.
219- if (parallelize_on_m == false || num_threads != nth) {
220- ggml_barrier (params->threadpool );
221- }
222214 }
223215
224216 // RHS packing
225- if (ith == 0 ) {
217+ if (first_to_arrive.test_and_set (std::memory_order_acquire) == false ) {
218+ // First thread to reach this point handles RHS packing
226219 memset (bias, 0 , n * sizeof (float ));
227220 transpose_f32kxn_f16nxk (n, k, reinterpret_cast <float *>(rhs_kxn),
228221 reinterpret_cast <const uint16_t *>(rhs_batch), rhs_stride);
@@ -233,31 +226,36 @@ class tensor_traits : public ggml::cpu::tensor_traits {
233226
234227 ggml_barrier (params->threadpool );
235228
229+ first_to_arrive.clear (std::memory_order_release);
230+
236231 // Perform the matmul
237232 {
238- const size_t m_to_process = m;
239- const size_t m_start = 0 ;
233+ const int64_t m_to_process = m;
234+ const int64_t m_start = 0 ;
240235
241- const size_t num_threads = KAI_MIN (n / n_step, nth);
236+ const int64_t n_step = static_cast <int64_t >(kernel->get_n_step ());
237+ const int64_t num_threads = KAI_MIN (n / n_step, nth);
242238
243239 if (ith < num_threads) {
244- const size_t num_n_per_thread0 = round_down (n / num_threads, n_step);
245- const size_t num_n_per_threadN_1 = n - (num_threads - 1 ) * num_n_per_thread0;
240+ const int64_t num_n_per_thread0 = round_down (n / num_threads, n_step);
241+ const int64_t num_n_per_threadN_1 = n - (num_threads - 1 ) * num_n_per_thread0;
246242
247- const size_t n_start = ith * num_n_per_thread0;
248- const size_t n_to_process = (ith == num_threads - 1 ) ? num_n_per_threadN_1 : num_n_per_thread0;
243+ const int64_t n_start = ith * num_n_per_thread0;
244+ const int64_t n_to_process = (ith == num_threads - 1 ) ? num_n_per_threadN_1 : num_n_per_thread0;
249245
250246 const size_t lhs_packed_offset = variant_call<size_t >(kernel->get_lhs_offset , m_start, k);
251247 const size_t rhs_packed_offset = variant_call<size_t >(kernel->get_rhs_packed_offset , n_start, k);
252248 const size_t dst_offset = kernel->get_dst_offset (m_start, n_start, dst_stride);
253249
254250 const void * lhs_ptr = lhs_packed + lhs_packed_offset;
255251 const void * rhs_ptr = rhs_packed + rhs_packed_offset;
256- float * dst_ptr = reinterpret_cast <float *>(dst_batch + dst_offset);
252+ float * dst_ptr = reinterpret_cast <float *>(dst_batch + dst_offset);
257253
258254 variant_call<void >(kernel->run_kernel , m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof (float ), -FLT_MAX, FLT_MAX);
259255 }
256+ }
260257
258+ if (batch_idx != batch_size - 1 ) {
261259 ggml_barrier (params->threadpool );
262260 }
263261 }
@@ -311,7 +309,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
311309 m_to_process = m - m_start;
312310 }
313311
314- if (m_start < m) {
312+ if (m_start < m) {
315313 // Transform LHS
316314 const size_t src_stride = src1->nb [1 ];
317315 const float * src_ptr = reinterpret_cast <const float *>(lhs + lhs_info->get_offset (m_start, dst->src [1 ]->nb [1 ]));
0 commit comments