Skip to content

Commit 3221c43

Browse files
authored
implement when_all coroutine gate (#2342)
1 parent 8079e76 commit 3221c43

File tree

2 files changed

+290
-0
lines changed

2 files changed

+290
-0
lines changed

lib/inc/drogon/utils/coroutine.h

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ auto getAwaiter(T &&value) noexcept(
5555
return getAwaiterImpl(static_cast<T &&>(value));
5656
}
5757

58+
template <typename T>
59+
using void_to_false_t =
60+
std::conditional_t<std::is_same_v<T, void>, std::false_type, T>;
61+
5862
} // end namespace internal
5963

6064
template <typename T>
@@ -420,6 +424,11 @@ struct CallbackAwaiter : public trantor::NonCopyable
420424
return false;
421425
}
422426

427+
bool hasException() const noexcept
428+
{
429+
return exception_ != nullptr;
430+
}
431+
423432
const T &await_resume() const noexcept(false)
424433
{
425434
// await_resume() should always be called after co_await
@@ -470,6 +479,11 @@ struct CallbackAwaiter<void> : public trantor::NonCopyable
470479
std::rethrow_exception(exception_);
471480
}
472481

482+
bool hasException() const noexcept
483+
{
484+
return exception_ != nullptr;
485+
}
486+
473487
private:
474488
std::exception_ptr exception_{nullptr};
475489

