Skip to content

Commit 858f977

Browse files
authored
Merge pull request from GHSA-r6ww-5963-7r95
Better handle client sending GOAWAY
2 parents 7648750 + e09cf66 commit 858f977

7 files changed

+139
-17
lines changed

Sources/GRPC/GRPCIdleHandler.swift

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,19 @@ internal final class GRPCIdleHandler: ChannelInboundHandler {
153153
streamID: .rootStream,
154154
payload: .goAway(lastStreamID: streamID, errorCode: .noError, opaqueData: nil)
155155
)
156-
self.context?.writeAndFlush(self.wrapOutboundOut(goAwayFrame), promise: nil)
156+
157+
self.context?.write(self.wrapOutboundOut(goAwayFrame), promise: nil)
158+
159+
// We emit a ping after some GOAWAY frames.
160+
if operations.shouldPingAfterGoAway {
161+
let pingFrame = HTTP2Frame(
162+
streamID: .rootStream,
163+
payload: .ping(self.pingHandler.pingDataGoAway, ack: false)
164+
)
165+
self.context?.write(self.wrapOutboundOut(pingFrame), promise: nil)
166+
}
167+
168+
self.context?.flush()
157169
}
158170

159171
// Close the channel, if necessary.
@@ -181,6 +193,9 @@ internal final class GRPCIdleHandler: ChannelInboundHandler {
181193
case let .reply(framePayload):
182194
let frame = HTTP2Frame(streamID: .rootStream, payload: framePayload)
183195
self.context?.writeAndFlush(self.wrapOutboundOut(frame), promise: nil)
196+
197+
case .ratchetDownLastSeenStreamID:
198+
self.perform(operations: self.stateMachine.ratchetDownGoAwayStreamID())
184199
}
185200
}
186201

