Skip to content

Commit 47e0be1

Browse files
authored
Fix keepalive logic (#50)
This PR fixes grpc/grpc-swift#2095. ## Motivation As per the gRPC specification, the server must keep track of pings from each client, and if they go over a threshold, we must send a GOAWAY frame and close the connection. We must reset the number of ping strikes every time the server writes a headers or data frame. However, there is a bug in the current keepalive implementation and we are not properly keeping track of when header/data frames are written, so we never reset the strikes, causing the server to always end up closing connections when keepalive pings are enabled. There was also a second bug where the GOAWAY frame wasn't actually sent to the client because we were closing the connection straight away, and the packet never made it out. ## Modifications This PR fixes a couple of bugs: - It keeps track of the appropriate FrameStats as described above - It delays the channel close after sending the GOAWAY packet by a tick to make sure it gets flushed and delivered to the client ## Results Fewer bugs!
1 parent 289c0bc commit 47e0be1

File tree

6 files changed

+136
-12
lines changed

6 files changed

+136
-12
lines changed

Sources/GRPCNIOTransportCore/Internal/NIOChannelPipeline+GRPC.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ extension ChannelPipeline.SynchronousOperations {
7171
var http2HandlerStreamConfiguration = NIOHTTP2Handler.StreamConfiguration()
7272
http2HandlerStreamConfiguration.targetWindowSize = clampedTargetWindowSize
7373

74+
let boundConnectionManagementHandler = NIOLoopBound(
75+
serverConnectionHandler.syncView,
76+
eventLoop: self.eventLoop
77+
)
7478
let streamMultiplexer = try self.configureAsyncHTTP2Pipeline(
7579
mode: .server,
7680
streamDelegate: serverConnectionHandler.http2StreamDelegate,
@@ -86,7 +90,8 @@ extension ChannelPipeline.SynchronousOperations {
8690
acceptedEncodings: compressionConfig.enabledAlgorithms,
8791
maxPayloadSize: rpcConfig.maxRequestPayloadSize,
8892
methodDescriptorPromise: methodDescriptorPromise,
89-
eventLoop: streamChannel.eventLoop
93+
eventLoop: streamChannel.eventLoop,
94+
connectionManagementHandler: boundConnectionManagementHandler.value
9095
)
9196
try streamChannel.pipeline.syncOperations.addHandler(streamHandler)
9297

Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
121121
}
122122

123123
/// Stats about recently written frames. Used to determine whether to reset keep-alive state.
124-
private var frameStats: FrameStats
124+
package var frameStats: FrameStats
125125

126-
struct FrameStats {
126+
package struct FrameStats {
127127
private(set) var didWriteHeadersOrData = false
128128

129129
/// Mark that a HEADERS frame has been written.
@@ -609,7 +609,13 @@ extension ServerConnectionManagementHandler {
609609

610610
context.write(self.wrapOutboundOut(goAway), promise: nil)
611611
self.maybeFlush(context: context)
612-
context.close(promise: nil)
612+
613+
// We must delay the channel close after sending the GOAWAY packet by a tick to make sure it
614+
// gets flushed and delivered to the client before the connection is closed.
615+
let loopBound = NIOLoopBound(context, eventLoop: context.eventLoop)
616+
context.eventLoop.execute {
617+
loopBound.value.close(promise: nil)
618+
}
613619

614620
case .sendAck:
615621
() // ACKs are sent by NIO's HTTP/2 handler, don't double ack.

Sources/GRPCNIOTransportCore/Server/GRPCServerStreamHandler.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan
4242

4343
private var cancellationHandle: Optional<ServerContext.RPCCancellationHandle>
4444

45+
package let connectionManagementHandler: ServerConnectionManagementHandler.SyncView
46+
4547
// Existential errors unconditionally allocate, avoid this per-use allocation by doing it
4648
// statically.
4749
private static let handlerRemovedBeforeDescriptorResolved: any Error = RPCError(
@@ -55,6 +57,7 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan
5557
maxPayloadSize: Int,
5658
methodDescriptorPromise: EventLoopPromise<MethodDescriptor>,
5759
eventLoop: any EventLoop,
60+
connectionManagementHandler: ServerConnectionManagementHandler.SyncView,
5861
cancellationHandler: ServerContext.RPCCancellationHandle? = nil,
5962
skipStateMachineAssertions: Bool = false
6063
) {
@@ -66,6 +69,7 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan
6669
self.methodDescriptorPromise = methodDescriptorPromise
6770
self.cancellationHandle = cancellationHandler
6871
self.eventLoop = eventLoop
72+
self.connectionManagementHandler = connectionManagementHandler
6973
}
7074

7175
package func setCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) {
@@ -136,13 +140,16 @@ extension GRPCServerStreamHandler {
136140
switch self.stateMachine.nextInboundMessage() {
137141
case .receiveMessage(let message):
138142
context.fireChannelRead(self.wrapInboundOut(.message(message)))
143+
139144
case .awaitMoreMessages:
140145
break loop
146+
141147
case .noMoreMessages:
142148
context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
143149
break loop
144150
}
145151
}
152+
146153
case .doNothing:
147154
()
148155
}
@@ -261,6 +268,7 @@ extension GRPCServerStreamHandler {
261268
self.flushPending = true
262269
let headers = try self.stateMachine.send(metadata: metadata)
263270
context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
271+
self.connectionManagementHandler.wroteHeadersFrame()
264272
} catch let invalidState {
265273
let error = RPCError(invalidState)
266274
promise?.fail(error)
@@ -270,6 +278,7 @@ extension GRPCServerStreamHandler {
270278
case .message(let message):
271279
do {
272280
try self.stateMachine.send(message: message, promise: promise)
281+
self.connectionManagementHandler.wroteDataFrame()
273282
} catch let invalidState {
274283
let error = RPCError(invalidState)
275284
promise?.fail(error)

Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/ConnectionTest.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,24 @@ extension ConnectionTest {
114114
let h2 = NIOHTTP2Handler(mode: .server)
115115
let mux = HTTP2StreamMultiplexer(mode: .server, channel: channel) { stream in
116116
let sync = stream.pipeline.syncOperations
117+
let connectionManagementHandler = ServerConnectionManagementHandler(
118+
eventLoop: stream.eventLoop,
119+
maxIdleTime: nil,
120+
maxAge: nil,
121+
maxGraceTime: nil,
122+
keepaliveTime: nil,
123+
keepaliveTimeout: nil,
124+
allowKeepaliveWithoutCalls: false,
125+
minPingIntervalWithoutCalls: .minutes(5),
126+
requireALPN: false
127+
)
117128
let handler = GRPCServerStreamHandler(
118129
scheme: .http,
119130
acceptedEncodings: .none,
120131
maxPayloadSize: .max,
121132
methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self),
122-
eventLoop: stream.eventLoop
133+
eventLoop: stream.eventLoop,
134+
connectionManagementHandler: connectionManagementHandler.syncView
123135
)
124136

125137
return stream.eventLoop.makeCompletedFuture {

Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/TestServer.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,24 @@ final class TestServer: Sendable {
7070
let sync = channel.pipeline.syncOperations
7171
let multiplexer = try sync.configureAsyncHTTP2Pipeline(mode: .server) { stream in
7272
stream.eventLoop.makeCompletedFuture {
73+
let connectionManagementHandler = ServerConnectionManagementHandler(
74+
eventLoop: stream.eventLoop,
75+
maxIdleTime: nil,
76+
maxAge: nil,
77+
maxGraceTime: nil,
78+
keepaliveTime: nil,
79+
keepaliveTimeout: nil,
80+
allowKeepaliveWithoutCalls: false,
81+
minPingIntervalWithoutCalls: .minutes(5),
82+
requireALPN: false
83+
)
7384
let handler = GRPCServerStreamHandler(
7485
scheme: .http,
7586
acceptedEncodings: .all,
7687
maxPayloadSize: .max,
7788
methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self),
78-
eventLoop: stream.eventLoop
89+
eventLoop: stream.eventLoop,
90+
connectionManagementHandler: connectionManagementHandler.syncView
7991
)
8092

8193
try stream.pipeline.syncOperations.addHandlers(handler)

Tests/GRPCNIOTransportCoreTests/Server/GRPCServerStreamHandlerTests.swift

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,25 @@ final class GRPCServerStreamHandlerTests: XCTestCase {
3333
descriptorPromise: EventLoopPromise<MethodDescriptor>? = nil,
3434
disableAssertions: Bool = false
3535
) -> GRPCServerStreamHandler {
36+
let serverConnectionManagementHandler = ServerConnectionManagementHandler(
37+
eventLoop: channel.eventLoop,
38+
maxIdleTime: nil,
39+
maxAge: nil,
40+
maxGraceTime: nil,
41+
keepaliveTime: nil,
42+
keepaliveTimeout: nil,
43+
allowKeepaliveWithoutCalls: false,
44+
minPingIntervalWithoutCalls: .minutes(5),
45+
requireALPN: false
46+
)
47+
3648
return GRPCServerStreamHandler(
3749
scheme: scheme,
3850
acceptedEncodings: acceptedEncodings,
3951
maxPayloadSize: maxPayloadSize,
4052
methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(),
4153
eventLoop: channel.eventLoop,
54+
connectionManagementHandler: serverConnectionManagementHandler.syncView,
4255
skipStateMachineAssertions: disableAssertions
4356
)
4457
}
@@ -974,28 +987,50 @@ final class GRPCServerStreamHandlerTests: XCTestCase {
974987
}
975988

976989
struct ServerStreamHandlerTests {
977-
private func makeServerStreamHandler(
990+
struct ConnectionAndStreamHandlers {
991+
let streamHandler: GRPCServerStreamHandler
992+
let connectionHandler: ServerConnectionManagementHandler
993+
}
994+
995+
private func makeServerConnectionAndStreamHandlers(
978996
channel: any Channel,
979997
scheme: Scheme = .http,
980998
acceptedEncodings: CompressionAlgorithmSet = [],
981999
maxPayloadSize: Int = .max,
9821000
descriptorPromise: EventLoopPromise<MethodDescriptor>? = nil,
9831001
disableAssertions: Bool = false
984-
) -> GRPCServerStreamHandler {
985-
return GRPCServerStreamHandler(
1002+
) -> ConnectionAndStreamHandlers {
1003+
let connectionManagementHandler = ServerConnectionManagementHandler(
1004+
eventLoop: channel.eventLoop,
1005+
maxIdleTime: nil,
1006+
maxAge: nil,
1007+
maxGraceTime: nil,
1008+
keepaliveTime: nil,
1009+
keepaliveTimeout: nil,
1010+
allowKeepaliveWithoutCalls: false,
1011+
minPingIntervalWithoutCalls: .minutes(5),
1012+
requireALPN: false
1013+
)
1014+
let streamHandler = GRPCServerStreamHandler(
9861015
scheme: scheme,
9871016
acceptedEncodings: acceptedEncodings,
9881017
maxPayloadSize: maxPayloadSize,
9891018
methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(),
9901019
eventLoop: channel.eventLoop,
1020+
connectionManagementHandler: connectionManagementHandler.syncView,
9911021
skipStateMachineAssertions: disableAssertions
9921022
)
1023+
1024+
return ConnectionAndStreamHandlers(
1025+
streamHandler: streamHandler,
1026+
connectionHandler: connectionManagementHandler
1027+
)
9931028
}
9941029

9951030
@Test("ChannelShouldQuiesceEvent is buffered and turns into RPC cancellation")
9961031
func shouldQuiesceEventIsBufferedBeforeHandleIsSet() async throws {
9971032
let channel = EmbeddedChannel()
998-
let handler = self.makeServerStreamHandler(channel: channel)
1033+
let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
9991034
try channel.pipeline.syncOperations.addHandler(handler)
10001035
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
10011036

@@ -1011,7 +1046,7 @@ struct ServerStreamHandlerTests {
10111046
@Test("ChannelShouldQuiesceEvent turns into RPC cancellation")
10121047
func shouldQuiesceEventTriggersCancellation() async throws {
10131048
let channel = EmbeddedChannel()
1014-
let handler = self.makeServerStreamHandler(channel: channel)
1049+
let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
10151050
try channel.pipeline.syncOperations.addHandler(handler)
10161051

10171052
await withServerContextRPCCancellationHandle { handle in
@@ -1028,7 +1063,7 @@ struct ServerStreamHandlerTests {
10281063
@Test("RST_STREAM turns into RPC cancellation")
10291064
func rstStreamTriggersCancellation() async throws {
10301065
let channel = EmbeddedChannel()
1031-
let handler = self.makeServerStreamHandler(channel: channel)
1066+
let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
10321067
try channel.pipeline.syncOperations.addHandler(handler)
10331068

10341069
await withServerContextRPCCancellationHandle { handle in
@@ -1045,6 +1080,51 @@ struct ServerStreamHandlerTests {
10451080
_ = try? channel.finish()
10461081
}
10471082

1083+
@Test("Connection FrameStats are updated when writing headers or data frames")
1084+
func connectionFrameStatsAreUpdatedAccordingly() async throws {
1085+
let channel = EmbeddedChannel()
1086+
let handlers = self.makeServerConnectionAndStreamHandlers(channel: channel)
1087+
try channel.pipeline.syncOperations.addHandler(handlers.streamHandler)
1088+
1089+
// We have written nothing yet, so expect FrameStats/didWriteHeadersOrData to be false
1090+
#expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)
1091+
1092+
// FrameStats aren't affected by pings received
1093+
channel.pipeline.fireChannelRead(
1094+
NIOAny(HTTP2Frame.FramePayload.ping(.init(withInteger: 42), ack: false))
1095+
)
1096+
#expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)
1097+
1098+
// Now write back headers and make sure FrameStats are updated accordingly:
1099+
// To do that, we first need to receive client's initial metadata...
1100+
let clientInitialMetadata: HPACKHeaders = [
1101+
GRPCHTTP2Keys.path.rawValue: "/SomeService/SomeMethod",
1102+
GRPCHTTP2Keys.scheme.rawValue: "http",
1103+
GRPCHTTP2Keys.method.rawValue: "POST",
1104+
GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
1105+
GRPCHTTP2Keys.te.rawValue: "trailers",
1106+
]
1107+
try channel.writeInbound(
1108+
HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata))
1109+
)
1110+
1111+
// Now we write back server's initial metadata...
1112+
let serverInitialMetadata = RPCResponsePart.metadata([:])
1113+
try channel.writeOutbound(serverInitialMetadata)
1114+
1115+
// And this should have updated the FrameStats
1116+
#expect(handlers.connectionHandler.frameStats.didWriteHeadersOrData)
1117+
1118+
// Manually reset the FrameStats to make sure that writing data also updates it correctly.
1119+
handlers.connectionHandler.frameStats.reset()
1120+
#expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)
1121+
try channel.writeOutbound(RPCResponsePart.message([42]))
1122+
#expect(handlers.connectionHandler.frameStats.didWriteHeadersOrData)
1123+
1124+
// Clean up.
1125+
// Throwing is fine: the channel is closed abruptly, errors are expected.
1126+
_ = try? channel.finish()
1127+
}
10481128
}
10491129

10501130
extension EmbeddedChannel {

0 commit comments

Comments
 (0)