Skip to content

Commit 42ee0cb

Browse files
committed
Use thread fence instead of atomic
1 parent 7c341a8 commit 42ee0cb

File tree

1 file changed

+25
-69
lines changed

1 file changed

+25
-69
lines changed

ggml/src/ggml-cann/common.h

Lines changed: 25 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class cann_task_queue {
241241
*/
242242
explicit cann_task_queue(size_t capacity, int32_t device)
243243
: buffer_(capacity), capacity_(capacity), head_(0), tail_(0),
244-
running_(false), device_(device), consuming_(false) {
244+
running_(false), device_(device) {
245245
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
246246
mask_ = capacity_ - 1;
247247
}
@@ -253,92 +253,52 @@ class cann_task_queue {
253253
* @return true if the task was successfully enqueued, false if the queue was full.
254254
*/
255255
bool enqueue(std::unique_ptr<cann_task>&& item) {
256-
size_t tail = tail_.load(std::memory_order_relaxed);
257-
size_t next_tail = (tail + 1) & mask_;
256+
size_t next_tail = (tail_ + 1) & mask_;
258257

259-
if (next_tail == head_.load(std::memory_order_acquire)) {
258+
if (next_tail == head_) {
260259
return false;
261260
}
262261

263-
buffer_[tail] = std::move(item);
264-
tail_.store(next_tail, std::memory_order_release);
265-
266-
cv_.notify_one();
262+
buffer_[tail_] = std::move(item);
263+
std::atomic_thread_fence(std::memory_order_release);
264+
tail_ = next_tail;
267265

268266
return true;
269267
}
270268

271-
/**
272-
* @brief Dequeues all available tasks in bulk into an output vector.
273-
*
274-
* @param output Output vector that will contain the dequeued tasks.
275-
* @return Number of tasks dequeued.
276-
*/
277-
size_t dequeue_bulk(std::vector<std::unique_ptr<cann_task>>& output) {
278-
output.clear();
279-
size_t head = head_.load(std::memory_order_relaxed);
280-
size_t tail = tail_.load(std::memory_order_acquire);
281-
282-
while (running_ && head == tail) {
283-
std::unique_lock<std::mutex> lock(mutex_);
284-
cv_.wait(lock);
285-
head = head_.load(std::memory_order_relaxed);
286-
tail = tail_.load(std::memory_order_acquire);
287-
}
288-
289-
size_t count = 0;
290-
while (running_ && head != tail) {
291-
output.push_back(std::move(buffer_[head]));
292-
head = (head + 1) & mask_;
293-
++count;
294-
}
295-
296-
head_.store(head, std::memory_order_release);
297-
return count;
298-
}
299-
300269
/**
301270
* @brief Submits a task to the queue, and starts the worker thread if not already running.
302271
*
303272
* @param task Task to be submitted.
304273
*/
305274
void submit_task(std::unique_ptr<cann_task>&& task) {
306-
while(!enqueue(std::move(task))) continue;
275+
while(!enqueue(std::move(task))) {
276+
std::this_thread::yield();
277+
continue;
278+
}
307279

308280
if (!running_) {
309-
thread_ = std::thread(&cann_task_queue::execute, this);
310281
running_ = true;
282+
thread_ = std::thread(&cann_task_queue::execute, this);
311283
}
312284

313285
}
314286

315-
/**
316-
* @brief Checks whether the queue is empty.
317-
*
318-
* @return true if the queue is empty, false otherwise.
319-
*/
320-
bool empty() const {
321-
return head_.load(std::memory_order_acquire) ==
322-
tail_.load(std::memory_order_acquire);
323-
}
324-
325287
/**
326288
* @brief Waits until the queue is completely empty and no tasks are being processed.
327289
*/
328290
void wait() {
329-
if (!running_)
330-
return;
331-
332-
while (!(empty() && consuming_)) {}
291+
while (running_ && head_ != tail_) {
292+
std::this_thread::yield();
293+
continue;
294+
}
333295
}
334296

335297
/**
336298
* @brief Stops the task queue and joins the worker thread.
337299
*/
338300
void stop() {
339301
running_ = false;
340-
wait();
341-
cv_.notify_all();
342302
if (thread_.joinable()) {
343303
thread_.join();
344304
}
@@ -349,33 +309,29 @@ class cann_task_queue {
349309
* @brief Worker thread function that continuously dequeues and executes tasks.
350310
*/
351311
void execute() {
352-
std::vector<std::unique_ptr<cann_task>> tasks;
353312
ggml_cann_set_device(device_);
354313

355-
while(running_) {
356-
consuming_ = true;
357-
int count = dequeue_bulk(tasks);
358-
consuming_ = false;
359-
if (count == 0)
314+
while (running_) {
315+
if(head_ == tail_) {
316+
std::this_thread::yield();
360317
continue;
361-
362-
for(auto &task : tasks) {
363-
task->run_task();
364318
}
319+
320+
std::atomic_thread_fence(std::memory_order_acquire);
321+
buffer_[head_]->run_task();
322+
buffer_[head_].reset();
323+
head_ = (head_ + 1) & mask_;
365324
}
366325
}
367326

368327
std::vector<std::unique_ptr<cann_task>> buffer_;
369328
const size_t capacity_;
370329
size_t mask_;
371-
std::atomic<size_t> head_;
372-
std::atomic<size_t> tail_;
373-
std::mutex mutex_;
374-
std::condition_variable cv_;
330+
size_t head_;
331+
size_t tail_;
375332
bool running_;
376333
std::thread thread_;
377334
int32_t device_;
378-
bool consuming_;
379335
};
380336

381337
/**

0 commit comments

Comments
 (0)