@@ -241,7 +241,7 @@ class cann_task_queue {
241241 */
242242 explicit cann_task_queue (size_t capacity, int32_t device)
243243 : buffer_(capacity), capacity_(capacity), head_(0 ), tail_(0 ),
244- running_(false ), device_(device), consuming_( false ) {
244+ running_(false ), device_(device) {
245245 GGML_ASSERT ((capacity & (capacity - 1 )) == 0 && " capacity must be power of 2" );
246246 mask_ = capacity_ - 1 ;
247247 }
@@ -253,92 +253,52 @@ class cann_task_queue {
253253 * @return true if the task was successfully enqueued, false if the queue was full.
254254 */
255255 bool enqueue (std::unique_ptr<cann_task>&& item) {
256- size_t tail = tail_.load (std::memory_order_relaxed);
257- size_t next_tail = (tail + 1 ) & mask_;
256+ size_t next_tail = (tail_ + 1 ) & mask_;
258257
259- if (next_tail == head_. load (std::memory_order_acquire) ) {
258+ if (next_tail == head_) {
260259 return false ;
261260 }
262261
263- buffer_[tail] = std::move (item);
264- tail_.store (next_tail, std::memory_order_release);
265-
266- cv_.notify_one ();
262+ buffer_[tail_] = std::move (item);
263+ std::atomic_thread_fence (std::memory_order_release);
264+ tail_ = next_tail;
267265
268266 return true ;
269267 }
270268
271- /* *
272- * @brief Dequeues all available tasks in bulk into an output vector.
273- *
274- * @param output Output vector that will contain the dequeued tasks.
275- * @return Number of tasks dequeued.
276- */
277- size_t dequeue_bulk (std::vector<std::unique_ptr<cann_task>>& output) {
278- output.clear ();
279- size_t head = head_.load (std::memory_order_relaxed);
280- size_t tail = tail_.load (std::memory_order_acquire);
281-
282- while (running_ && head == tail) {
283- std::unique_lock<std::mutex> lock (mutex_);
284- cv_.wait (lock);
285- head = head_.load (std::memory_order_relaxed);
286- tail = tail_.load (std::memory_order_acquire);
287- }
288-
289- size_t count = 0 ;
290- while (running_ && head != tail) {
291- output.push_back (std::move (buffer_[head]));
292- head = (head + 1 ) & mask_;
293- ++count;
294- }
295-
296- head_.store (head, std::memory_order_release);
297- return count;
298- }
299-
300269 /* *
301270 * @brief Submits a task to the queue, and starts the worker thread if not already running.
302271 *
303272 * @param task Task to be submitted.
304273 */
305274 void submit_task (std::unique_ptr<cann_task>&& task) {
306- while (!enqueue (std::move (task))) continue ;
275+ while (!enqueue (std::move (task))) {
276+ std::this_thread::yield ();
277+ continue ;
278+ }
307279
308280 if (!running_) {
309- thread_ = std::thread (&cann_task_queue::execute, this );
310281 running_ = true ;
282+ thread_ = std::thread (&cann_task_queue::execute, this );
311283 }
312284
313285 }
314286
315- /* *
316- * @brief Checks whether the queue is empty.
317- *
318- * @return true if the queue is empty, false otherwise.
319- */
320- bool empty () const {
321- return head_.load (std::memory_order_acquire) ==
322- tail_.load (std::memory_order_acquire);
323- }
324-
325287 /* *
326288 * @brief Waits until the queue is completely empty and no tasks are being processed.
327289 */
328290 void wait () {
329- if (! running_)
330- return ;
331-
332- while (!( empty () && consuming_)) { }
291+ while ( running_ && head_ != tail_) {
292+ std::this_thread::yield () ;
293+ continue ;
294+ }
333295 }
334296
335297 /* *
336298 * @brief Stops the task queue and joins the worker thread.
337299 */
338300 void stop () {
339301 running_ = false ;
340- wait ();
341- cv_.notify_all ();
342302 if (thread_.joinable ()) {
343303 thread_.join ();
344304 }
@@ -349,33 +309,29 @@ class cann_task_queue {
349309 * @brief Worker thread function that continuously dequeues and executes tasks.
350310 */
351311 void execute () {
352- std::vector<std::unique_ptr<cann_task>> tasks;
353312 ggml_cann_set_device (device_);
354313
355- while (running_) {
356- consuming_ = true ;
357- int count = dequeue_bulk (tasks);
358- consuming_ = false ;
359- if (count == 0 )
314+ while (running_) {
315+ if (head_ == tail_) {
316+ std::this_thread::yield ();
360317 continue ;
361-
362- for (auto &task : tasks) {
363- task->run_task ();
364318 }
319+
320+ std::atomic_thread_fence (std::memory_order_acquire);
321+ buffer_[head_]->run_task ();
322+ buffer_[head_].reset ();
323+ head_ = (head_ + 1 ) & mask_;
365324 }
366325 }
367326
368327 std::vector<std::unique_ptr<cann_task>> buffer_;
369328 const size_t capacity_;
370329 size_t mask_;
371- std::atomic<size_t > head_;
372- std::atomic<size_t > tail_;
373- std::mutex mutex_;
374- std::condition_variable cv_;
330+ size_t head_;
331+ size_t tail_;
375332 bool running_;
376333 std::thread thread_;
377334 int32_t device_;
378- bool consuming_;
379335};
380336
381337/* *
0 commit comments