Skip to content

Commit d9c69a5

Browse files
benhillisBen Hillis
andauthored
cleanup: VirtioNetworking refactoring (#13760)
* cleanup: update VirtioNetworking class to not rely on the WslCoreConfig struct * cleanup: simplify VirtioNetworking construction * remove old constructor and other cleanup * more minor cleanup * string cleanup in HandleVirtioModifyOpenPorts --------- Co-authored-by: Ben Hillis <benhill@ntdev.microsoft.com>
1 parent c3d369d commit d9c69a5

File tree

4 files changed

+96
-94
lines changed

4 files changed

+96
-94
lines changed

src/windows/service/exe/VirtioNetworking.cpp

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,23 @@ using wsl::core::VirtioNetworking;
1212

1313
static constexpr auto c_loopbackDeviceName = TEXT(LX_INIT_LOOPBACK_DEVICE_NAME);
1414

15-
VirtioNetworking::VirtioNetworking(GnsChannel&& gnsChannel, const Config& config) :
16-
m_gnsChannel(std::move(gnsChannel)), m_config(config)
15+
VirtioNetworking::VirtioNetworking(
16+
GnsChannel&& gnsChannel,
17+
bool enableLocalhostRelay,
18+
AddGuestDeviceCallback addGuestDeviceCallback,
19+
ModifyOpenPortsCallback modifyOpenPortsCallback,
20+
GuestInterfaceStateChangeCallback guestInterfaceStateChangeCallback) :
21+
m_addGuestDeviceCallback(std::move(addGuestDeviceCallback)),
22+
m_gnsChannel(std::move(gnsChannel)),
23+
m_modifyOpenPortsCallback(std::move(modifyOpenPortsCallback)),
24+
m_guestInterfaceStateChangeCallback(std::move(guestInterfaceStateChangeCallback)),
25+
m_enableLocalhostRelay(enableLocalhostRelay)
1726
{
1827
}
1928

20-
VirtioNetworking& VirtioNetworking::OnAddGuestDevice(const AddGuestDeviceRoutine& addGuestDeviceRoutine)
21-
{
22-
m_addGuestDeviceRoutine = addGuestDeviceRoutine;
23-
return *this;
24-
}
25-
26-
VirtioNetworking& VirtioNetworking::OnModifyOpenPorts(const ModifyOpenPortsCallback& modifyOpenPortsCallback)
27-
{
28-
m_modifyOpenPortsCallback = modifyOpenPortsCallback;
29-
return *this;
30-
}
31-
32-
VirtioNetworking& VirtioNetworking::OnGuestInterfaceStateChanged(const GuestInterfaceStateChangeCallback& guestInterfaceStateChangedCallback)
33-
{
34-
m_guestInterfaceStateChangeCallback = guestInterfaceStateChangedCallback;
35-
return *this;
36-
}
37-
3829
void VirtioNetworking::Initialize()
3930
try
4031
{
41-
THROW_HR_IF(E_NOT_SET, !m_addGuestDeviceRoutine || !m_modifyOpenPortsCallback || !m_guestInterfaceStateChangeCallback);
42-
4332
m_networkSettings = GetHostEndpointSettings();
4433

4534
// TODO: Determine gateway MAC address
@@ -84,7 +73,7 @@ try
8473
}
8574

8675
// Add virtio net adapter to guest
87-
m_adapterId = (*m_addGuestDeviceRoutine)(c_virtioNetworkClsid, c_virtioNetworkDeviceId, L"eth0", device_options.str().c_str());
76+
m_adapterId = m_addGuestDeviceCallback(c_virtioNetworkClsid, c_virtioNetworkDeviceId, L"eth0", device_options.str().c_str());
8877

8978
auto lock = m_lock.lock_exclusive();
9079

@@ -121,7 +110,7 @@ try
121110
UpdateDns(std::move(dnsSettings));
122111
}
123112

124-
if (m_config.EnableLocalhostRelay)
113+
if (m_enableLocalhostRelay)
125114
{
126115
SetupLoopbackDevice();
127116
}
@@ -132,7 +121,7 @@ CATCH_LOG()
132121

133122
void VirtioNetworking::SetupLoopbackDevice()
134123
{
135-
m_localhostAdapterId = (*m_addGuestDeviceRoutine)(
124+
m_localhostAdapterId = m_addGuestDeviceCallback(
136125
c_virtioNetworkClsid, c_virtioNetworkDeviceId, c_loopbackDeviceName, L"client_ip=127.0.0.1;client_mac=00:11:22:33:44:55");
137126

138127
hns::HNSEndpoint endpointProperties;
@@ -162,7 +151,7 @@ void VirtioNetworking::StartPortTracker(wil::unique_socket&& socket)
162151
m_gnsPortTrackerChannel.emplace(
163152
std::move(socket),
164153
[&](const SOCKADDR_INET& addr, int protocol, bool allocate) { return HandlePortNotification(addr, protocol, allocate); },
165-
[&](_In_ const std::string& interfaceName, _In_ bool up) { (*m_guestInterfaceStateChangeCallback)(interfaceName, up); });
154+
[&](_In_ const std::string& interfaceName, _In_ bool up) { m_guestInterfaceStateChangeCallback(interfaceName, up); });
166155
}
167156

