Skip to content

Commit 66a84d8

Browse files
committed
Improve the pool
1 parent ea39f26 commit 66a84d8

File tree

1 file changed

+47
-36
lines changed

1 file changed

+47
-36
lines changed

llamafile/pool.cpp

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,46 +48,48 @@ struct llamafile_thread {
4848
};
4949

5050
static atomic_int g_active;
51-
static _Atomic(llamafile_thread *) g_idle;
51+
static atomic_uintptr_t g_idle;
5252

53-
static void unlock_mutex(void *arg) {
54-
pthread_mutex_t *mu = (pthread_mutex_t *)arg;
55-
pthread_mutex_unlock(mu);
56-
}
53+
#define MASQUE 0x00fffffffffffff0
54+
#define PTR(x) ((uintptr_t)(x) & MASQUE)
55+
#define TAG(x) ROL((uintptr_t)(x) & ~MASQUE, 8)
56+
#define ABA(p, t) ((uintptr_t)(p) | (ROR((uintptr_t)(t), 8) & ~MASQUE))
57+
#define ROL(x, n) (((x) << (n)) | ((x) >> (64 - (n))))
58+
#define ROR(x, n) (((x) >> (n)) | ((x) << (64 - (n))))
5759

5860
static void idle_push(llamafile_thread *thread) {
59-
int backoff = 0;
60-
thread->next = atomic_load_explicit(&g_idle, memory_order_relaxed);
61-
while (!atomic_compare_exchange_weak_explicit(&g_idle, &thread->next, thread,
62-
memory_order_acq_rel, memory_order_relaxed))
63-
backoff = pthread_delay_np(&g_idle, backoff);
61+
uintptr_t tip;
62+
unassert(!TAG(thread));
63+
tip = atomic_load_explicit(&g_idle, memory_order_relaxed);
64+
for (;;) {
65+
thread->next = (llamafile_thread *)PTR(tip);
66+
if (atomic_compare_exchange_weak_explicit(&g_idle, &tip, ABA(thread, TAG(tip) + 1),
67+
memory_order_release, memory_order_relaxed))
68+
break;
69+
}
6470
}
6571

6672
static llamafile_thread *idle_pop(void) {
67-
int backoff = 0;
73+
uintptr_t tip;
6874
llamafile_thread *thread;
69-
for (;;) {
70-
if ((thread = atomic_load_explicit(&g_idle, memory_order_relaxed))) {
71-
if (atomic_compare_exchange_weak_explicit(&g_idle, &thread, thread->next,
72-
memory_order_acq_rel, memory_order_relaxed))
73-
return thread;
74-
backoff = pthread_delay_np(g_idle, backoff);
75-
} else {
76-
return nullptr;
77-
}
78-
}
75+
tip = atomic_load_explicit(&g_idle, memory_order_relaxed);
76+
while ((thread = (llamafile_thread *)PTR(tip)))
77+
if (atomic_compare_exchange_weak_explicit(&g_idle, &tip, ABA(thread->next, TAG(tip) + 1),
78+
memory_order_acquire, memory_order_relaxed))
79+
break;
80+
return thread;
7981
}
8082

8183
static void cancel_task(llamafile_task *task) {
8284
pthread_mutex_lock(&task->mu);
8385
task->res = PTHREAD_CANCELED;
84-
task->th = 0;
86+
atomic_store_explicit(&task->th, 0, memory_order_release);
8587
pthread_cond_signal(&task->cv);
8688
pthread_mutex_unlock(&task->mu);
8789
}
8890

