Skip to content

Commit d75ed70

Browse files
authored
Refactor QuiescingHelper to exhaustively iterate state (#193)
# Motivation Currently the `QuiescingHelper` is crashing on a precondition if you call shutdown when it already was shutdown. However, that can totally happen and we should support it. # Modification Refactor the `QuiescingHelper` to exhaustively switch over its state in every method. Furthermore, I added a few more test cases to test realistic scenarios. # Result We are now reliable checking our state and making sure to allow most transitions.
1 parent 6bd9bf5 commit d75ed70

File tree

3 files changed

+264
-71
lines changed

3 files changed

+264
-71
lines changed

Sources/NIOExtras/QuiescingHelper.swift

Lines changed: 80 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,25 @@ private enum ShutdownError: Error {
2323
/// `channelAdded` method in the same event loop tick as the `Channel` is actually created.
2424
private final class ChannelCollector {
2525
enum LifecycleState {
26-
case upAndRunning
27-
case shuttingDown
26+
case upAndRunning(
27+
openChannels: [ObjectIdentifier: Channel],
28+
serverChannel: Channel
29+
)
30+
case shuttingDown(
31+
openChannels: [ObjectIdentifier: Channel],
32+
fullyShutdownPromise: EventLoopPromise<Void>
33+
)
2834
case shutdownCompleted
2935
}
3036

31-
private var openChannels: [ObjectIdentifier: Channel] = [:]
32-
private let serverChannel: Channel
33-
private var fullyShutdownPromise: EventLoopPromise<Void>? = nil
34-
private var lifecycleState = LifecycleState.upAndRunning
37+
private var lifecycleState: LifecycleState
3538

36-
private var eventLoop: EventLoop {
37-
return self.serverChannel.eventLoop
38-
}
39+
private let eventLoop: EventLoop
3940

4041
/// Initializes a `ChannelCollector` for `Channel`s accepted by `serverChannel`.
4142
init(serverChannel: Channel) {
42-
self.serverChannel = serverChannel
43+
self.eventLoop = serverChannel.eventLoop
44+
self.lifecycleState = .upAndRunning(openChannels: [:], serverChannel: serverChannel)
4345
}
4446

4547
/// Add a channel to the `ChannelCollector`.
@@ -51,30 +53,64 @@ private final class ChannelCollector {
5153
func channelAdded(_ channel: Channel) throws {
5254
self.eventLoop.assertInEventLoop()
5355

54-
guard self.lifecycleState != .shutdownCompleted else {
56+
switch self.lifecycleState {
57+
case .upAndRunning(var openChannels, let serverChannel):
58+
openChannels[ObjectIdentifier(channel)] = channel
59+
self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)
60+
61+
case .shuttingDown(var openChannels, let fullyShutdownPromise):
62+
openChannels[ObjectIdentifier(channel)] = channel
63+
channel.eventLoop.execute {
64+
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
65+
}
66+
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
67+
68+
case .shutdownCompleted:
5569
channel.close(promise: nil)
5670
throw ShutdownError.alreadyShutdown
5771
}
58-
59-
self.openChannels[ObjectIdentifier(channel)] = channel
6072
}
6173

6274
private func shutdownCompleted() {
6375
self.eventLoop.assertInEventLoop()
64-
assert(self.lifecycleState == .shuttingDown)
6576

66-
self.lifecycleState = .shutdownCompleted
67-
self.fullyShutdownPromise?.succeed(())
77+
switch self.lifecycleState {
78+
case .upAndRunning:
79+
preconditionFailure("This can never happen because we transition to shuttingDown first")
80+
81+
case .shuttingDown(_, let fullyShutdownPromise):
82+
self.lifecycleState = .shutdownCompleted
83+
fullyShutdownPromise.succeed(())
84+
85+
case .shutdownCompleted:
86+
preconditionFailure("We should only complete the shutdown once")
87+
}
6888
}
6989

7090
private func channelRemoved0(_ channel: Channel) {
7191
self.eventLoop.assertInEventLoop()
72-
precondition(self.openChannels.keys.contains(ObjectIdentifier(channel)),
73-
"channel \(channel) not in ChannelCollector \(self.openChannels)")
7492

75-
self.openChannels.removeValue(forKey: ObjectIdentifier(channel))
76-
if self.lifecycleState != .upAndRunning && self.openChannels.isEmpty {
77-
shutdownCompleted()
93+
switch self.lifecycleState {
94+
case .upAndRunning(var openChannels, let serverChannel):
95+
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))
96+
97+
precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")
98+
99+
self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)
100+
101+
case .shuttingDown(var openChannels, let fullyShutdownPromise):
102+
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))
103+
104+
precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")
105+
106+
if openChannels.isEmpty {
107+
self.shutdownCompleted()
108+
} else {
109+
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
110+
}
111+
112+
case .shutdownCompleted:
113+
preconditionFailure("We should not have channels removed after transitioned to completed")
78114
}
79115
}
80116