@@ -798,6 +812,180 @@ struct [[nodiscard]] EventLoopAwaiter : public drogon::CallbackAwaiter<T>
798812
std::function<T()> task_;
799813
trantor::EventLoop *loop_;
800814
};
815+
816+
template <typename... Tasks>
817+
struct WhenAllAwaiter
818+
: public CallbackAwaiter<
819+
std::tuple<internal::void_to_false_t<await_result_t<Tasks>>...>>
820+
{
821+
WhenAllAwaiter(Tasks... tasks)
822+
: tasks_(std::forward<Tasks>(tasks)...), counter_(sizeof...(tasks))
823+
{
824+
}
825+
826+
void await_suspend(std::coroutine_handle<> handle)
827+
{
828+
if (counter_ == 0)
829+
{
830+
handle.resume();
831+
return;
832+
}
833+
834+
await_suspend_impl(handle, std::index_sequence_for<Tasks...>{});
835+
}
836+
837+
private:
838+
std::tuple<Tasks...> tasks_;
839+
std::atomic<size_t> counter_;
840+
std::tuple<internal::void_to_false_t<await_result_t<Tasks>>...> results_;
841+
std::atomic_flag exceptionFlag_;
842+
843+
template <size_t Idx>
844+
void launch_task(std::coroutine_handle<> handle)
845+
{
846+
using Self = WhenAllAwaiter<Tasks...>;
847+
[](Self *self, std::coroutine_handle<> handle) -> AsyncTask {
848+
try
849+
{
850+
using TaskType = std::tuple_element_t<
851+
Idx,
852+
std::remove_cvref_t<decltype(results_)>>;
853+
if constexpr (std::is_same_v<TaskType, std::false_type>)
854+
{
855+
co_await std::get<Idx>(self->tasks_);
856+
std::get<Idx>(self->results_) = std::false_type{};
857+
}
858+
else
859+
{
860+
std::get<Idx>(self->results_) =
861+
co_await std::get<Idx>(self->tasks_);
862+
}
863+
}
864+
catch (...)
865+
{
866+
if (self->exceptionFlag_.test_and_set() == false)
867+
self->setException(std::current_exception());
868+
}
869+
870+
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
871+
{
872+
if (!self->hasException())
873+
self->setValue(std::move(self->results_));
874+
handle.resume();
875+
}
876+
}(this, handle);
877+
}
878+
879+
template <size_t... Is>
880+
void await_suspend_impl(std::coroutine_handle<> handle,
881+
std::index_sequence<Is...>)
882+
{
883+
((launch_task<Is>(handle)), ...);
884+
}
885+
};
886+
887+
template <typename T>
888+
struct WhenAllAwaiter<std::vector<Task<T>>>
889+
: public CallbackAwaiter<std::vector<T>>
890+
{
891+
WhenAllAwaiter(std::vector<Task<T>> tasks)
892+
: tasks_(std::move(tasks)),
893+
counter_(tasks_.size()),
894+
results_(tasks_.size())
895+
{
896+
}
897+
898+
void await_suspend(std::coroutine_handle<> handle)
899+
{
900+
if (tasks_.empty())
901+
{
902+
this->setValue(std::vector<T>{});
903+
handle.resume();
904+
return;
905+
}
906+
907+
const size_t count = tasks_.size();
908+
for (size_t i = 0; i < count; ++i)
909+
{
910+
[](WhenAllAwaiter *self,
911+
std::coroutine_handle<> handle,
912+
Task<T> task,
913+
size_t index) -> AsyncTask {
914+
try
915+
{
916+
auto result = co_await task;
917+
self->results_[index] = std::move(result);
918+
}
919+
catch (...)
920+
{
921+
if (self->exceptionFlag_.test_and_set() == false)
922+
self->setException(std::current_exception());
923+
}
924+
925+
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
926+
{
927+
if (!self->hasException())
928+
{
929+
self->setValue(std::move(self->results_));
930+
}
931+
handle.resume();
932+
}
933+
}(this, handle, std::move(tasks_[i]), i);
934+
}
935+
}
936+
937+
private:
938+
std::vector<Task<T>> tasks_;
939+
std::atomic<size_t> counter_;
940+
std::vector<T> results_;
941+
std::atomic_flag exceptionFlag_;
942+
};
943+
944+
template <>
945+
struct WhenAllAwaiter<std::vector<Task<void>>> : public CallbackAwaiter<void>
946+
{
947+
WhenAllAwaiter(std::vector<Task<void>> &&t)
948+
: tasks_(std::move(t)), counter_(tasks_.size())
949+
{
950+
}
951+
952+
void await_suspend(std::coroutine_handle<> handle)
953+
{
954+
if (tasks_.empty())
955+
{
956+
handle.resume();
957+
return;
958+
}
959+
960+
const size_t count =
961+
tasks_
962+
.size(); // capture the size fist (see lifetime comment beflow)
963+
for (size_t i = 0; i < count; ++i)
964+
{
965+
[](WhenAllAwaiter *self,
966+
std::coroutine_handle<> handle,
967+
Task<> task) -> AsyncTask {
968+
try
969+
{
970+
co_await task;
971+
}
972+
catch (...)
973+
{
974+
if (self->exceptionFlag_.test_and_set() == false)
975+
self->setException(std::current_exception());
976+
}
977+
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
978+
// This line CAN delete `this` at last iteration. We MUST
979+
// NOT depend on this after last iteration
980+
handle.resume();
981+
}(this, handle, std::move(tasks_[i]));
982+
}
983+
}
984+
985+
std::vector<Task<void>> tasks_;
986+
std::atomic<size_t> counter_;
987+
std::atomic_flag exceptionFlag_;
988+
};
801989
} // namespace internal
802990