8991
static void llamafile_thread_canceled(llamafile_thread *thread) {
90-
thread->th = 0;
92+
atomic_store_explicit(&thread->th, 0, memory_order_release);
9193
cancel_task(thread->task);
9294
delete thread;
9395
--g_active;
@@ -103,16 +105,14 @@ static void *llamafile_thread_worker(void *arg) {
103105
void *res = thread->task->func(thread->task->arg);
104106
pthread_setcancelstate(PTHREAD_CANCEL_MASKED, 0);
105107

106-
for (;;) {
107-
if (thread->th != -1)
108-
if (thread->task->th != -1)
108+
for (;;)
109+
if (atomic_load_explicit(&thread->th, memory_order_acquire) != -1)
110+
if (atomic_load_explicit(&thread->task->th, memory_order_acquire) != -1)
109111
break;
110-
pthread_pause_np();
111-
}
112112

113113
pthread_mutex_lock(&thread->task->mu);
114114
thread->task->res = res;
115-
thread->task->th = 0;
115+
atomic_store_explicit(&thread->task->th, 0, memory_order_release);
116116
pthread_cond_signal(&thread->task->cv);
117117
pthread_mutex_unlock(&thread->task->mu);
118118

@@ -131,7 +131,7 @@ static void *llamafile_thread_worker(void *arg) {
131131
if (thread->task)
132132
cancel_task(thread->task);
133133

134-
thread->th = 0;
134+
atomic_store_explicit(&thread->th, 0, memory_order_release);
135135
g_key.set(nullptr);
136136
delete thread;
137137
--g_active;
@@ -151,7 +151,8 @@ static errno_t llamafile_thread_create(llamafile_task *task) {
151151
errno_t err = pthread_create((pthread_t *)&thread->th, &attr, llamafile_thread_worker, thread);
152152
pthread_attr_destroy(&attr);
153153
if (!err) {
154-
task->th = thread->th.load();
154+
atomic_store_explicit(&task->th, atomic_load_explicit(&thread->th, memory_order_relaxed),
155+
memory_order_release);
155156
} else {
156157
delete thread;
157158
}
@@ -166,8 +167,9 @@ errno_t llamafile_task_create(llamafile_task **out_task, void *(*func)(void *),
166167
llamafile_thread *thread;
167168
if ((thread = idle_pop())) {
168169
pthread_mutex_lock(&thread->mu);
170+
atomic_store_explicit(&task->th, atomic_load_explicit(&thread->th, memory_order_relaxed),
171+
memory_order_release);
169172
thread->task = task;
170-
task->th = thread->th.load();
171173
pthread_cond_signal(&thread->cv);
172174
pthread_mutex_unlock(&thread->mu);
173175
err = 0;
@@ -182,10 +184,15 @@ errno_t llamafile_task_create(llamafile_task **out_task, void *(*func)(void *),
182184
return err;
183185
}
184186

187+
static void unlock_mutex(void *arg) {
188+
pthread_mutex_t *mu = (pthread_mutex_t *)arg;
189+
pthread_mutex_unlock(mu);
190+
}
191+
185192
errno_t llamafile_task_join(llamafile_task *task, void **out_res) {
186193
pthread_cleanup_push(unlock_mutex, &task->mu);
187194
pthread_mutex_lock(&task->mu);
188-
while (task->th)
195+
while (atomic_load_explicit(&task->th, memory_order_acquire))
189196
pthread_cond_wait(&task->cv, &task->mu);
190197
pthread_cleanup_pop(true);
191198
if (out_res)
@@ -195,9 +202,13 @@ errno_t llamafile_task_join(llamafile_task *task, void **out_res) {
195202
}
196203

197204
errno_t llamafile_task_cancel(llamafile_task *task) {
205+
pthread_t th;
198206
errno_t err = 0;
199-
if (task->th)
200-
err = pthread_cancel(task->th);
207+
if ((th = atomic_load_explicit(&task->th, memory_order_acquire))) {
208+
err = pthread_cancel(th);
209+
} else {
210+
err = ESRCH;
211+
}
201212
return err;
202213
}
203214

@@ -207,7 +218,7 @@ void llamafile_task_shutdown(void) {
207218
llamafile_thread *thread;
208219
for (;;) {
209220
while ((thread = idle_pop()))
210-
if ((th = thread->th))
221+
if ((th = atomic_load_explicit(&thread->th, memory_order_acquire)))
211222
pthread_cancel(th);
212223
if (!g_active)
213224
break;

0 commit comments

Comments
 (0)