@@ -96,44 +132,39 @@ private final class ChannelCollector {
96132

97133
private func initiateShutdown0(promise: EventLoopPromise<Void>?) {
98134
self.eventLoop.assertInEventLoop()
99-
precondition(self.lifecycleState == .upAndRunning)
100135

101-
self.lifecycleState = .shuttingDown
136+
switch self.lifecycleState {
137+
case .upAndRunning(let openChannels, let serverChannel):
138+
let fullyShutdownPromise = promise ?? serverChannel.eventLoop.makePromise(of: Void.self)
102139

103-
if let promise = promise {
104-
if let alreadyExistingPromise = self.fullyShutdownPromise {
105-
alreadyExistingPromise.futureResult.cascade(to: promise)
106-
} else {
107-
self.fullyShutdownPromise = promise
108-
}
109-
}
140+
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
110141

111-
self.serverChannel.close().cascadeFailure(to: self.fullyShutdownPromise)
142+
serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
143+
serverChannel.close().cascadeFailure(to: fullyShutdownPromise)
112144

113-
for channel in self.openChannels.values {
114-
channel.eventLoop.execute {
115-
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
145+
for channel in openChannels.values {
146+
channel.eventLoop.execute {
147+
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
148+
}
116149
}
117-
}
118150

119-
if self.openChannels.isEmpty {
120-
shutdownCompleted()
151+
if openChannels.isEmpty {
152+
self.shutdownCompleted()
153+
}
154+
155+
case .shuttingDown(_, let fullyShutdownPromise):
156+
fullyShutdownPromise.futureResult.cascade(to: promise)
157+
158+
case .shutdownCompleted:
159+
promise?.succeed(())
121160
}
122161
}
123162

124163
/// Initiate the shutdown fulfilling `promise` when all the previously registered `Channel`s have been closed.
125164
///
126165
/// - parameters:
127-
/// - promise: The `EventLoopPromise` to fulfill when the shutdown of all previously registered `Channel`s has been completed.
166+
/// - promise: The `EventLoopPromise` to fulfil when the shutdown of all previously registered `Channel`s has been completed.
128167
func initiateShutdown(promise: EventLoopPromise<Void>?) {
129-
if self.serverChannel.eventLoop.inEventLoop {
130-
self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
131-
} else {
132-
self.eventLoop.execute {
133-
self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
134-
}
135-
}
136-
137168
if self.eventLoop.inEventLoop {
138169
self.initiateShutdown0(promise: promise)
139170
} else {
@@ -144,7 +175,6 @@ private final class ChannelCollector {
144175
}
145176
}
146177

147-
148178
extension ChannelCollector: @unchecked Sendable {}
149179

150180
/// A `ChannelHandler` that adds all channels that it receives through the `ChannelPipeline` to a `ChannelCollector`.
@@ -173,7 +203,7 @@ private final class CollectAcceptedChannelsHandler: ChannelInboundHandler {
173203
do {
174204
try self.channelCollector.channelAdded(channel)
175205
let closeFuture = channel.closeFuture
176-
closeFuture.whenComplete { (_: Result<(), Error>) in
206+
closeFuture.whenComplete { (_: Result<Void, Error>) in
177207
self.channelCollector.channelRemoved(channel)
178208
}
179209
context.fireChannelRead(data)
@@ -231,7 +261,7 @@ public final class ServerQuiescingHelper {
231261
deinit {
232262
self.channelCollectorPromise.fail(UnusedQuiescingHelperError())
233263
}
234-
264+
235265
/// Create the `ChannelHandler` for the server `channel` to collect all accepted child `Channel`s.
236266
///
237267
/// - parameters:
@@ -262,6 +292,4 @@ public final class ServerQuiescingHelper {
262292
}
263293
}
264294

265-
extension ServerQuiescingHelper: Sendable {
266-
267-
}
295+
extension ServerQuiescingHelper: Sendable {}

Tests/NIOExtrasTests/QuiescingHelperTest+XCTest.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ extension QuiescingHelperTest {
3131
("testQuiesceUserEventReceivedOnShutdown", testQuiesceUserEventReceivedOnShutdown),
3232
("testQuiescingDoesNotSwallowCloseErrorsFromAcceptHandler", testQuiescingDoesNotSwallowCloseErrorsFromAcceptHandler),
3333
("testShutdownIsImmediateWhenPromiseDoesNotSucceed", testShutdownIsImmediateWhenPromiseDoesNotSucceed),
34+
("testShutdown_whenAlreadyShutdown", testShutdown_whenAlreadyShutdown),
35+
("testShutdown_whenNoOpenChild", testShutdown_whenNoOpenChild),
36+
("testChannelClose_whenRunning", testChannelClose_whenRunning),
37+
("testChannelAdded_whenShuttingDown", testChannelAdded_whenShuttingDown),
38+
("testChannelAdded_whenShutdown", testChannelAdded_whenShutdown),
3439
]
3540
}
3641
}

0 commit comments

Comments
 (0)