@@ -16,8 +16,14 @@ Module Name:
1616#include " relay.hpp"
1717#pragma hdrstop
1818
19+ using wsl::windows::common::relay::EventHandle;
20+ using wsl::windows::common::relay::HandleWrapper;
21+ using wsl::windows::common::relay::IOHandleStatus;
22+ using wsl::windows::common::relay::MultiHandleWait;
23+ using wsl::windows::common::relay::OverlappedIOHandle;
1924using wsl::windows::common::relay::ScopedMultiRelay;
2025using wsl::windows::common::relay::ScopedRelay;
26+ using wsl::windows::common::relay::SingleAcceptHandle;
2127
2228namespace {
2329
@@ -108,7 +114,7 @@ wsl::windows::common::relay::InterruptableRead(
108114 return 0 ;
109115 }
110116
111- THROW_LAST_ERROR_IF (lastError != ERROR_IO_PENDING);
117+ THROW_LAST_ERROR_IF_MSG (lastError != ERROR_IO_PENDING, " Handle: 0x%p " , ( void *)InputHandle );
112118
113119 auto cancelRead = wil::scope_exit_log (WI_DIAGNOSTICS_INFO, [&] {
114120 CancelIoEx (InputHandle, Overlapped);
@@ -569,4 +575,215 @@ try
569575 }
570576 }
571577}
572- CATCH_LOG ()
578+ CATCH_LOG ()
579+
580+ void MultiHandleWait::AddHandle(std::unique_ptr<OverlappedIOHandle>&& handle, Flags flags)
581+ {
582+ m_handles.emplace_back (flags, std::move (handle));
583+ }
584+
585+ void MultiHandleWait::Cancel ()
586+ {
587+ m_cancel = true ;
588+ }
589+ bool MultiHandleWait::Run (std::optional<std::chrono::milliseconds> Timeout)
590+ {
591+ m_cancel = false ; // Run may be called multiple times.
592+
593+ std::optional<std::chrono::steady_clock::time_point> deadline;
594+
595+ if (Timeout.has_value ())
596+ {
597+ deadline = std::chrono::steady_clock::now () + Timeout.value ();
598+ }
599+
600+ // Run until all handles are completed.
601+
602+ while (!m_handles.empty () && !m_cancel)
603+ {
604+ // Schedule IO on each handle until all are either pending, or completed.
605+ for (size_t i = 0 ; i < m_handles.size (); i++)
606+ {
607+ while (m_handles[i].second ->GetState () == IOHandleStatus::Standby)
608+ {
609+ try
610+ {
611+ m_handles[i].second ->Schedule ();
612+ }
613+ catch (...)
614+ {
615+ if (WI_IsFlagSet (m_handles[i].first , Flags::IgnoreErrors))
616+ {
617+ m_handles[i].second .reset (); // Reset the handle so it can be deleted.
618+ }
619+ else
620+ {
621+ throw ;
622+ }
623+ }
624+ }
625+ }
626+
627+ // Remove completed handles from m_handles.
628+ for (auto it = m_handles.begin (); it != m_handles.end ();)
629+ {
630+ if (!it->second )
631+ {
632+ it = m_handles.erase (it);
633+ }
634+ else if (it->second ->GetState () == IOHandleStatus::Completed)
635+ {
636+ if (WI_IsFlagSet (it->first , Flags::CancelOnCompleted))
637+ {
638+ m_cancel = true ; // Cancel the IO if a handle with CancelOnCompleted is in the completed state.
639+ }
640+
641+ it = m_handles.erase (it);
642+ }
643+ else
644+ {
645+ ++it;
646+ }
647+ }
648+
649+ if (m_handles.empty () || m_cancel)
650+ {
651+ break ;
652+ }
653+
654+ // Wait for the next operation to complete.
655+ std::vector<HANDLE> waitHandles;
656+ for (const auto & e : m_handles)
657+ {
658+ waitHandles.emplace_back (e.second ->GetHandle ());
659+ }
660+
661+ DWORD waitTimeout = INFINITE;
662+ if (deadline.has_value ())
663+ {
664+ auto miliseconds =
665+ std::chrono::duration_cast<std::chrono::milliseconds>(deadline.value () - std::chrono::steady_clock::now ()).count ();
666+
667+ waitTimeout = static_cast <DWORD>(std::max (0LL , miliseconds));
668+ }
669+
670+ auto result = WaitForMultipleObjects (static_cast <DWORD>(waitHandles.size ()), waitHandles.data (), false , waitTimeout);
671+ if (result == WAIT_TIMEOUT)
672+ {
673+ THROW_WIN32 (ERROR_TIMEOUT);
674+ }
675+ else if (result >= WAIT_OBJECT_0 && result < WAIT_OBJECT_0 + m_handles.size ())
676+ {
677+ auto index = result - WAIT_OBJECT_0;
678+
679+ try
680+ {
681+ m_handles[index].second ->Collect ();
682+ }
683+ catch (...)
684+ {
685+ if (WI_IsFlagSet (m_handles[index].first , Flags::IgnoreErrors))
686+ {
687+ m_handles.erase (m_handles.begin () + index);
688+ }
689+ else
690+ {
691+ throw ;
692+ }
693+ }
694+ }
695+ else
696+ {
697+ THROW_LAST_ERROR_MSG (" Timeout: %lu, Count: %llu" , waitTimeout, waitHandles.size ());
698+ }
699+ }
700+
701+ return !m_cancel;
702+ }
703+
704+ IOHandleStatus OverlappedIOHandle::GetState () const
705+ {
706+ return State;
707+ }
708+
709+ EventHandle::EventHandle (HandleWrapper&& Handle, std::function<void ()>&& OnSignalled) :
710+ Handle(std::move(Handle)), OnSignalled(std::move(OnSignalled))
711+ {
712+ }
713+
714+ void EventHandle::Schedule ()
715+ {
716+ State = IOHandleStatus::Pending;
717+ }
718+
719+ void EventHandle::Collect ()
720+ {
721+ State = IOHandleStatus::Completed;
722+ OnSignalled ();
723+ }
724+
725+ HANDLE EventHandle::GetHandle () const
726+ {
727+ return Handle.Get ();
728+ }
729+
730+ SingleAcceptHandle::SingleAcceptHandle (HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function<void ()>&& OnAccepted) :
731+ ListenSocket(std::move(ListenSocket)), AcceptedSocket(std::move(AcceptedSocket)), OnAccepted(std::move(OnAccepted))
732+ {
733+ Overlapped.hEvent = Event.get ();
734+ }
735+
736+ SingleAcceptHandle::~SingleAcceptHandle ()
737+ {
738+ if (State == IOHandleStatus::Pending)
739+ {
740+ LOG_IF_WIN32_BOOL_FALSE (CancelIoEx (ListenSocket.Get (), &Overlapped));
741+
742+ DWORD bytesProcessed{};
743+ DWORD flagsReturned{};
744+ if (!WSAGetOverlappedResult ((SOCKET)ListenSocket.Get (), &Overlapped, &bytesProcessed, TRUE , &flagsReturned))
745+ {
746+ auto error = GetLastError ();
747+ LOG_LAST_ERROR_IF (error != ERROR_CONNECTION_ABORTED && error != ERROR_OPERATION_ABORTED);
748+ }
749+ }
750+ }
751+
752+ void SingleAcceptHandle::Schedule ()
753+ {
754+ WI_ASSERT (State == IOHandleStatus::Standby);
755+
756+ // Schedule the accept.
757+ DWORD bytesReturned{};
758+ if (AcceptEx ((SOCKET)ListenSocket.Get (), (SOCKET)AcceptedSocket.Get (), &AcceptBuffer, 0 , sizeof (SOCKADDR_STORAGE), sizeof (SOCKADDR_STORAGE), &bytesReturned, &Overlapped))
759+ {
760+ // Accept completed immediately.
761+ State = IOHandleStatus::Completed;
762+ OnAccepted ();
763+ }
764+ else
765+ {
766+ auto error = WSAGetLastError ();
767+ THROW_HR_IF_MSG (HRESULT_FROM_WIN32 (error), error != ERROR_IO_PENDING, " Handle: 0x%p" , (void *)ListenSocket.Get ());
768+
769+ State = IOHandleStatus::Pending;
770+ }
771+ }
772+
773+ void SingleAcceptHandle::Collect ()
774+ {
775+ WI_ASSERT (State == IOHandleStatus::Pending);
776+
777+ DWORD bytesReceived{};
778+ DWORD flagsReturned{};
779+
780+ THROW_IF_WIN32_BOOL_FALSE (WSAGetOverlappedResult ((SOCKET)ListenSocket.Get (), &Overlapped, &bytesReceived, false , &flagsReturned));
781+
782+ State = IOHandleStatus::Completed;
783+ OnAccepted ();
784+ }
785+
786+ HANDLE SingleAcceptHandle::GetHandle () const
787+ {
788+ return Event.get ();
789+ }
0 commit comments