@@ -55,6 +55,10 @@ auto getAwaiter(T &&value) noexcept(
5555 return getAwaiterImpl (static_cast <T &&>(value));
5656}
5757
58+ template <typename T>
59+ using void_to_false_t =
60+ std::conditional_t <std::is_same_v<T, void >, std::false_type, T>;
61+
5862} // end namespace internal
5963
6064template <typename T>
@@ -420,6 +424,11 @@ struct CallbackAwaiter : public trantor::NonCopyable
420424 return false ;
421425 }
422426
427+ bool hasException () const noexcept
428+ {
429+ return exception_ != nullptr ;
430+ }
431+
423432 const T &await_resume () const noexcept (false )
424433 {
425434 // await_resume() should always be called after co_await
@@ -470,6 +479,11 @@ struct CallbackAwaiter<void> : public trantor::NonCopyable
470479 std::rethrow_exception (exception_);
471480 }
472481
482+ bool hasException () const noexcept
483+ {
484+ return exception_ != nullptr ;
485+ }
486+
473487 private:
474488 std::exception_ptr exception_{nullptr };
475489
@@ -798,6 +812,180 @@ struct [[nodiscard]] EventLoopAwaiter : public drogon::CallbackAwaiter<T>
798812 std::function<T ()> task_;
799813 trantor::EventLoop *loop_;
800814};
815+
816+ template <typename ... Tasks>
817+ struct WhenAllAwaiter
818+ : public CallbackAwaiter<
819+ std::tuple<internal::void_to_false_t <await_result_t <Tasks>>...>>
820+ {
821+ WhenAllAwaiter (Tasks... tasks)
822+ : tasks_(std::forward<Tasks>(tasks)...), counter_(sizeof ...(tasks))
823+ {
824+ }
825+
826+ void await_suspend (std::coroutine_handle<> handle)
827+ {
828+ if (counter_ == 0 )
829+ {
830+ handle.resume ();
831+ return ;
832+ }
833+
834+ await_suspend_impl (handle, std::index_sequence_for<Tasks...>{});
835+ }
836+
837+ private:
838+ std::tuple<Tasks...> tasks_;
839+ std::atomic<size_t > counter_;
840+ std::tuple<internal::void_to_false_t <await_result_t <Tasks>>...> results_;
841+ std::atomic_flag exceptionFlag_;
842+
843+ template <size_t Idx>
844+ void launch_task (std::coroutine_handle<> handle)
845+ {
846+ using Self = WhenAllAwaiter<Tasks...>;
847+ [](Self *self, std::coroutine_handle<> handle) -> AsyncTask {
848+ try
849+ {
850+ using TaskType = std::tuple_element_t <
851+ Idx,
852+ std::remove_cvref_t <decltype (results_)>>;
853+ if constexpr (std::is_same_v<TaskType, std::false_type>)
854+ {
855+ co_await std::get<Idx>(self->tasks_ );
856+ std::get<Idx>(self->results_ ) = std::false_type{};
857+ }
858+ else
859+ {
860+ std::get<Idx>(self->results_ ) =
861+ co_await std::get<Idx>(self->tasks_ );
862+ }
863+ }
864+ catch (...)
865+ {
866+ if (self->exceptionFlag_ .test_and_set () == false )
867+ self->setException (std::current_exception ());
868+ }
869+
870+ if (self->counter_ .fetch_sub (1 , std::memory_order_acq_rel) == 1 )
871+ {
872+ if (!self->hasException ())
873+ self->setValue (std::move (self->results_ ));
874+ handle.resume ();
875+ }
876+ }(this , handle);
877+ }
878+
879+ template <size_t ... Is>
880+ void await_suspend_impl (std::coroutine_handle<> handle,
881+ std::index_sequence<Is...>)
882+ {
883+ ((launch_task<Is>(handle)), ...);
884+ }
885+ };
886+
887+ template <typename T>
888+ struct WhenAllAwaiter <std::vector<Task<T>>>
889+ : public CallbackAwaiter<std::vector<T>>
890+ {
891+ WhenAllAwaiter (std::vector<Task<T>> tasks)
892+ : tasks_(std::move(tasks)),
893+ counter_ (tasks_.size()),
894+ results_(tasks_.size())
895+ {
896+ }
897+
898+ void await_suspend (std::coroutine_handle<> handle)
899+ {
900+ if (tasks_.empty ())
901+ {
902+ this ->setValue (std::vector<T>{});
903+ handle.resume ();
904+ return ;
905+ }
906+
907+ const size_t count = tasks_.size ();
908+ for (size_t i = 0 ; i < count; ++i)
909+ {
910+ [](WhenAllAwaiter *self,
911+ std::coroutine_handle<> handle,
912+ Task<T> task,
913+ size_t index) -> AsyncTask {
914+ try
915+ {
916+ auto result = co_await task;
917+ self->results_ [index] = std::move (result);
918+ }
919+ catch (...)
920+ {
921+ if (self->exceptionFlag_ .test_and_set () == false )
922+ self->setException (std::current_exception ());
923+ }
924+
925+ if (self->counter_ .fetch_sub (1 , std::memory_order_acq_rel) == 1 )
926+ {
927+ if (!self->hasException ())
928+ {
929+ self->setValue (std::move (self->results_ ));
930+ }
931+ handle.resume ();
932+ }
933+ }(this , handle, std::move (tasks_[i]), i);
934+ }
935+ }
936+
937+ private:
938+ std::vector<Task<T>> tasks_;
939+ std::atomic<size_t > counter_;
940+ std::vector<T> results_;
941+ std::atomic_flag exceptionFlag_;
942+ };
943+
944+ template <>
945+ struct WhenAllAwaiter <std::vector<Task<void >>> : public CallbackAwaiter<void >
946+ {
947+ WhenAllAwaiter (std::vector<Task<void >> &&t)
948+ : tasks_(std::move(t)), counter_(tasks_.size())
949+ {
950+ }
951+
952+ void await_suspend (std::coroutine_handle<> handle)
953+ {
954+ if (tasks_.empty ())
955+ {
956+ handle.resume ();
957+ return ;
958+ }
959+
960+ const size_t count =
961+ tasks_
962+ .size (); // capture the size fist (see lifetime comment beflow)
963+ for (size_t i = 0 ; i < count; ++i)
964+ {
965+ [](WhenAllAwaiter *self,
966+ std::coroutine_handle<> handle,
967+ Task<> task) -> AsyncTask {
968+ try
969+ {
970+ co_await task;
971+ }
972+ catch (...)
973+ {
974+ if (self->exceptionFlag_ .test_and_set () == false )
975+ self->setException (std::current_exception ());
976+ }
977+ if (self->counter_ .fetch_sub (1 , std::memory_order_acq_rel) == 1 )
978+ // This line CAN delete `this` at last iteration. We MUST
979+ // NOT depend on this after last iteration
980+ handle.resume ();
981+ }(this , handle, std::move (tasks_[i]));
982+ }
983+ }
984+
985+ std::vector<Task<void >> tasks_;
986+ std::atomic<size_t > counter_;
987+ std::atomic_flag exceptionFlag_;
988+ };
801989} // namespace internal
802990
803991/* *
@@ -987,4 +1175,23 @@ class Mutex final
9871175 CoroMutexAwaiter *waiters_;
9881176};
9891177
1178+ template <typename ... Tasks>
1179+ internal::WhenAllAwaiter<Tasks...> when_all (Tasks... tasks)
1180+ {
1181+ return internal::WhenAllAwaiter<Tasks...>(std::move (tasks)...);
1182+ }
1183+
1184+ template <typename T>
1185+ internal::WhenAllAwaiter<std::vector<Task<T>>> when_all (
1186+ std::vector<Task<T>> tasks)
1187+ {
1188+ return internal::WhenAllAwaiter (std::move (tasks));
1189+ }
1190+
1191+ inline internal::WhenAllAwaiter<std::vector<Task<void >>> when_all (
1192+ std::vector<Task<void >> tasks)
1193+ {
1194+ return internal::WhenAllAwaiter (std::move (tasks));
1195+ }
1196+
9901197} // namespace drogon
0 commit comments