@@ -2013,10 +2013,11 @@ struct ggml_threadpool {
20132013 // these are atomic as an annotation for thread-sanitizer
20142014 atomic_bool stop; // Used for stopping the threadpool altogether
20152015 atomic_bool pause; // Used for pausing the threadpool or individual threads
2016+ atomic_bool abort; // Used for aborting processing of a graph
20162017
20172018 struct ggml_compute_state * workers; // per thread state
20182019 int n_threads_max; // number of threads in the pool
2019- int n_threads_cur; // number of threads used in the current graph
2020+ atomic_int n_threads_cur; // number of threads used in the current graph
20202021
20212022 int32_t prio; // Scheduling priority
20222023 uint32_t poll; // Polling level (0 - no polling)
@@ -3178,41 +3179,36 @@ inline static void ggml_critical_section_start(void) {
31783179 }
31793180}
31803181
3181- #ifdef GGML_USE_OPENMP
3182- static void ggml_barrier(struct ggml_threadpool * threadpool) {
3183- if (threadpool->n_threads_cur == 1) {
3182+ static void ggml_barrier(struct ggml_threadpool * tp) {
3183+ int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
3184+ if (n_threads == 1) {
31843185 return;
31853186 }
31863187
3188+ #ifdef GGML_USE_OPENMP
31873189 #pragma omp barrier
3188- }
31893190#else
3190- static void ggml_barrier(struct ggml_threadpool * threadpool) {
3191- if (threadpool->n_threads_cur == 1) {
3192- return;
3193- }
3194-
3195- atomic_int * n_barrier = &threadpool->n_barrier;
3196- atomic_int * n_barrier_passed = &threadpool->n_barrier_passed;
3191+ int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed);
31973192
3198- int n_threads = threadpool->n_threads_cur;
3199- int passed_old = atomic_load_explicit(n_barrier_passed, memory_order_relaxed );
3193+ // enter barrier (full seq-cst fence)
3194+ int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst );
32003195
3201- if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
3196+ int last = 0;
3197+ if (n_barrier == (n_threads - 1)) {
32023198 // last thread
3203- atomic_store( n_barrier, 0);
3204- atomic_fetch_add_explicit(n_barrier_passed, 1, memory_order_relaxed) ;
3199+ atomic_store_explicit(&tp-> n_barrier, 0, memory_order_relaxed );
3200+ last = 1 ;
32053201 } else {
32063202 // wait for other threads
3207- while (true) {
3208- if (atomic_load_explicit(n_barrier_passed, memory_order_relaxed) != passed_old) {
3209- return;
3210- }
3203+ while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
32113204 ggml_thread_cpu_relax();
32123205 }
32133206 }
3214- }
3207+
3208+ // exit barrier (full seq-cst fence)
3209+ atomic_fetch_add_explicit(&tp->n_barrier_passed, last, memory_order_seq_cst);
32153210#endif
3211+ }
32163212
32173213// TODO: make this somehow automatically executed
32183214// some sort of "sentry" mechanism
@@ -19933,64 +19929,84 @@ struct ggml_cplan ggml_graph_plan(
1993319929
1993419930static thread_ret_t ggml_graph_compute_thread(void * data) {
1993519931 struct ggml_compute_state * state = (struct ggml_compute_state *) data;
19932+ struct ggml_threadpool * tp = state->threadpool;
1993619933
19937- const struct ggml_cgraph * cgraph = state->threadpool ->cgraph;
19938- const struct ggml_cplan * cplan = state->threadpool ->cplan;
19934+ const struct ggml_cgraph * cgraph = tp ->cgraph;
19935+ const struct ggml_cplan * cplan = tp ->cplan;
1993919936
1994019937 set_numa_thread_affinity(state->ith);
1994119938
1994219939 struct ggml_compute_params params = {
1994319940 /*.ith =*/ state->ith,
19944- /*.nth =*/ state->threadpool-> n_threads_cur,
19941+ /*.nth =*/ atomic_load_explicit(&tp-> n_threads_cur, memory_order_relaxed) ,
1994519942 /*.wsize =*/ cplan->work_size,
1994619943 /*.wdata =*/ cplan->work_data,
19947- /*.threadpool=*/ state->threadpool ,
19944+ /*.threadpool=*/ tp ,
1994819945 };
1994919946
19950- for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
19947+ for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort ; node_n++) {
1995119948 struct ggml_tensor * node = cgraph->nodes[node_n];
1995219949
1995319950 ggml_compute_forward(¶ms, node);
1995419951
19955- if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
19956- state->threadpool->ec = GGML_STATUS_ABORTED;
19952+ if (state->ith == 0 && cplan->abort_callback &&
19953+ cplan->abort_callback(cplan->abort_callback_data)) {
19954+ tp->abort = true;
19955+ tp->ec = GGML_STATUS_ABORTED;
1995719956 }
1995819957
1995919958 ggml_barrier(state->threadpool);
19960-
19961- if (state->threadpool->ec != GGML_STATUS_SUCCESS) {
19962- break;
19963- }
1996419959 }
1996519960
1996619961 return 0;
1996719962}
1996819963
1996919964#ifndef GGML_USE_OPENMP
1997019965
19971- static inline bool ggml_graph_compute_ready(struct ggml_compute_state * state) {
19966+ // check if thread is active
19967+ static inline bool ggml_graph_compute_thread_active(struct ggml_compute_state * state) {
19968+ struct ggml_threadpool * threadpool = state->threadpool;
19969+ int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed);
19970+ return (state->ith < n_threads);
19971+ }
19972+
19973+ // check if thread is ready to proceed (exit from polling or sleeping)
19974+ static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * state) {
1997219975 struct ggml_threadpool * threadpool = state->threadpool;
1997319976
1997419977 if (state->pending || threadpool->stop || threadpool->pause) { return true; }
1997519978
1997619979 // check for new graph/work
1997719980 int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
1997819981 if (new_graph != state->last_graph) {
19979- state->pending = (state->ith < threadpool->n_threads_cur );
19982+ state->pending = ggml_graph_compute_thread_active (state);
1998019983 state->last_graph = new_graph;
1998119984 }
1998219985
1998319986 return state->pending;
1998419987}
1998519988
19989+ // sync thread state after polling
19990+ static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) {
19991+ struct ggml_threadpool * threadpool = state->threadpool;
19992+ // this should just be atomic_thread_fence(seq_cst) but it confuses thread-sanitizer
19993+ // so instead we just use a dummy read-modify-write
19994+ atomic_fetch_add_explicit(&threadpool->n_graph, 0, memory_order_seq_cst);
19995+ }
19996+
1998619997static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {
1998719998 struct ggml_threadpool * threadpool = state->threadpool;
1998819999
20000+ // Skip polling for unused threads
20001+ if (!ggml_graph_compute_thread_active(state)) {
20002+ return state->pending;
20003+ }
20004+
1998920005 // This seems to make 0 ... 100 a decent range for polling level across modern processors.
1999020006 // Perhaps, we can adjust it dynamically based on load and things.
1999120007 const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
1999220008
19993- for (uint64_t i=0; !ggml_graph_compute_ready (state) && i< n_rounds; i++) {
20009+ for (uint64_t i=0; !ggml_graph_compute_thread_ready (state) && i < n_rounds; i++) {
1999420010 // No new work. Keep polling.
1999520011 ggml_thread_cpu_relax();
1999620012 }
@@ -20002,13 +20018,14 @@ static inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state *
2000220018 struct ggml_threadpool * threadpool = state->threadpool;
2000320019
2000420020 if (ggml_graph_compute_poll_for_work(state)) {
20021+ ggml_graph_compute_thread_sync(state);
2000520022 return state->pending;
2000620023 }
2000720024
2000820025 ggml_mutex_lock_shared(&threadpool->mutex);
20009- while (!ggml_graph_compute_ready (state)) {
20026+ while (!ggml_graph_compute_thread_ready (state)) {
2001020027 // No new work. Wait for the signal.
20011- GGML_PRINT_DEBUG("thread #%d waiting for work\n", state->ith);
20028+ GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping) \n", state->ith);
2001220029 ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
2001320030 }
2001420031 ggml_mutex_unlock_shared(&threadpool->mutex);
@@ -20055,13 +20072,20 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
2005520072}
2005620073
2005720074// Start processing new graph
20058- static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool)
20075+ static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int n_threads )
2005920076{
20060- // always take the mutex here because the worker threads are doing hybrid poll/wait
20077+ // Always take the mutex here because the worker threads are doing hybrid poll/wait
2006120078
2006220079 ggml_mutex_lock(&threadpool->mutex);
2006320080
20064- atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_relaxed);
20081+ GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads);
20082+
20083+ // Update the number of active threads
20084+ atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
20085+
20086+ // Indicate the graph is ready to be processed
20087+ // We need the full seq-cst fence here because of the polling threads (used in thread_sync)
20088+ atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst);
2006520089
2006620090 if (threadpool->pause) {
2006720091 // Update main thread prio and affinity to match the threadpool settings
@@ -20120,6 +20144,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
2012020144 threadpool->current_chunk = 0;
2012120145 threadpool->stop = false;
2012220146 threadpool->pause = tpp->paused;
20147+ threadpool->abort = false;
2012320148 threadpool->workers = NULL;
2012420149 threadpool->n_threads_max = tpp->n_threads;
2012520150 threadpool->n_threads_cur = tpp->n_threads;
@@ -20195,15 +20220,11 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
2019520220 // No worker threads should be accessing the parameters below at this stage
2019620221 threadpool->cgraph = cgraph;
2019720222 threadpool->cplan = cplan;
20198- threadpool->n_threads_cur = n_threads;
2019920223 threadpool->current_chunk = 0;
20224+ threadpool->abort = false;
2020020225 threadpool->ec = GGML_STATUS_SUCCESS;
2020120226 }
2020220227
20203- if (n_threads > threadpool->n_threads_max) {
20204- GGML_PRINT("WARNING: cplan is requesting more threads than the threadpool contains. Expect a bad time!\n");
20205- }
20206-
2020720228#ifdef GGML_USE_OPENMP
2020820229 if (n_threads > 1) {
2020920230 #pragma omp parallel num_threads(n_threads)
@@ -20212,7 +20233,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
2021220233 {
2021320234 // update the number of threads from the actual number of threads that we got from OpenMP
2021420235 n_threads = omp_get_num_threads();
20215- threadpool->n_threads_cur = n_threads;
20236+ atomic_store_explicit(& threadpool->n_threads_cur, n_threads, memory_order_relaxed) ;
2021620237 }
2021720238
2021820239 ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
@@ -20221,8 +20242,13 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
2022120242 ggml_graph_compute_thread(&threadpool->workers[0]);
2022220243 }
2022320244#else
20245+ if (n_threads > threadpool->n_threads_max) {
20246+ GGML_PRINT("WARNING: cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
20247+ n_threads = threadpool->n_threads_max;
20248+ }
20249+
2022420250 // Kick all threads to start the new graph
20225- ggml_graph_compute_kickoff(threadpool);
20251+ ggml_graph_compute_kickoff(threadpool, n_threads );
2022620252
2022720253 // This is a work thread too
2022820254 ggml_graph_compute_thread(&threadpool->workers[0]);
0 commit comments