diff --git a/include/stdexec/__detail/__sync_wait.hpp b/include/stdexec/__detail/__sync_wait.hpp index e36918bd3..75b05ced7 100644 --- a/include/stdexec/__detail/__sync_wait.hpp +++ b/include/stdexec/__detail/__sync_wait.hpp @@ -32,6 +32,8 @@ #include "__run_loop.hpp" #include "__type_traits.hpp" +#include "__atomic.hpp" + #include #include #include @@ -89,6 +91,17 @@ namespace stdexec { struct __state { std::exception_ptr __eptr_; run_loop __loop_; + stdexec::__std::atomic __done_{false}; + + void finish() noexcept { + __loop_.finish(); + __done_.store(true, stdexec::__std::memory_order_release); + __done_.notify_all(); + } + + void wait() noexcept { + __done_.wait(false, stdexec::__std::memory_order_acquire); + } }; template @@ -108,7 +121,7 @@ namespace stdexec { STDEXEC_CATCH_ALL { __state_->__eptr_ = std::current_exception(); } - __state_->__loop_.finish(); + __state_->finish(); } template @@ -121,11 +134,11 @@ namespace stdexec { } else { __state_->__eptr_ = std::make_exception_ptr(static_cast<_Error&&>(__err)); } - __state_->__loop_.finish(); + __state_->finish(); } void set_stopped() noexcept { - __state_->__loop_.finish(); + __state_->finish(); } [[nodiscard]] @@ -286,6 +299,7 @@ namespace stdexec { // Wait for the variant to be filled in. __local_state.__loop_.run(); + __local_state.wait(); if (__local_state.__eptr_) { std::rethrow_exception(static_cast(__local_state.__eptr_));