diff --git a/doc/APIreference/functions.rst b/doc/APIreference/functions.rst index ca4d1c6f7b..36b2c13cdc 100644 --- a/doc/APIreference/functions.rst +++ b/doc/APIreference/functions.rst @@ -2903,6 +2903,15 @@ Adds a thread pool to mjData and configures it for multi-threaded use. Enqueue a task in a thread pool. +.. _mju_threadPoolSetBusyWait: + +`mju_threadPoolSetBusyWait <#mju_threadPoolSetBusyWait>`__ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. mujoco-include:: mju_threadPoolSetBusyWait + +Set whether the thread pool should busy-wait for its task queue. Set to 1 to busy-wait, or 0 to use sleep. + .. _mju_threadPoolDestroy: `mju_threadPoolDestroy <#mju_threadPoolDestroy>`__ diff --git a/doc/includes/references.h b/doc/includes/references.h index c8bf05fa6d..c9195c6736 100644 --- a/doc/includes/references.h +++ b/doc/includes/references.h @@ -3420,6 +3420,7 @@ const mjpResourceProvider* mjp_getResourceProvider(const char* resource_name); const mjpResourceProvider* mjp_getResourceProviderAtSlot(int slot); mjThreadPool* mju_threadPoolCreate(size_t number_of_threads); void mju_bindThreadPool(mjData* d, void* thread_pool); +void mju_threadPoolSetBusyWait(mjThreadPool* thread_pool, int busy_wait); void mju_threadPoolEnqueue(mjThreadPool* thread_pool, mjTask* task); void mju_threadPoolDestroy(mjThreadPool* thread_pool); void mju_defaultTask(mjTask* task); diff --git a/include/mujoco/mujoco.h b/include/mujoco/mujoco.h index 9beac473bd..920f2db605 100644 --- a/include/mujoco/mujoco.h +++ b/include/mujoco/mujoco.h @@ -1429,6 +1429,9 @@ MJAPI void mju_bindThreadPool(mjData* d, void* thread_pool); // Enqueue a task in a thread pool. MJAPI void mju_threadPoolEnqueue(mjThreadPool* thread_pool, mjTask* task); +// Set whether the thread pool should use busy-waiting for its task queue. +MJAPI void mju_threadPoolSetBusyWait(mjThreadPool* thread_pool, int busy_wait); + // Destroy a thread pool. MJAPI void mju_threadPoolDestroy(mjThreadPool* thread_pool); diff --git a/python/mujoco/introspect/functions.py b/python/mujoco/introspect/functions.py index b9bb91e299..04eadb840d 100644 --- a/python/mujoco/introspect/functions.py +++ b/python/mujoco/introspect/functions.py @@ -8911,6 +8911,24 @@ ), doc='Enqueue a task in a thread pool.', )), + ('mju_threadPoolSetBusyWait', + FunctionDecl( + name='mju_threadPoolSetBusyWait', + return_type=ValueType(name='void'), + parameters=( + FunctionParameterDecl( + name='thread_pool', + type=PointerType( + inner_type=ValueType(name='mjThreadPool'), + ), + ), + FunctionParameterDecl( + name='busy_wait', + type=ValueType(name='int'), + ), + ), + doc='Set whether the thread pool should busy-waiting for its task queue.', + )), ('mju_threadPoolDestroy', FunctionDecl( name='mju_threadPoolDestroy', diff --git a/src/thread/thread_pool.cc b/src/thread/thread_pool.cc index af117f6a2d..f9e350c58d 100644 --- a/src/thread/thread_pool.cc +++ b/src/thread/thread_pool.cc @@ -129,6 +129,10 @@ class ThreadPoolImpl : public mjThreadPool { thread_pool_bound_ = true; } + void SetQueueBusyWait(int busy_wait) { + lockless_queue_.SetBusyWait(busy_wait > 0); + } + ~ThreadPoolImpl() { Shutdown(); } private: @@ -278,6 +282,11 @@ size_t mju_threadPoolCurrentWorkerId(mjThreadPool* thread_pool) { return thread_pool_impl->GetWorkerId(); } +void mju_threadPoolSetBusyWait(mjThreadPool* thread_pool, int busy_wait) { + auto thread_pool_impl = static_cast(thread_pool); + thread_pool_impl->SetQueueBusyWait(busy_wait); +} + // start a task in the threadpool void mju_threadPoolEnqueue(mjThreadPool* thread_pool, mjTask* task) { auto thread_pool_impl = static_cast(thread_pool); diff --git a/src/thread/thread_pool.h b/src/thread/thread_pool.h index 54d852d93d..4f0fd3c8ca 100644 --- a/src/thread/thread_pool.h +++ b/src/thread/thread_pool.h @@ -65,6 +65,9 @@ MJAPI size_t mju_threadPoolCurrentWorkerId(mjThreadPool* thread_pool); // Enqueue a task in a thread pool. MJAPI void mju_threadPoolEnqueue(mjThreadPool* thread_pool, mjTask* task); +// Set whether the thread pool should use busy-waiting for its task queue. +MJAPI void mju_threadPoolSetBusyWait(mjThreadPool* thread_pool, int busy_wait); + // Locks the allocation mutex to protect Arena allocations. MJAPI void mju_threadPoolLockAllocMutex(mjThreadPool* thread_pool); diff --git a/src/thread/thread_queue.h b/src/thread/thread_queue.h index 3178f8ec68..e025756b3d 100644 --- a/src/thread/thread_queue.h +++ b/src/thread/thread_queue.h @@ -43,6 +43,10 @@ class LocklessQueue { return maximum_read_cursor_ == read_cursor_; } + void SetBusyWait(bool busy_wait) { + busy_wait_ = busy_wait; + } + // Push an element into the queue. void push(const T& input) { // Reserve a slot in the queue @@ -91,8 +95,13 @@ class LocklessQueue { // Wait until the queue has an element do { if (empty) { - std::this_thread::yield(); + if (busy_wait_) { + std::this_thread::yield(); + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } } + current_read_cursor = read_cursor_.load(); current_maximum_read_cursor = maximum_read_cursor_.load(); @@ -145,6 +154,8 @@ class LocklessQueue { std::atomic maximum_read_cursor_ = 0; std::atomic buffer_[(buffer_capacity + 1)]; + + int busy_wait_ = 1; }; } // namespace mujoco