Skip to content

Commit 1897d05

Browse files
authored
Rethink the Accept() logic to differentiate between errors and cancellation (#14156)
* Bring relay changes * Redesign Accept() logic to differenciate between cancellation and errors * Prepare for PR * Apply PR feedback
1 parent 8ae97be commit 1897d05

File tree

10 files changed

+439
-35
lines changed

10 files changed

+439
-35
lines changed

src/shared/inc/defs.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ Module Name:
2222
#define _wcsicmp wcscasecmp
2323
#endif
2424

25+
#define NON_COPYABLE(Type) \
26+
Type(const Type&) = delete; \
27+
Type& operator=(const Type&) = delete;
28+
29+
#define NON_MOVABLE(Type) \
30+
Type(Type&&) = delete; \
31+
Type& operator=(Type&&) = delete;
32+
33+
#define DEFAULT_MOVABLE(Type) \
34+
Type(Type&&) = default; \
35+
Type& operator=(Type&&) = default;
36+
2537
namespace wsl::shared {
2638

2739
inline constexpr std::uint32_t VersionMajor = WSL_PACKAGE_VERSION_MAJOR;

src/windows/common/hvsocket.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,14 @@ void InitializeWildcardSocketAddress(_Out_ PSOCKADDR_HV Address)
3939
}
4040
} // namespace
4141

