Skip to content

Commit 6f7601e

Browse files
authored
Handle EOF on one side of a UnixSocketRelay. (#323)
- Current implementation shuts down everything as soon as EOF is detected on DispatchSourceRead for the relay. This is problematic for, say an HTTP request where the client makes a request, and calls `shutdown(fd, SHUT_WR)` to close the send side, but it expects to be able to keep calling `recv()` to get the response. - Changed cancel handlers so that the one that sees that both sources have been cancelled then closes both the UDS and vsock fds. - Updated vminitd VsockProxy to also do `shutdown(fd, SHUT_WR)` for read hangup or EOF, and only close the underlying fds and unwire pollers on full hangup, broken pipe, error, or when both sides half-close.
1 parent 57ad7ad commit 6f7601e

File tree

4 files changed

+292
-54
lines changed

4 files changed

+292
-54
lines changed

Sources/Containerization/LinuxContainer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ extension LinuxContainer {
316316

317317
try await vm.start()
318318
do {
319-
let relayManager = UnixSocketRelayManager(vm: vm)
319+
let relayManager = UnixSocketRelayManager(vm: vm, log: self.logger)
320320
try await vm.withAgent { agent in
321321
try await agent.standardSetup()
322322

Sources/Containerization/UnixSocketRelay.swift

Lines changed: 110 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ extension SocketRelay {
159159
let hostSocket = try Socket(type: socketType)
160160
try hostSocket.listen()
161161

162+
log?.info(
163+
"listening on host UDS",
164+
metadata: [
165+
"path": "\(hostConn.path)",
166+
"vport": "\(self.port)",
167+
])
162168
let connectionStream = try hostSocket.acceptStream(closeOnDeinit: false)
163169
self.state.withLock {
164170
$0.t = Task {
@@ -185,6 +191,12 @@ extension SocketRelay {
185191
let log = self.log
186192

187193
let connectionStream = try self.vm.listen(self.port)
194+
log?.info(
195+
"listening on guest vsock",
196+
metadata: [
197+
"path": "\(hostPath)",
198+
"vport": "\(port)",
199+
])
188200
self.state.withLock {
189201
$0.t = Task {
190202
do {
@@ -212,6 +224,13 @@ extension SocketRelay {
212224
) async throws {
213225
do {
214226
let guestConn = try await vm.dial(port)
227+
log?.info(
228+
"initiating connection from host to guest",
229+
metadata: [
230+
"vport": "\(port)",
231+
"hostFd": "\(guestConn.fileDescriptor)",
232+
"guestFd": "\(hostConn.fileDescriptor)",
233+
])
215234
try await self.relay(
216235
hostConn: hostConn,
217236
guestFd: guestConn.fileDescriptor
@@ -234,6 +253,13 @@ extension SocketRelay {
234253
type: socketType,
235254
closeOnDeinit: false
236255
)
256+
log?.info(
257+
"initiating connection from host to guest",
258+
metadata: [
259+
"vport": "\(port)",
260+
"hostFd": "\(hostSocket.fileDescriptor)",
261+
"guestFd": "\(vsockConn.fileDescriptor)",
262+
])
237263
try hostSocket.connect()
238264

239265
do {
@@ -250,15 +276,19 @@ extension SocketRelay {
250276
hostConn: Socket,
251277
guestFd: Int32
252278
) async throws {
279+
// set up the source for host to guest transfers
253280
let connSource = DispatchSource.makeReadSource(
254281
fileDescriptor: hostConn.fileDescriptor,
255282
queue: self.q
256283
)
284+
285+
// set up the source for guest to host transfers
257286
let vsockConnectionSource = DispatchSource.makeReadSource(
258287
fileDescriptor: guestFd,
259288
queue: self.q
260289
)
261290

291+
// add the sources to the connection map
262292
let pairID = UUID().uuidString
263293
self.state.withLock {
264294
$0.relaySources[pairID] = ConnectionSources(
@@ -267,46 +297,72 @@ extension SocketRelay {
267297
)
268298
}
269299

270-
// `buf1` isn't used concurrently.
300+
// `buf1` is thread-safe because it is only used when servicing a serial dispatch queue
271301
nonisolated(unsafe) let buf1 = UnsafeMutableBufferPointer<UInt8>.allocate(capacity: Int(getpagesize()))
272302
connSource.setEventHandler {
273303
Self.fdCopyHandler(
274304
buffer: buf1,
275305
source: connSource,
276306
from: hostConn.fileDescriptor,
277-
to: guestFd
307+
to: guestFd,
308+
log: self.log
278309
)
279310
}
280311

312+
// `buf2` is thread-safe because it is only used when servicing a serial dispatch queue
281313
nonisolated(unsafe) let buf2 = UnsafeMutableBufferPointer<UInt8>.allocate(capacity: Int(getpagesize()))
282-
// `buf2` isn't used concurrently.
283314
vsockConnectionSource.setEventHandler {
284315
Self.fdCopyHandler(
285316
buffer: buf2,
286317
source: vsockConnectionSource,
287318
from: guestFd,
288-
to: hostConn.fileDescriptor
319+
to: hostConn.fileDescriptor,
320+
log: self.log
289321
)
290322
}
291323

292324
connSource.setCancelHandler {
293-
if !connSource.isCancelled {
325+
self.log?.info(
326+
"host cancel received",
327+
metadata: [
328+
"hostFd": "\(hostConn.fileDescriptor)",
329+
"guestFd": "\(guestFd)",
330+
])
331+
332+
// only close underlying fds when both sources are at EOF
333+
// ensure that one of the cancel handlers will see both sources cancelled
334+
self.state.withLock { _ in
294335
connSource.cancel()
336+
if vsockConnectionSource.isCancelled {
337+
try? hostConn.close()
338+
close(guestFd)
339+
}
295340
}
296-
if !vsockConnectionSource.isCancelled {
297-
vsockConnectionSource.cancel()
298-
}
299-
try? hostConn.close()
300341
}
301342

302343
vsockConnectionSource.setCancelHandler {
303-
if !vsockConnectionSource.isCancelled {
344+
self.log?.info(
345+
"guest cancel received",
346+
metadata: [
347+
"hostFd": "\(hostConn.fileDescriptor)",
348+
"guestFd": "\(guestFd)",
349+
])
350+
351+
// only close underlying fds when both sources are at EOF
352+
// ensure that one of the cancel handlers will see both sources cancelled
353+
self.state.withLock { _ in
304354
vsockConnectionSource.cancel()
355+
if connSource.isCancelled {
356+
self.log?.info(
357+
"close file descriptors",
358+
metadata: [
359+
"hostFd": "\(hostConn.fileDescriptor)",
360+
"guestFd": "\(guestFd)",
361+
])
362+
try? hostConn.close()
363+
close(guestFd)
364+
}
305365
}
306-
if !connSource.isCancelled {
307-
connSource.cancel()
308-
}
309-
close(guestFd)
310366
}
311367

312368
connSource.activate()
@@ -321,13 +377,42 @@ extension SocketRelay {
321377
log: Logger? = nil
322378
) {
323379
if source.data == 0 {
380+
log?.info(
381+
"source EOF",
382+
metadata: [
383+
"sourceFd": "\(sourceFd)",
384+
"dstFd": "\(destinationFd)",
385+
])
324386
if !source.isCancelled {
387+
log?.info(
388+
"canceling DispatchSourceRead",
389+
metadata: [
390+
"sourceFd": "\(sourceFd)",
391+
"dstFd": "\(destinationFd)",
392+
])
325393
source.cancel()
394+
if shutdown(destinationFd, SHUT_WR) != 0 {
395+
log?.info(
396+
"failed to shut down reads",
397+
metadata: [
398+
"errno": "\(errno)",
399+
"sourceFd": "\(sourceFd)",
400+
"dstFd": "\(destinationFd)",
401+
]
402+
)
403+
}
326404
}
327405
return
328406
}
329407

330408
do {
409+
log?.debug(
410+
"source copy",
411+
metadata: [
412+
"sourceFd": "\(sourceFd)",
413+
"dstFd": "\(destinationFd)",
414+
"size": "\(source.data)",
415+
])
331416
try self.fileDescriptorCopy(
332417
buffer: buffer,
333418
size: source.data,
@@ -338,6 +423,16 @@ extension SocketRelay {
338423
log?.error("file descriptor copy failed \(error)")
339424
if !source.isCancelled {
340425
source.cancel()
426+
if shutdown(destinationFd, SHUT_RDWR) != 0 {
427+
log?.info(
428+
"failed to shut down destination",
429+
metadata: [
430+
"errno": "\(errno)",
431+
"sourceFd": "\(sourceFd)",
432+
"dstFd": "\(destinationFd)",
433+
]
434+
)
435+
}
341436
}
342437
}
343438
}
@@ -374,7 +469,7 @@ extension SocketRelay {
374469
if writeResult <= 0 {
375470
throw ContainerizationError(
376471
.internalError,
377-
message: "zero byte write or error in socket relay"
472+
message: "zero byte write or error in socket relay: fd \(destinationFd), result \(writeResult)"
378473
)
379474
}
380475
writeBytesRemaining -= writeResult

Sources/ContainerizationOS/Linux/Epoll.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,11 @@ public final class Epoll: Sendable {
169169

170170
extension Epoll.Mask {
171171
public var isHangup: Bool {
172-
(self & (EPOLLHUP | EPOLLERR | EPOLLRDHUP)) != 0
172+
(self & (EPOLLHUP | EPOLLERR)) != 0
173+
}
174+
175+
public var isRhangup: Bool {
176+
(self & EPOLLRDHUP) != 0
173177
}
174178

175179
public var readyToRead: Bool {

0 commit comments

Comments
 (0)