Sources/GRPC/GRPCIdleHandlerStateMachine.swift

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,17 @@ struct GRPCIdleHandlerStateMachine {
189189
/// Whether the channel should be closed.
190190
private(set) var shouldCloseChannel: Bool
191191

192+
/// Whether a ping should be sent after a GOAWAY frame.
193+
private(set) var shouldPingAfterGoAway: Bool
194+
192195
fileprivate static let none = Operations()
193196

194-
fileprivate mutating func sendGoAwayFrame(lastPeerInitiatedStreamID streamID: HTTP2StreamID) {
197+
fileprivate mutating func sendGoAwayFrame(
198+
lastPeerInitiatedStreamID streamID: HTTP2StreamID,
199+
followWithPing: Bool = false
200+
) {
195201
self.sendGoAwayWithLastPeerInitiatedStreamID = streamID
202+
self.shouldPingAfterGoAway = followWithPing
196203
}
197204

198205
fileprivate mutating func cancelIdleTask(_ task: Scheduled<Void>) {
@@ -220,6 +227,7 @@ struct GRPCIdleHandlerStateMachine {
220227
self.idleTask = nil
221228
self.sendGoAwayWithLastPeerInitiatedStreamID = nil
222229
self.shouldCloseChannel = false
230+
self.shouldPingAfterGoAway = false
223231
}
224232
}
225233

@@ -267,12 +275,7 @@ struct GRPCIdleHandlerStateMachine {
267275
operations.cancelIdleTask(state.idleTask)
268276

269277
case var .quiescing(state):
270-
precondition(state.initiatedByUs)
271-
precondition(state.role == .client)
272-
// If we're a client and we initiated shutdown then it's possible for streams to be created in
273-
// the quiescing state as there's a delay between stream channels (i.e. `HTTP2StreamChannel`)
274-
// being created and us being notified about their creation (via a user event fired by
275-
// the `HTTP2Handler`).
278+
state.lastPeerInitiatedStreamID = streamID
276279
state.openStreams += 1
277280
self.state = .quiescing(state)
278281

@@ -466,6 +469,18 @@ struct GRPCIdleHandlerStateMachine {
466469

467470
if state.hasOpenStreams {
468471
operations.notifyConnectionManager(about: .quiescing)
472+
switch state.role {
473+
case .client:
474+
// The server sent us a GOAWAY we'll just stop opening new streams and will send a GOAWAY
475+
// frame before we close later.
476+
()
477+
case .server:
478+
// Client sent us a GOAWAY frame; we'll let the streams drain and then close. We'll tell
479+
// the client that we're going away and send them a ping. When we receive the pong we will
480+
// send another GOAWAY frame with a lower stream ID. In this case, the pong acts as an ack
481+
// for the GOAWAY.
482+
operations.sendGoAwayFrame(lastPeerInitiatedStreamID: .maxID, followWithPing: true)
483+
}
469484
self.state = .quiescing(.init(fromOperating: state, initiatedByUs: false))
470485
} else {
471486
// No open streams, we can close as well.
@@ -494,6 +509,23 @@ struct GRPCIdleHandlerStateMachine {
494509
return operations
495510
}
496511

512+
mutating func ratchetDownGoAwayStreamID() -> Operations {
513+
var operations: Operations = .none
514+
515+
switch self.state {
516+
case let .quiescing(state):
517+
let streamID = state.lastPeerInitiatedStreamID
518+
operations.sendGoAwayFrame(lastPeerInitiatedStreamID: streamID)
519+
case .operating, .waitingToIdle:
520+
// We can only ratchet down the stream ID if we're already quiescing.
521+
preconditionFailure()
522+
case .closing, .closed:
523+
()
524+
}
525+
526+
return operations
527+
}
528+
497529
mutating func receiveSettings(_ settings: HTTP2Settings) -> Operations {
498530
// Log the change in settings.
499531
self.logger.debug(

Sources/GRPC/GRPCKeepaliveHandlers.swift

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ import NIOCore
1717
import NIOHTTP2
1818

1919
struct PingHandler {
20-
/// Code for ping
21-
private let pingCode: UInt64
20+
/// Opaque ping data used for keep-alive pings.
21+
private let pingData: HTTP2PingData
22+
23+
/// Opaque ping data used for a ping sent after a GOAWAY frame.
24+
internal let pingDataGoAway: HTTP2PingData
2225

2326
/// The amount of time to wait before sending a keepalive ping.
2427
private let interval: TimeAmount
@@ -90,6 +93,7 @@ struct PingHandler {
9093
case schedulePing(delay: TimeAmount, timeout: TimeAmount)
9194
case cancelScheduledTimeout
9295
case reply(HTTP2Frame.FramePayload)
96+
case ratchetDownLastSeenStreamID
9397
}
9498

9599
init(
@@ -102,7 +106,8 @@ struct PingHandler {
102106
minimumReceivedPingIntervalWithoutData: TimeAmount? = nil,
103107
maximumPingStrikes: UInt? = nil
104108
) {
105-
self.pingCode = pingCode
109+
self.pingData = HTTP2PingData(withInteger: pingCode)
110+
self.pingDataGoAway = HTTP2PingData(withInteger: ~pingCode)
106111
self.interval = interval
107112
self.timeout = timeout
108113
self.permitWithoutCalls = permitWithoutCalls
@@ -137,8 +142,12 @@ struct PingHandler {
137142
}
138143

139144
private func handlePong(_ pingData: HTTP2PingData) -> Action {
140-
if pingData.integer == self.pingCode {
145+
if pingData == self.pingData {
141146
return .cancelScheduledTimeout
147+
} else if pingData == self.pingDataGoAway {
148+
// We received a pong for a ping we sent to trail a GOAWAY frame: this means we can now
149+
// send another GOAWAY frame with a (possibly) lower stream ID.
150+
return .ratchetDownLastSeenStreamID
142151
} else {
143152
return .none
144153
}
@@ -161,32 +170,35 @@ struct PingHandler {
161170
// This is a valid ping, reset our strike count and reply with a pong.
162171
self.pingStrikes = 0
163172
self.lastReceivedPingDate = self.now()
164-
return .reply(self.generatePingFrame(code: pingData.integer, ack: true))
173+
return .reply(self.generatePingFrame(data: pingData, ack: true))
165174
}
166175
} else {
167176
// We don't support ping strikes. We'll just reply with a pong.
168177
//
169178
// Note: we don't need to update `pingStrikes` or `lastReceivedPingDate` as we don't
170179
// support ping strikes.
171-
return .reply(self.generatePingFrame(code: pingData.integer, ack: true))
180+
return .reply(self.generatePingFrame(data: pingData, ack: true))
172181
}
173182
}
174183

175184
mutating func pingFired() -> Action {
176185
if self.shouldBlockPing {
177186
return .none
178187
} else {
179-
return .reply(self.generatePingFrame(code: self.pingCode, ack: false))
188+
return .reply(self.generatePingFrame(data: self.pingData, ack: false))
180189
}
181190
}
182191

183-
private mutating func generatePingFrame(code: UInt64, ack: Bool) -> HTTP2Frame.FramePayload {
192+
private mutating func generatePingFrame(
193+
data: HTTP2PingData,
194+
ack: Bool
195+
) -> HTTP2Frame.FramePayload {
184196
if self.activeStreams == 0 {
185197
self.sentPingsWithoutData += 1
186198
}
187199

188200
self.lastSentPingDate = self.now()
189-
return HTTP2Frame.FramePayload.ping(HTTP2PingData(withInteger: code), ack: ack)
201+
return HTTP2Frame.FramePayload.ping(data, ack: ack)
190202
}
191203

192204
/// Returns true if, on receipt of a ping, the ping should be regarded as a ping strike.

Tests/GRPCTests/GRPCIdleHandlerStateMachineTests.swift

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ class GRPCIdleHandlerStateMachineTests: GRPCTestCase {
2424
return GRPCIdleHandlerStateMachine(role: .client, logger: self.clientLogger)
2525
}
2626

27+
private func makeServerStateMachine() -> GRPCIdleHandlerStateMachine {
28+
return GRPCIdleHandlerStateMachine(role: .server, logger: self.serverLogger)
29+
}
30+
2731
private func makeNoOpScheduled() -> Scheduled<Void> {
2832
let loop = EmbeddedEventLoop()
2933
return loop.scheduleTask(deadline: .distantFuture) { return () }
@@ -469,6 +473,43 @@ class GRPCIdleHandlerStateMachineTests: GRPCTestCase {
469473
// The peer initiated shutdown by sending GOAWAY, we'll idle.
470474
op6.assertConnectionManager(.idle)
471475
}
476+
477+
func testClientSendsGoAwayAndOpensStream() {
478+
var stateMachine = self.makeServerStateMachine()
479+
480+
let op1 = stateMachine.receiveSettings([])
481+
op1.assertConnectionManager(.ready)
482+
op1.assertScheduleIdleTimeout()
483+
484+
// Schedule the idle timeout.
485+
let op2 = stateMachine.scheduledIdleTimeoutTask(self.makeNoOpScheduled())
486+
op2.assertDoNothing()
487+
488+
// Create a stream to cancel the task.
489+
let op3 = stateMachine.streamCreated(withID: 1)
490+
op3.assertCancelIdleTimeout()
491+
492+
// Receive a GOAWAY frame from the client.
493+
let op4 = stateMachine.receiveGoAway()
494+
op4.assertGoAway(streamID: .maxID)
495+
op4.assertShouldPingAfterGoAway()
496+
497+
// Create another stream. This is fine, the client hasn't ack'd the ping yet.
498+
let op5 = stateMachine.streamCreated(withID: 7)
499+
op5.assertDoNothing()
500+
501+
// Receiving the ping is handled by a different state machine which will tell us to ratchet
502+
// down the go away stream ID.
503+
let op6 = stateMachine.ratchetDownGoAwayStreamID()
504+
op6.assertGoAway(streamID: 7)
505+
op6.assertShouldNotPingAfterGoAway()
506+
507+
let op7 = stateMachine.streamClosed(withID: 7)
508+
op7.assertDoNothing()
509+
510+
let op8 = stateMachine.streamClosed(withID: 1)
511+
op8.assertShouldClose()
512+
}
472513
}
473514

474515
extension GRPCIdleHandlerStateMachine.Operations {
@@ -477,6 +518,7 @@ extension GRPCIdleHandlerStateMachine.Operations {
477518
XCTAssertNil(self.idleTask)
478519
XCTAssertNil(self.sendGoAwayWithLastPeerInitiatedStreamID)
479520
XCTAssertFalse(self.shouldCloseChannel)
521+
XCTAssertFalse(self.shouldPingAfterGoAway)
480522
}
481523

482524
func assertGoAway(streamID: HTTP2StreamID) {
@@ -524,4 +566,12 @@ extension GRPCIdleHandlerStateMachine.Operations {
524566
func assertShouldNotClose() {
525567
XCTAssertFalse(self.shouldCloseChannel)
526568
}
569+
570+
func assertShouldPingAfterGoAway() {
571+
XCTAssert(self.shouldPingAfterGoAway)
572+
}
573+
574+
func assertShouldNotPingAfterGoAway() {
575+
XCTAssertFalse(self.shouldPingAfterGoAway)
576+
}
527577
}

Tests/GRPCTests/GRPCPingHandlerTests.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,12 @@ class GRPCPingHandlerTests: GRPCTestCase {
347347
)
348348
}
349349

350+
func testPongWithGoAwayPingData() {
351+
self.setupPingHandler()
352+
let response = self.pingHandler.read(pingData: self.pingHandler.pingDataGoAway, ack: true)
353+
XCTAssertEqual(response, .ratchetDownLastSeenStreamID)
354+
}
355+
350356
private func setupPingHandler(
351357
pingCode: UInt64 = 1,
352358
interval: TimeAmount = .seconds(15),
@@ -379,6 +385,8 @@ extension PingHandler.Action: Equatable {
379385
return lhsDelay == rhsDelay && lhsTimeout == rhsTimeout
380386
case (.cancelScheduledTimeout, .cancelScheduledTimeout):
381387
return true
388+
case (.ratchetDownLastSeenStreamID, .ratchetDownLastSeenStreamID):
389+
return true
382390
case let (.reply(lhsPayload), .reply(rhsPayload)):
383391
switch (lhsPayload, rhsPayload) {
384392
case (let .ping(lhsData, ack: lhsAck), let .ping(rhsData, ack: rhsAck)):

Tests/GRPCTests/ServerFuzzingRegressionTests.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,9 @@ final class ServerFuzzingRegressionTests: GRPCTestCase {
8383
let name = "clusterfuzz-testcase-minimized-ServerFuzzer-release-5285159577452544"
8484
XCTAssertNoThrow(try self.runTest(withInputNamed: name))
8585
}
86+
87+
func testFuzzCase_release_4739158818553856() {
88+
let name = "clusterfuzz-testcase-minimized-ServerFuzzer-release-4739158818553856"
89+
XCTAssertNoThrow(try self.runTest(withInputNamed: name))
90+
}
8691
}

0 commit comments

Comments
 (0)