168157
HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int protocol, bool allocate) const noexcept
@@ -181,7 +170,7 @@ HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int
181170
}
182171
}
183172

184-
if (m_config.EnableLocalhostRelay && (unspecified || loopback))
173+
if (m_enableLocalhostRelay && (unspecified || loopback))
185174
{
186175
SOCKADDR_INET localAddr = addr;
187176
if (!loopback)
@@ -196,12 +185,12 @@ HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int
196185
localAddr.Ipv6.sin6_port = addr.Ipv6.sin6_port;
197186
}
198187
}
199-
result = (*m_modifyOpenPortsCallback)(c_virtioNetworkClsid, c_loopbackDeviceName, localAddr, protocol, allocate);
188+
result = m_modifyOpenPortsCallback(c_virtioNetworkClsid, c_loopbackDeviceName, localAddr, protocol, allocate);
200189
LOG_HR_IF_MSG(E_FAIL, result != S_OK, "Failure adding localhost relay port %d", localAddr.Ipv4.sin_port);
201190
}
202191
if (!loopback)
203192
{
204-
const int localResult = (*m_modifyOpenPortsCallback)(c_virtioNetworkClsid, L"eth0", addr, protocol, allocate);
193+
const int localResult = m_modifyOpenPortsCallback(c_virtioNetworkClsid, L"eth0", addr, protocol, allocate);
205194
LOG_HR_IF_MSG(E_FAIL, localResult != S_OK, "Failure adding relay port %d", addr.Ipv4.sin_port);
206195
if (result == 0)
207196
{

src/windows/service/exe/VirtioNetworking.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,21 @@
99

1010
namespace wsl::core {
1111

12-
using AddGuestDeviceRoutine = std::function<GUID(const GUID& clsid, const GUID& deviceId, PCWSTR tag, PCWSTR options)>;
12+
using AddGuestDeviceCallback = std::function<GUID(const GUID& clsid, const GUID& deviceId, PCWSTR tag, PCWSTR options)>;
1313
using ModifyOpenPortsCallback = std::function<int(const GUID& clsid, PCWSTR tag, const SOCKADDR_INET& addr, int protocol, bool isOpen)>;
1414
using GuestInterfaceStateChangeCallback = std::function<void(const std::string& name, bool isUp)>;
1515

1616
class VirtioNetworking : public INetworkingEngine
1717
{
1818
public:
19-
VirtioNetworking(GnsChannel&& gnsChannel, const Config& config);
19+
VirtioNetworking(
20+
GnsChannel&& gnsChannel,
21+
bool enableLocalhostRelay,
22+
AddGuestDeviceCallback addGuestDeviceCallback,
23+
ModifyOpenPortsCallback modifyOpenPortsCallback,
24+
GuestInterfaceStateChangeCallback guestInterfaceStateChangeCallback);
2025
~VirtioNetworking() = default;
2126

22-
VirtioNetworking& OnAddGuestDevice(const AddGuestDeviceRoutine& addGuestDeviceRoutine);
23-
VirtioNetworking& OnModifyOpenPorts(const ModifyOpenPortsCallback& modifyOpenPortsCallback);
24-
VirtioNetworking& OnGuestInterfaceStateChanged(const GuestInterfaceStateChangeCallback& guestInterfaceStateChangedCallback);
25-
2627
// Note: This class cannot be moved because m_networkNotifyHandle captures a 'this' pointer.
2728
VirtioNetworking(const VirtioNetworking&) = delete;
2829
VirtioNetworking(VirtioNetworking&&) = delete;
@@ -49,17 +50,17 @@ class VirtioNetworking : public INetworkingEngine
4950

5051
mutable wil::srwlock m_lock;
5152

52-
std::optional<AddGuestDeviceRoutine> m_addGuestDeviceRoutine;
53+
AddGuestDeviceCallback m_addGuestDeviceCallback;
5354
GnsChannel m_gnsChannel;
5455
std::optional<GnsPortTrackerChannel> m_gnsPortTrackerChannel;
5556
std::shared_ptr<networking::NetworkSettings> m_networkSettings;
56-
const Config& m_config;
57+
bool m_enableLocalhostRelay;
5758
GUID m_localhostAdapterId;
5859
GUID m_adapterId;
5960
std::optional<NL_NETWORK_CONNECTIVITY_LEVEL_HINT> m_connectivityLevel;
6061
std::optional<NL_NETWORK_CONNECTIVITY_COST_HINT> m_connectivityCost;
61-
std::optional<ModifyOpenPortsCallback> m_modifyOpenPortsCallback;
62-
std::optional<GuestInterfaceStateChangeCallback> m_guestInterfaceStateChangeCallback;
62+
ModifyOpenPortsCallback m_modifyOpenPortsCallback;
63+
GuestInterfaceStateChangeCallback m_guestInterfaceStateChangeCallback;
6364

6465
std::optional<ULONGLONG> m_interfaceLuid;
6566
ULONG m_networkMtu = 0;

src/windows/service/exe/WslCoreVm.cpp

Lines changed: 63 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -607,55 +607,16 @@ void WslCoreVm::Initialize(const GUID& VmId, const wil::shared_handle& UserToken
607607
}
608608
else if (m_vmConfig.NetworkingMode == NetworkingMode::VirtioProxy)
609609
{
610-
auto virtioNetworkingEngine = std::make_unique<wsl::core::VirtioNetworking>(std::move(gnsChannel), m_vmConfig);
611-
virtioNetworkingEngine->OnAddGuestDevice([&](const GUID& Clsid, const GUID& DeviceId, PCWSTR Tag, PCWSTR Options) {
612-
auto guestDeviceLock = m_guestDeviceLock.lock_exclusive();
613-
return AddHdvShareWithOptions(DeviceId, Clsid, Tag, {}, Options, 0, m_userToken.get());
614-
});
615-
616-
virtioNetworkingEngine->OnModifyOpenPorts([&](const GUID& Clsid, PCWSTR Tag, const SOCKADDR_INET& addr, int protocol, bool isOpen) {
617-
if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP)
618-
{
619-
LOG_HR_MSG(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), "Unsupported bind protocol %d", protocol);
620-
return 0;
621-
}
622-
else if (addr.si_family == AF_INET6)
623-
{
624-
// The virtio net adapter does not yet support IPv6 packets, so any traffic would arrive via
625-
// IPv4. If the caller wants IPv4 they will also likely listen on an IPv4 address, which will
626-
// be handled as a separate callback to this same code.
627-
return 0;
628-
}
629-
630-
auto guestDeviceLock = m_guestDeviceLock.lock_exclusive();
631-
const auto server = m_deviceHostSupport->GetRemoteFileSystem(Clsid, c_defaultTag);
632-
if (server)
633-
{
634-
std::wstring portString(L"tag=");
635-
portString += Tag;
636-
portString += L";port_number=";
637-
portString += std::to_wstring(addr.Ipv4.sin_port);
638-
if (protocol == IPPROTO_UDP)
639-
{
640-
portString += L";udp";
641-
}
642-
if (!isOpen)
643-
{
644-
portString += L";allocate=false";
645-
}
646-
else
647-
{
648-
std::wstring addrStr(L"000.000.000.000\0");
649-
RtlIpv4AddressToStringW(&addr.Ipv4.sin_addr, addrStr.data());
650-
portString += L";listen_addr=";
651-
portString += addrStr;
652-
}
653-
LOG_IF_FAILED(server->AddShare(portString.c_str(), nullptr, 0));
654-
}
655-
return 0;
656-
});
657-
virtioNetworkingEngine->OnGuestInterfaceStateChanged([&](const std::string& name, bool isUp) {});
658-
m_networkingEngine.reset(virtioNetworkingEngine.release());
610+
m_networkingEngine = std::make_unique<wsl::core::VirtioNetworking>(
611+
std::move(gnsChannel),
612+
m_vmConfig.EnableLocalhostRelay,
613+
[this](const GUID& Clsid, const GUID& DeviceId, PCWSTR Tag, PCWSTR Options) {
614+
return HandleVirtioAddGuestDevice(Clsid, DeviceId, Tag, Options);
615+
},
616+
[this](const GUID& Clsid, PCWSTR Tag, const SOCKADDR_INET& Addr, int Protocol, bool IsOpen) {
617+
return HandleVirtioModifyOpenPorts(Clsid, Tag, Addr, Protocol, IsOpen);
618+
},
619+
[](const std::string&, bool) {});
659620
}
660621
else if (m_vmConfig.NetworkingMode == NetworkingMode::Bridged)
661622
{
@@ -2037,6 +1998,59 @@ bool WslCoreVm::IsDnsTunnelingSupported() const
20371998
return SUCCEEDED_LOG(wsl::core::networking::DnsResolver::LoadDnsResolverMethods());
20381999
}
20392000

2001+
bool WslCoreVm::IsVhdAttached(_In_ PCWSTR VhdPath)
2002+
{
2003+
auto lock = m_lock.lock_exclusive();
2004+
return m_attachedDisks.contains({DiskType::VHD, VhdPath});
2005+
}
2006+
2007+
GUID WslCoreVm::HandleVirtioAddGuestDevice(_In_ const GUID& Clsid, _In_ const GUID& DeviceId, _In_ PCWSTR Tag, _In_ PCWSTR Options)
2008+
{
2009+
auto guestDeviceLock = m_guestDeviceLock.lock_exclusive();
2010+
return AddHdvShareWithOptions(DeviceId, Clsid, Tag, {}, Options, 0, m_userToken.get());
2011+
}
2012+
2013+
int WslCoreVm::HandleVirtioModifyOpenPorts(_In_ const GUID& Clsid, _In_ PCWSTR Tag, _In_ const SOCKADDR_INET& Addr, _In_ int Protocol, _In_ bool IsOpen)
2014+
{
2015+
if (Protocol != IPPROTO_TCP && Protocol != IPPROTO_UDP)
2016+
{
2017+
LOG_HR_MSG(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), "Unsupported bind protocol %d", Protocol);
2018+
return 0;
2019+
}
2020+
else if (Addr.si_family == AF_INET6)
2021+
{
2022+
// The virtio net adapter does not yet support IPv6 packets, so any traffic would arrive via
2023+
// IPv4. If the caller wants IPv4 they will also likely listen on an IPv4 address, which will
2024+
// be handled as a separate callback to this same code.
2025+
return 0;
2026+
}
2027+
2028+
auto guestDeviceLock = m_guestDeviceLock.lock_exclusive();
2029+
const auto server = m_deviceHostSupport->GetRemoteFileSystem(Clsid, c_defaultTag);
2030+
if (server)
2031+
{
2032+
std::wstring portString = std::format(L"tag={};port_number={}", Tag, Addr.Ipv4.sin_port);
2033+
if (Protocol == IPPROTO_UDP)
2034+
{
2035+
portString += L";udp";
2036+
}
2037+
2038+
if (!IsOpen)
2039+
{
2040+
portString += L";allocate=false";
2041+
}
2042+
else
2043+
{
2044+
wchar_t addrStr[16]; // "000.000.000.000" + null terminator
2045+
RtlIpv4AddressToStringW(&Addr.Ipv4.sin_addr, addrStr);
2046+
portString += std::format(L";listen_addr={}", addrStr);
2047+
}
2048+
2049+
LOG_IF_FAILED(server->AddShare(portString.c_str(), nullptr, 0));
2050+
}
2051+
return 0;
2052+
}
2053+
20402054
WslCoreVm::DiskMountResult WslCoreVm::MountDisk(
20412055
_In_ PCWSTR Disk, _In_ DiskType MountDiskType, _In_ ULONG PartitionIndex, _In_opt_ PCWSTR Name, _In_opt_ PCWSTR Type, _In_opt_ PCWSTR Options)
20422056
{
@@ -2846,12 +2860,6 @@ LX_INIT_DRVFS_MOUNT WslCoreVm::s_InitializeDrvFs(_Inout_ WslCoreVm* VmContext, _
28462860
}
28472861
}
28482862

2849-
bool WslCoreVm::IsVhdAttached(_In_ PCWSTR VhdPath)
2850-
{
2851-
auto lock = m_lock.lock_exclusive();
2852-
return m_attachedDisks.contains({DiskType::VHD, VhdPath});
2853-
}
2854-
28552863
void CALLBACK WslCoreVm::s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context)
28562864
try
28572865
{

src/windows/service/exe/WslCoreVm.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ class WslCoreVm
107107

108108
bool IsVhdAttached(_In_ PCWSTR VhdPath);
109109

110+
GUID HandleVirtioAddGuestDevice(_In_ const GUID& Clsid, _In_ const GUID& DeviceId, _In_ PCWSTR Tag, _In_ PCWSTR Options);
111+
112+
int HandleVirtioModifyOpenPorts(_In_ const GUID& Clsid, _In_ PCWSTR Tag, _In_ const SOCKADDR_INET& Addr, _In_ int Protocol, _In_ bool IsOpen);
113+
110114
DiskMountResult MountDisk(
111115
_In_ PCWSTR Disk, _In_ DiskType MountDiskType, _In_ ULONG PartitionIndex, _In_opt_ PCWSTR Name, _In_opt_ PCWSTR Type, _In_opt_ PCWSTR Options);
112116

0 commit comments

Comments
 (0)