Skip to content

Commit 686685c

Browse files
authored
implemented stack overflow prevention (#191)
* implemented stack overflow prevention * clang-format
1 parent 5e2c478 commit 686685c

File tree

2 files changed

+73
-32
lines changed

2 files changed

+73
-32
lines changed

examples/stackoverflow.cpp

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,25 @@ struct task {
1111
using completion_signatures = ex::completion_signatures<ex::set_value_t()>;
1212

1313
struct base {
14-
virtual void complete_value() noexcept = 0;
14+
virtual void complete_value() noexcept = 0;
15+
virtual void complete_stopped() noexcept = 0;
1516
};
1617

1718
struct promise_type {
1819
struct final_awaiter {
1920
base* data;
2021
bool await_ready() noexcept { return false; }
21-
auto await_suspend(auto h) noexcept {
22-
std::cout << "final_awaiter\n";
23-
this->data->complete_value();
24-
std::cout << "completed\n";
25-
};
26-
void await_resume() noexcept {}
22+
auto await_suspend(auto h) noexcept { this->data->complete_value(); };
23+
void await_resume() noexcept {}
2724
};
2825
std::suspend_always initial_suspend() const noexcept { return {}; }
2926
final_awaiter final_suspend() const noexcept { return {this->data}; }
3027
void unhandled_exception() const noexcept {}
31-
std::coroutine_handle<> unhandled_stopped() { return std::coroutine_handle<>(); }
32-
auto return_void() {}
28+
std::coroutine_handle<> unhandled_stopped() {
29+
this->data->complete_stopped();
30+
return std::noop_coroutine();
31+
}
32+
auto return_void() {}
3333
auto get_return_object() { return task{std::coroutine_handle<promise_type>::from_promise(*this)}; }
3434
template <::beman::execution::sender Sender>
3535
auto await_transform(Sender&& sender) noexcept {
@@ -51,7 +51,14 @@ struct task {
5151
this->handle.promise().data = this;
5252
this->handle.resume();
5353
}
54-
void complete_value() noexcept override { ex::set_value(std::move(this->r)); }
54+
void complete_value() noexcept override {
55+
this->handle.destroy();
56+
ex::set_value(std::move(this->r));
57+
}
58+
void complete_stopped() noexcept override {
59+
this->handle.destroy();
60+
ex::set_stopped(std::move(this->r));
61+
}
5562
};
5663

5764
std::coroutine_handle<promise_type> handle;
@@ -63,11 +70,23 @@ struct task {
6370
};
6471

6572
int main(int ac, char*[]) {
73+
std::cout << std::unitbuf;
74+
using on_exit = std::unique_ptr<const char, decltype([](auto msg) { std::cout << msg << "\n"; })>;
6675
static_assert(ex::sender<task>);
6776
ex::sync_wait([](int n) -> task {
68-
for (int i{}; i < n; ++i) {
69-
std::cout << "await=" << (co_await ex::just(i)) << "\n";
70-
}
71-
co_return;
77+
on_exit msg("coro run to the end");
78+
if constexpr (true)
79+
for (int i{}; i < n; ++i) {
80+
std::cout << "await just=" << (co_await ex::just(i)) << "\n";
81+
}
82+
if constexpr (false)
83+
for (int i{}; i < n; ++i) {
84+
try {
85+
co_await ex::just_error(i);
86+
} catch (int x) {
87+
std::cout << "await error=" << x << "\n";
88+
}
89+
}
90+
co_await ex::just_stopped();
7291
}(ac < 2 ? 3 : 30000));
7392
}

include/beman/execution/detail/sender_awaitable.hpp

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include <type_traits>
2323
#include <utility>
2424
#include <variant>
25+
#include <tuple>
26+
#include <atomic>
2527

2628
namespace beman::execution::detail {
2729
template <class Sndr, class Promise>
@@ -31,58 +33,78 @@ class sender_awaitable {
3133
::beman::execution::detail::single_sender_value_type<Sndr, ::beman::execution::env_of_t<Promise>>;
3234
using result_type = ::std::conditional_t<::std::is_void_v<value_type>, unit, value_type>;
3335
using variant_type = ::std::variant<::std::monostate, result_type, ::std::exception_ptr>;
36+
using data_type = ::std::tuple<variant_type, ::std::atomic<bool>, ::std::coroutine_handle<Promise>>;
37+
3438
struct awaitable_receiver {
3539
using receiver_concept = ::beman::execution::receiver_t;
3640

41+
void resume() {
42+
if (::std::get<1>(*result_ptr_).exchange(true, std::memory_order_acq_rel)) {
43+
::std::get<2>(*result_ptr_).resume();
44+
}
45+
}
46+
3747
template <class... Args>
3848
requires ::std::constructible_from<result_type, Args...>
3949
void set_value(Args&&... args) && noexcept {
4050
try {
41-
result_ptr_->template emplace<1>(::std::forward<Args>(args)...);
51+
::std::get<0>(*result_ptr_).template emplace<1>(::std::forward<Args>(args)...);
4252
} catch (...) {
43-
result_ptr_->template emplace<2>(::std::current_exception());
53+
::std::get<0>(*result_ptr_).template emplace<2>(::std::current_exception());
4454
}
45-
continuation_.resume();
55+
this->resume();
4656
}
47-
4857
template <class Error>
4958
void set_error(Error&& error) && noexcept {
50-
result_ptr_->template emplace<2>(::beman::execution::detail::as_except_ptr(::std::forward<Error>(error)));
51-
continuation_.resume();
59+
::std::get<0>(*result_ptr_)
60+
.template emplace<2>(::beman::execution::detail::as_except_ptr(::std::forward<Error>(error)));
61+
this->resume();
5262
}
5363

5464
void set_stopped() && noexcept {
55-
static_cast<::std::coroutine_handle<>>(continuation_.promise().unhandled_stopped()).resume();
65+
if (::std::get<1>(*result_ptr_).exchange(true, ::std::memory_order_acq_rel)) {
66+
static_cast<::std::coroutine_handle<>>(::std::get<2>(*result_ptr_).promise().unhandled_stopped())
67+
.resume();
68+
}
5669
}
5770

5871
auto get_env() const noexcept {
59-
return ::beman::execution::detail::fwd_env{::beman::execution::get_env(continuation_.promise())};
72+
return ::beman::execution::detail::fwd_env{
73+
::beman::execution::get_env(::std::get<2>(*result_ptr_).promise())};
6074
}
6175

62-
variant_type* result_ptr_;
63-
::std::coroutine_handle<Promise> continuation_;
76+
data_type* result_ptr_;
6477
};
6578
using op_state_type = ::beman::execution::connect_result_t<Sndr, awaitable_receiver>;
6679

67-
variant_type result{};
80+
data_type result{};
6881
op_state_type state;
6982

7083
public:
7184
sender_awaitable(Sndr&& sndr, Promise& p)
72-
: state{::beman::execution::connect(
73-
::std::forward<Sndr>(sndr),
74-
awaitable_receiver{::std::addressof(result), ::std::coroutine_handle<Promise>::from_promise(p)})} {}
85+
: result{::std::monostate{}, false, ::std::coroutine_handle<Promise>::from_promise(p)},
86+
state{::beman::execution::connect(::std::forward<Sndr>(sndr),
87+
sender_awaitable::awaitable_receiver{::std::addressof(result)})} {}
7588

7689
static constexpr bool await_ready() noexcept { return false; }
77-
void await_suspend(::std::coroutine_handle<Promise>) noexcept { ::beman::execution::start(state); }
90+
bool await_suspend(::std::coroutine_handle<Promise>) noexcept {
91+
::beman::execution::start(state);
92+
if (::std::get<1>(this->result).exchange(true, std::memory_order_acq_rel)) {
93+
if (::std::holds_alternative<::std::monostate>(::std::get<0>(this->result))) {
94+
return bool(::std::get<2>(this->result).promise().unhandled_stopped());
95+
}
96+
return false;
97+
}
98+
return true;
99+
}
78100
value_type await_resume() {
79-
if (::std::holds_alternative<::std::exception_ptr>(result)) {
80-
::std::rethrow_exception(::std::get<::std::exception_ptr>(result));
101+
if (::std::holds_alternative<::std::exception_ptr>(::std::get<0>(result))) {
102+
::std::rethrow_exception(::std::get<::std::exception_ptr>(::std::get<0>(result)));
81103
}
82104
if constexpr (::std::is_void_v<value_type>) {
83105
return;
84106
} else {
85-
return ::std::get<value_type>(std::move(result));
107+
return ::std::get<value_type>(std::move(::std::get<0>(result)));
86108
}
87109
}
88110
};

0 commit comments

Comments
 (0)