42-
wil::unique_socket wsl::windows::common::hvsocket::Accept(
43-
_In_ SOCKET ListenSocket, _In_ int Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
42+
std::optional<wil::unique_socket> wsl::windows::common::hvsocket::CancellableAccept(
43+
_In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
4444
{
4545
wil::unique_socket Socket = Create();
46-
wsl::windows::common::socket::Accept(ListenSocket, Socket.get(), Timeout, ExitHandle, Location);
46+
if (!socket::CancellableAccept(ListenSocket, Socket.get(), Timeout, ExitHandle, Location))
47+
{
48+
return {};
49+
}
4750

4851
return Socket;
4952
}

src/windows/common/hvsocket.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ Module Name:
1919

2020
namespace wsl::windows::common::hvsocket {
2121

22-
wil::unique_socket Accept(
22+
std::optional<wil::unique_socket> CancellableAccept(
2323
_In_ SOCKET ListenSocket,
24-
_In_ int Timeout,
24+
_In_ DWORD Timeout,
2525
_In_opt_ HANDLE ExitHandle = nullptr,
2626
const std::source_location& Location = std::source_location::current());
2727

src/windows/common/relay.cpp

Lines changed: 219 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,14 @@ Module Name:
1616
#include "relay.hpp"
1717
#pragma hdrstop
1818

19+
using wsl::windows::common::relay::EventHandle;
20+
using wsl::windows::common::relay::HandleWrapper;
21+
using wsl::windows::common::relay::IOHandleStatus;
22+
using wsl::windows::common::relay::MultiHandleWait;
23+
using wsl::windows::common::relay::OverlappedIOHandle;
1924
using wsl::windows::common::relay::ScopedMultiRelay;
2025
using wsl::windows::common::relay::ScopedRelay;
26+
using wsl::windows::common::relay::SingleAcceptHandle;
2127

2228
namespace {
2329

@@ -108,7 +114,7 @@ wsl::windows::common::relay::InterruptableRead(
108114
return 0;
109115
}
110116

111-
THROW_LAST_ERROR_IF(lastError != ERROR_IO_PENDING);
117+
THROW_LAST_ERROR_IF_MSG(lastError != ERROR_IO_PENDING, "Handle: 0x%p", (void*)InputHandle);
112118

113119
auto cancelRead = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] {
114120
CancelIoEx(InputHandle, Overlapped);
@@ -569,4 +575,215 @@ try
569575
}
570576
}
571577
}
572-
CATCH_LOG()
578+
CATCH_LOG()
579+
580+
void MultiHandleWait::AddHandle(std::unique_ptr<OverlappedIOHandle>&& handle, Flags flags)
581+
{
582+
m_handles.emplace_back(flags, std::move(handle));
583+
}
584+
585+
void MultiHandleWait::Cancel()
586+
{
587+
m_cancel = true;
588+
}
589+
bool MultiHandleWait::Run(std::optional<std::chrono::milliseconds> Timeout)
590+
{
591+
m_cancel = false; // Run may be called multiple times.
592+
593+
std::optional<std::chrono::steady_clock::time_point> deadline;
594+
595+
if (Timeout.has_value())
596+
{
597+
deadline = std::chrono::steady_clock::now() + Timeout.value();
598+
}
599+
600+
// Run until all handles are completed.
601+
602+
while (!m_handles.empty() && !m_cancel)
603+
{
604+
// Schedule IO on each handle until all are either pending, or completed.
605+
for (size_t i = 0; i < m_handles.size(); i++)
606+
{
607+
while (m_handles[i].second->GetState() == IOHandleStatus::Standby)
608+
{
609+
try
610+
{
611+
m_handles[i].second->Schedule();
612+
}
613+
catch (...)
614+
{
615+
if (WI_IsFlagSet(m_handles[i].first, Flags::IgnoreErrors))
616+
{
617+
m_handles[i].second.reset(); // Reset the handle so it can be deleted.
618+
}
619+
else
620+
{
621+
throw;
622+
}
623+
}
624+
}
625+
}
626+
627+
// Remove completed handles from m_handles.
628+
for (auto it = m_handles.begin(); it != m_handles.end();)
629+
{
630+
if (!it->second)
631+
{
632+
it = m_handles.erase(it);
633+
}
634+
else if (it->second->GetState() == IOHandleStatus::Completed)
635+
{
636+
if (WI_IsFlagSet(it->first, Flags::CancelOnCompleted))
637+
{
638+
m_cancel = true; // Cancel the IO if a handle with CancelOnCompleted is in the completed state.
639+
}
640+
641+
it = m_handles.erase(it);
642+
}
643+
else
644+
{
645+
++it;
646+
}
647+
}
648+
649+
if (m_handles.empty() || m_cancel)
650+
{
651+
break;
652+
}
653+
654+
// Wait for the next operation to complete.
655+
std::vector<HANDLE> waitHandles;
656+
for (const auto& e : m_handles)
657+
{
658+
waitHandles.emplace_back(e.second->GetHandle());
659+
}
660+
661+
DWORD waitTimeout = INFINITE;
662+
if (deadline.has_value())
663+
{
664+
auto miliseconds =
665+
std::chrono::duration_cast<std::chrono::milliseconds>(deadline.value() - std::chrono::steady_clock::now()).count();
666+
667+
waitTimeout = static_cast<DWORD>(std::max(0LL, miliseconds));
668+
}
669+
670+
auto result = WaitForMultipleObjects(static_cast<DWORD>(waitHandles.size()), waitHandles.data(), false, waitTimeout);
671+
if (result == WAIT_TIMEOUT)
672+
{
673+
THROW_WIN32(ERROR_TIMEOUT);
674+
}
675+
else if (result >= WAIT_OBJECT_0 && result < WAIT_OBJECT_0 + m_handles.size())
676+
{
677+
auto index = result - WAIT_OBJECT_0;
678+
679+
try
680+
{
681+
m_handles[index].second->Collect();
682+
}
683+
catch (...)
684+
{
685+
if (WI_IsFlagSet(m_handles[index].first, Flags::IgnoreErrors))
686+
{
687+
m_handles.erase(m_handles.begin() + index);
688+
}
689+
else
690+
{
691+
throw;
692+
}
693+
}
694+
}
695+
else
696+
{
697+
THROW_LAST_ERROR_MSG("Timeout: %lu, Count: %llu", waitTimeout, waitHandles.size());
698+
}
699+
}
700+
701+
return !m_cancel;
702+
}
703+
704+
IOHandleStatus OverlappedIOHandle::GetState() const
705+
{
706+
return State;
707+
}
708+
709+
EventHandle::EventHandle(HandleWrapper&& Handle, std::function<void()>&& OnSignalled) :
710+
Handle(std::move(Handle)), OnSignalled(std::move(OnSignalled))
711+
{
712+
}
713+
714+
void EventHandle::Schedule()
715+
{
716+
State = IOHandleStatus::Pending;
717+
}
718+
719+
void EventHandle::Collect()
720+
{
721+
State = IOHandleStatus::Completed;
722+
OnSignalled();
723+
}
724+
725+
HANDLE EventHandle::GetHandle() const
726+
{
727+
return Handle.Get();
728+
}
729+
730+
SingleAcceptHandle::SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function<void()>&& OnAccepted) :
731+
ListenSocket(std::move(ListenSocket)), AcceptedSocket(std::move(AcceptedSocket)), OnAccepted(std::move(OnAccepted))
732+
{
733+
Overlapped.hEvent = Event.get();
734+
}
735+
736+
SingleAcceptHandle::~SingleAcceptHandle()
737+
{
738+
if (State == IOHandleStatus::Pending)
739+
{
740+
LOG_IF_WIN32_BOOL_FALSE(CancelIoEx(ListenSocket.Get(), &Overlapped));
741+
742+
DWORD bytesProcessed{};
743+
DWORD flagsReturned{};
744+
if (!WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesProcessed, TRUE, &flagsReturned))
745+
{
746+
auto error = GetLastError();
747+
LOG_LAST_ERROR_IF(error != ERROR_CONNECTION_ABORTED && error != ERROR_OPERATION_ABORTED);
748+
}
749+
}
750+
}
751+
752+
void SingleAcceptHandle::Schedule()
753+
{
754+
WI_ASSERT(State == IOHandleStatus::Standby);
755+
756+
// Schedule the accept.
757+
DWORD bytesReturned{};
758+
if (AcceptEx((SOCKET)ListenSocket.Get(), (SOCKET)AcceptedSocket.Get(), &AcceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &bytesReturned, &Overlapped))
759+
{
760+
// Accept completed immediately.
761+
State = IOHandleStatus::Completed;
762+
OnAccepted();
763+
}
764+
else
765+
{
766+
auto error = WSAGetLastError();
767+
THROW_HR_IF_MSG(HRESULT_FROM_WIN32(error), error != ERROR_IO_PENDING, "Handle: 0x%p", (void*)ListenSocket.Get());
768+
769+
State = IOHandleStatus::Pending;
770+
}
771+
}
772+
773+
void SingleAcceptHandle::Collect()
774+
{
775+
WI_ASSERT(State == IOHandleStatus::Pending);
776+
777+
DWORD bytesReceived{};
778+
DWORD flagsReturned{};
779+
780+
THROW_IF_WIN32_BOOL_FALSE(WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesReceived, false, &flagsReturned));
781+
782+
State = IOHandleStatus::Completed;
783+
OnAccepted();
784+
}
785+
786+
HANDLE SingleAcceptHandle::GetHandle() const
787+
{
788+
return Event.get();
789+
}

0 commit comments

Comments
 (0)