Skip to content

Commit 847a934

Browse files
authored
Add GRPCClientStreamHandler (#1838)
1 parent 9c91043 commit 847a934

File tree

4 files changed

+995
-15
lines changed

4 files changed

+995
-15
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
/*
2+
* Copyright 2024, gRPC Authors All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
import GRPCCore
18+
import NIOCore
19+
import NIOHTTP2
20+
21+
@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
22+
final class GRPCClientStreamHandler: ChannelDuplexHandler {
23+
typealias InboundIn = HTTP2Frame.FramePayload
24+
typealias InboundOut = RPCResponsePart
25+
26+
typealias OutboundIn = RPCRequestPart
27+
typealias OutboundOut = HTTP2Frame.FramePayload
28+
29+
private var stateMachine: GRPCStreamStateMachine
30+
31+
private var isReading = false
32+
private var flushPending = false
33+
34+
init(
35+
methodDescriptor: MethodDescriptor,
36+
scheme: Scheme,
37+
outboundEncoding: CompressionAlgorithm,
38+
acceptedEncodings: [CompressionAlgorithm],
39+
maximumPayloadSize: Int,
40+
skipStateMachineAssertions: Bool = false
41+
) {
42+
self.stateMachine = .init(
43+
configuration: .client(
44+
.init(
45+
methodDescriptor: methodDescriptor,
46+
scheme: scheme,
47+
outboundEncoding: outboundEncoding,
48+
acceptedEncodings: acceptedEncodings
49+
)
50+
),
51+
maximumPayloadSize: maximumPayloadSize,
52+
skipAssertions: skipStateMachineAssertions
53+
)
54+
}
55+
}
56+
57+
// - MARK: ChannelInboundHandler
58+
59+
@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
60+
extension GRPCClientStreamHandler {
61+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
62+
self.isReading = true
63+
let frame = self.unwrapInboundIn(data)
64+
switch frame {
65+
case .data(let frameData):
66+
let endStream = frameData.endStream
67+
switch frameData.data {
68+
case .byteBuffer(let buffer):
69+
do {
70+
try self.stateMachine.receive(buffer: buffer, endStream: endStream)
71+
loop: while true {
72+
switch self.stateMachine.nextInboundMessage() {
73+
case .receiveMessage(let message):
74+
context.fireChannelRead(self.wrapInboundOut(.message(message)))
75+
case .awaitMoreMessages:
76+
break loop
77+
case .noMoreMessages:
78+
context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
79+
break loop
80+
}
81+
}
82+
} catch {
83+
context.fireErrorCaught(error)
84+
}
85+
86+
case .fileRegion:
87+
preconditionFailure("Unexpected IOData.fileRegion")
88+
}
89+
90+
case .headers(let headers):
91+
do {
92+
let action = try self.stateMachine.receive(
93+
headers: headers.headers,
94+
endStream: headers.endStream
95+
)
96+
switch action {
97+
case .receivedMetadata(let metadata):
98+
context.fireChannelRead(self.wrapInboundOut(.metadata(metadata)))
99+
100+
case .rejectRPC:
101+
throw RPCError(
102+
code: .internalError,
103+
message: "Client cannot get rejectRPC."
104+
)
105+
106+
case .receivedStatusAndMetadata(let status, let metadata):
107+
context.fireChannelRead(self.wrapInboundOut(.status(status, metadata)))
108+
109+
case .doNothing:
110+
()
111+
}
112+
} catch {
113+
context.fireErrorCaught(error)
114+
}
115+
116+
case .ping, .goAway, .priority, .rstStream, .settings, .pushPromise, .windowUpdate,
117+
.alternativeService, .origin:
118+
()
119+
}
120+
}
121+
122+
func channelReadComplete(context: ChannelHandlerContext) {
123+
self.isReading = false
124+
if self.flushPending {
125+
self.flushPending = false
126+
self.flush(context: context)
127+
}
128+
context.fireChannelReadComplete()
129+
}
130+
131+
func handlerRemoved(context: ChannelHandlerContext) {
132+
self.stateMachine.tearDown()
133+
}
134+
}
135+
136+
// - MARK: ChannelOutboundHandler
137+
138+
@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
139+
extension GRPCClientStreamHandler {
140+
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
141+
switch self.unwrapOutboundIn(data) {
142+
case .metadata(let metadata):
143+
do {
144+
self.flushPending = true
145+
let headers = try self.stateMachine.send(metadata: metadata)
146+
context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: nil)
147+
// TODO: move the promise handling into the state machine
148+
promise?.succeed()
149+
} catch {
150+
context.fireErrorCaught(error)
151+
// TODO: move the promise handling into the state machine
152+
promise?.fail(error)
153+
}
154+
155+
case .message(let message):
156+
do {
157+
try self.stateMachine.send(message: message)
158+
// TODO: move the promise handling into the state machine
159+
promise?.succeed()
160+
} catch {
161+
context.fireErrorCaught(error)
162+
// TODO: move the promise handling into the state machine
163+
promise?.fail(error)
164+
}
165+
}
166+
}
167+
168+
func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
169+
switch mode {
170+
case .output, .all:
171+
do {
172+
try self.stateMachine.closeOutbound()
173+
// Force a flush by calling _flush
174+
// (otherwise, we'd skip flushing if we're in a read loop)
175+
self._flush(context: context)
176+
context.close(mode: mode, promise: promise)
177+
} catch {
178+
promise?.fail(error)
179+
context.fireErrorCaught(error)
180+
}
181+
182+
case .input:
183+
context.close(mode: .input, promise: promise)
184+
}
185+
}
186+
187+
func flush(context: ChannelHandlerContext) {
188+
if self.isReading {
189+
// We don't want to flush yet if we're still in a read loop.
190+
self.flushPending = true
191+
return
192+
}
193+
194+
self._flush(context: context)
195+
}
196+
197+
private func _flush(context: ChannelHandlerContext) {
198+
do {
199+
loop: while true {
200+
switch try self.stateMachine.nextOutboundMessage() {
201+
case .sendMessage(let byteBuffer):
202+
self.flushPending = true
203+
context.write(
204+
self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))),
205+
promise: nil
206+
)
207+
208+
case .noMoreMessages:
209+
// Write an empty data frame with the EOS flag set, to signal the RPC
210+
// request is now finished.
211+
context.write(
212+
self.wrapOutboundOut(
213+
HTTP2Frame.FramePayload.data(
214+
.init(
215+
data: .byteBuffer(.init()),
216+
endStream: true
217+
)
218+
)
219+
),
220+
promise: nil
221+
)
222+
223+
context.flush()
224+
break loop
225+
226+
case .awaitMoreMessages:
227+
if self.flushPending {
228+
self.flushPending = false
229+
context.flush()
230+
}
231+
break loop
232+
}
233+
}
234+
} catch {
235+
context.fireErrorCaught(error)
236+
}
237+
}
238+
}

Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,12 @@ struct GRPCStreamStateMachine {
373373

374374
mutating func receive(headers: HPACKHeaders, endStream: Bool) throws -> OnMetadataReceived {
375375
switch self.configuration {
376-
case .client:
377-
return try self.clientReceive(headers: headers, endStream: endStream)
376+
case .client(let clientConfiguration):
377+
return try self.clientReceive(
378+
headers: headers,
379+
endStream: endStream,
380+
configuration: clientConfiguration
381+
)
378382
case .server(let serverConfiguration):
379383
return try self.serverReceive(
380384
headers: headers,
@@ -567,9 +571,7 @@ extension GRPCStreamStateMachine {
567571
case .clientOpenServerClosed(let state):
568572
self.state = .clientClosedServerClosed(.init(previousState: state))
569573
case .clientClosedServerIdle, .clientClosedServerOpen, .clientClosedServerClosed:
570-
try self.invalidState(
571-
"Client is closed, cannot send a message."
572-
)
574+
try self.invalidState("Client is already closed.")
573575
}
574576
}
575577

@@ -665,7 +667,7 @@ extension GRPCStreamStateMachine {
665667
.receivedStatusAndMetadata(
666668
status: .init(
667669
code: .internalError,
668-
message: "Missing \(GRPCHTTP2Keys.contentType) header"
670+
message: "Missing \(GRPCHTTP2Keys.contentType.rawValue) header"
669671
),
670672
metadata: Metadata(headers: metadata)
671673
)
@@ -680,18 +682,23 @@ extension GRPCStreamStateMachine {
680682
case success(CompressionAlgorithm)
681683
}
682684

683-
private func processInboundEncoding(_ metadata: HPACKHeaders) -> ProcessInboundEncodingResult {
685+
private func processInboundEncoding(
686+
headers: HPACKHeaders,
687+
configuration: GRPCStreamStateMachineConfiguration.ClientConfiguration
688+
) -> ProcessInboundEncodingResult {
684689
let inboundEncoding: CompressionAlgorithm
685-
if let serverEncoding = metadata.first(name: GRPCHTTP2Keys.encoding.rawValue) {
686-
guard let parsedEncoding = CompressionAlgorithm(rawValue: serverEncoding) else {
690+
if let serverEncoding = headers.first(name: GRPCHTTP2Keys.encoding.rawValue) {
691+
guard let parsedEncoding = CompressionAlgorithm(rawValue: serverEncoding),
692+
configuration.acceptedEncodings.contains(parsedEncoding)
693+
else {
687694
return .error(
688695
.receivedStatusAndMetadata(
689696
status: .init(
690697
code: .internalError,
691698
message:
692699
"The server picked a compression algorithm ('\(serverEncoding)') the client does not know about."
693700
),
694-
metadata: Metadata(headers: metadata)
701+
metadata: Metadata(headers: headers)
695702
)
696703
)
697704
}
@@ -732,7 +739,8 @@ extension GRPCStreamStateMachine {
732739

733740
private mutating func clientReceive(
734741
headers: HPACKHeaders,
735-
endStream: Bool
742+
endStream: Bool,
743+
configuration: GRPCStreamStateMachineConfiguration.ClientConfiguration
736744
) throws -> OnMetadataReceived {
737745
switch self.state {
738746
case .clientOpenServerIdle(let state):
@@ -750,7 +758,7 @@ extension GRPCStreamStateMachine {
750758
self.state = .clientOpenServerClosed(.init(previousState: state))
751759
return try self.validateAndReturnStatusAndMetadata(headers)
752760
case (.valid, false):
753-
switch self.processInboundEncoding(headers) {
761+
switch self.processInboundEncoding(headers: headers, configuration: configuration) {
754762
case .error(let failure):
755763
return failure
756764
case .success(let inboundEncoding):
@@ -798,7 +806,7 @@ extension GRPCStreamStateMachine {
798806
self.state = .clientClosedServerClosed(.init(previousState: state))
799807
return try self.validateAndReturnStatusAndMetadata(headers)
800808
case (.valid, false):
801-
switch self.processInboundEncoding(headers) {
809+
switch self.processInboundEncoding(headers: headers, configuration: configuration) {
802810
case .error(let failure):
803811
return failure
804812
case .success(let inboundEncoding):

0 commit comments

Comments
 (0)