Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions include/stdexec/__detail/__sync_wait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include "__run_loop.hpp"
#include "__type_traits.hpp"

#include "__atomic.hpp"

#include <exception>
#include <system_error>
#include <optional>
Expand Down Expand Up @@ -89,6 +91,20 @@ namespace stdexec {
struct __state {
std::exception_ptr __eptr_;
run_loop __loop_;
stdexec::__std::atomic<bool> __done_{false};

void finish() noexcept {
__loop_.finish();
__done_.store(true, stdexec::__std::memory_order_release);
__done_.notify_all();
}

void wait() noexcept {
// Account for spurios wakeups
while (!__done_.load(stdexec::__std::memory_order_acquire)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's still a use-after-free bug; a background thread can store true to __done_, then before it calls notify, the waiter thread may end up seeing __done_ is true and break (say, a spurious wakeup, or the wait had never reached the first iteration of the loop), then the sync_wait returns and destructs.

Btw, I don't think you need to loop here .. std::atomic::wait only returns when the value has changed, accounting for spurious wakeups (it has a loop internally).

I was able to detect this with Relacy, which I am in the process of updating to work with stdexec. The test is https://github.com/NVIDIA/stdexec/compare/main...ccotter:stdexec:sync-wait-rrd-bug?expand=1

To run the test, follow these instructions, except use this branch instead of relacy's main branch: dvyukov/relacy#33

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually not sure how this can be addressed. In code like

exec::static_thread_pool pool{1};
{
  auto s = ex::schedule(pool.get_scheduler()) | ex::then([] { return 0; });
  ex::sync_wait(s);
}

when the background thread notifies the foreground thread to wake up and unblock, the communication channel (condition variable, atomic, other) may become destroyed when the sync_wait objects on the stack created by the foreground thread go out of scope. Whether the notifier uses done.store(true); done.notify_one(); or done = true; cv.notify();, I'm not sure how to ensure the foreground thread will wake up and continue, but only after the background thread has finishing doing the notification...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lewissbaker - wondering if you might be able to shed some light on this .. I seem to have hit a mental dead end in not understanding how it's possible to to have a sync_wait() without allocating the internal run_loop on the heap, and giving a strong reference to the receiver to guarantee lifetime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ccotter The receiver is the one allocating the shared state (currently on the stack).
In fact, the code you mentioned above was very similar to the code where I detected the bug this fix is for (see here: https://github.com/olympus-robotics/hephaestus/blob/main/modules/concurrency/tests/context_tests.cpp#L96).
In fact, the solution I came up with should be actually safe. I don't have a full proof, but ran my testcase for around a million times.

Let me try to answer why what you are describing should be safe:

  • Both the state and the receiver are owned by thread calling sync_wait
  • The execution context on a different thread merely has a reference to the shared state

The execution context signals completion by setting performing the following steps (in a strongly "happens before" order):
1.1. Set the finishing of the run loop to true
1.2. (optional) signalling the run loop
1.3. Set done to true
1.4. Signal done

The execution context waiting on completion is performing the following steps:
2.1. If finished is true, goto 5.
2.2. Wait for new task
2.3. Execute task
2.4. Goto 1.
2.6. Wait for done to be set to true
2.7. return (ends lifetime of the state)

Without the 'done' step, we run into the situation that step 1.2 might access invalid memory because we already ended up in 2.7. too early.
Previously, this was ensured by actually holding the mutex while signalling the condition variable (

) which had the same effect.

I can see how there is still a tiny window which can cause a use after free as you described (even after removing the loop).

I think the only way this can be truly fixed is to introduce a count of 'tasks in flight' for the run_loop:

  • Each run_loop::__push_back call increments this count
  • Each run_loop::__execute_all decrements this count by the amount it handled
  • run_loop::finish increments the count before setting the finishing flag and decrements it before returning.
  • run_loop::run is looping as long as the count is non zero and finishing is set to true.

I wanted to avoid this, since it sounded expensive and the solution I propose here at least seems to work...

__done_.wait(false, stdexec::__std::memory_order_acquire);
}
}
};

template <class... _Values>
Expand All @@ -108,7 +124,7 @@ namespace stdexec {
STDEXEC_CATCH_ALL {
__state_->__eptr_ = std::current_exception();
}
__state_->__loop_.finish();
__state_->finish();
}

template <class _Error>
Expand All @@ -121,11 +137,11 @@ namespace stdexec {
} else {
__state_->__eptr_ = std::make_exception_ptr(static_cast<_Error&&>(__err));
}
__state_->__loop_.finish();
__state_->finish();
}

void set_stopped() noexcept {
__state_->__loop_.finish();
__state_->finish();
}

[[nodiscard]]
Expand Down Expand Up @@ -286,6 +302,7 @@ namespace stdexec {

// Wait for the variant to be filled in.
__local_state.__loop_.run();
__local_state.wait();

if (__local_state.__eptr_) {
std::rethrow_exception(static_cast<std::exception_ptr&&>(__local_state.__eptr_));
Expand Down