803991
/**
@@ -987,4 +1175,23 @@ class Mutex final
9871175
CoroMutexAwaiter *waiters_;
9881176
};
9891177

1178+
template <typename... Tasks>
1179+
internal::WhenAllAwaiter<Tasks...> when_all(Tasks... tasks)
1180+
{
1181+
return internal::WhenAllAwaiter<Tasks...>(std::move(tasks)...);
1182+
}
1183+
1184+
template <typename T>
1185+
internal::WhenAllAwaiter<std::vector<Task<T>>> when_all(
1186+
std::vector<Task<T>> tasks)
1187+
{
1188+
return internal::WhenAllAwaiter(std::move(tasks));
1189+
}
1190+
1191+
inline internal::WhenAllAwaiter<std::vector<Task<void>>> when_all(
1192+
std::vector<Task<void>> tasks)
1193+
{
1194+
return internal::WhenAllAwaiter(std::move(tasks));
1195+
}
1196+
9901197
} // namespace drogon

lib/tests/unittests/CoroutineTest.cc

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@
33
#include <drogon/HttpAppFramework.h>
44
#include <trantor/net/EventLoopThread.h>
55
#include <trantor/net/EventLoopThreadPool.h>
6+
#include <atomic>
67
#include <chrono>
78
#include <cstdint>
9+
#include <exception>
810
#include <future>
11+
#include <memory>
12+
#include <mutex>
13+
#include <optional>
914
#include <type_traits>
1015

1116
using namespace drogon;
@@ -245,3 +250,81 @@ DROGON_TEST(Mutex)
245250
pool.getLoop(i)->quit();
246251
pool.wait();
247252
}
253+
254+
DROGON_TEST(WhenAll)
255+
{
256+
using TestCtx = std::shared_ptr<drogon::test::Case>;
257+
[](TestCtx TEST_CTX) -> AsyncTask {
258+
size_t counter = 0;
259+
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
260+
co_await drogon::sleepCoro(app().getLoop(), 0.2);
261+
(*counter)++;
262+
}(TEST_CTX, &counter);
263+
auto t2 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
264+
co_await drogon::sleepCoro(app().getLoop(), 0.1);
265+
(*counter)++;
266+
}(TEST_CTX, &counter);
267+
std::vector<Task<void>> tasks;
268+
tasks.emplace_back(std::move(t1));
269+
tasks.emplace_back(std::move(t2));
270+
271+
co_await when_all(std::move(tasks));
272+
CHECK(counter == 2);
273+
}(TEST_CTX);
274+
275+
[](TestCtx TEST_CTX) -> AsyncTask {
276+
std::vector<Task<void>> tasks;
277+
co_await when_all(std::move(tasks));
278+
SUCCESS();
279+
}(TEST_CTX);
280+
281+
[](TestCtx TEST_CTX) -> AsyncTask {
282+
auto t1 = [](TestCtx TEST_CTX) -> Task<int> { co_return 1; }(TEST_CTX);
283+
auto t2 = [](TestCtx TEST_CTX) -> Task<int> { co_return 2; }(TEST_CTX);
284+
std::vector<Task<int>> tasks;
285+
tasks.emplace_back(std::move(t1));
286+
tasks.emplace_back(std::move(t2));
287+
288+
auto res = co_await when_all(std::move(tasks));
289+
CO_REQUIRE(res.size() == 2);
290+
CHECK(res[0] == 1);
291+
CHECK(res[1] == 2);
292+
}(TEST_CTX);
293+
294+
[](TestCtx TEST_CTX) -> AsyncTask {
295+
auto t1 = [](TestCtx TEST_CTX) -> Task<int> { co_return 1; }(TEST_CTX);
296+
auto t2 = [](TestCtx TEST_CTX) -> Task<std::string> {
297+
co_return "Hello";
298+
}(TEST_CTX);
299+
auto [num, str] = co_await when_all(std::move(t1), std::move(t2));
300+
CHECK(num == 1);
301+
CHECK(str == "Hello");
302+
}(TEST_CTX);
303+
304+
[](TestCtx TEST_CTX) -> AsyncTask {
305+
size_t counter = 0;
306+
// Even on corutine throws, other coroutins run to completion
307+
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<int> {
308+
co_await drogon::sleepCoro(app().getLoop(), 0.2);
309+
(*counter)++;
310+
co_return 1;
311+
}(TEST_CTX, &counter);
312+
auto t2 = [](TestCtx TEST_CTX) -> Task<std::string> {
313+
co_await drogon::sleepCoro(app().getLoop(), 0.1);
314+
throw std::runtime_error("Test exception");
315+
}(TEST_CTX);
316+
CO_REQUIRE_THROWS(co_await when_all(std::move(t1), std::move(t2)));
317+
CHECK(counter == 1);
318+
}(TEST_CTX);
319+
320+
[](TestCtx TEST_CTX) -> AsyncTask {
321+
size_t counter = 0;
322+
// void retuens gets mapped to std::false_type in the tuple API
323+
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
324+
(*counter)++;
325+
co_return;
326+
}(TEST_CTX, &counter);
327+
auto [res] = co_await when_all(std::move(t1));
328+
CHECK(counter == 1);
329+
}(TEST_CTX);
330+
}

0 commit comments

Comments
 (0)