Skip to content

Commit c125d48

Browse files
committed
Add ServerInterceptorTarget
1 parent 574eea6 commit c125d48

File tree

11 files changed

+465
-54
lines changed

11 files changed

+465
-54
lines changed

Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct ServerRPCExecutor {
3434
>,
3535
deserializer: some MessageDeserializer<Input>,
3636
serializer: some MessageSerializer<Output>,
37-
interceptors: [any ServerInterceptor],
37+
interceptors: [ServerInterceptorTarget],
3838
handler: @Sendable @escaping (
3939
_ request: StreamingServerRequest<Input>,
4040
_ context: ServerContext
@@ -72,7 +72,7 @@ struct ServerRPCExecutor {
7272
outbound: RPCWriter<RPCResponsePart>.Closable,
7373
deserializer: some MessageDeserializer<Input>,
7474
serializer: some MessageSerializer<Output>,
75-
interceptors: [any ServerInterceptor],
75+
interceptors: [ServerInterceptorTarget],
7676
handler: @escaping @Sendable (
7777
_ request: StreamingServerRequest<Input>,
7878
_ context: ServerContext
@@ -113,7 +113,7 @@ struct ServerRPCExecutor {
113113
outbound: RPCWriter<RPCResponsePart>.Closable,
114114
deserializer: some MessageDeserializer<Input>,
115115
serializer: some MessageSerializer<Output>,
116-
interceptors: [any ServerInterceptor],
116+
interceptors: [ServerInterceptorTarget],
117117
handler: @escaping @Sendable (
118118
_ request: StreamingServerRequest<Input>,
119119
_ context: ServerContext
@@ -153,7 +153,7 @@ struct ServerRPCExecutor {
153153
outbound: RPCWriter<RPCResponsePart>.Closable,
154154
deserializer: some MessageDeserializer<Input>,
155155
serializer: some MessageSerializer<Output>,
156-
interceptors: [any ServerInterceptor],
156+
interceptors: [ServerInterceptorTarget],
157157
handler: @escaping @Sendable (
158158
_ request: StreamingServerRequest<Input>,
159159
_ context: ServerContext
@@ -286,7 +286,7 @@ extension ServerRPCExecutor {
286286
static func _intercept<Input, Output>(
287287
request: StreamingServerRequest<Input>,
288288
context: ServerContext,
289-
interceptors: [any ServerInterceptor],
289+
interceptors: [ServerInterceptorTarget],
290290
finally: @escaping @Sendable (
291291
_ request: StreamingServerRequest<Input>,
292292
_ context: ServerContext
@@ -304,7 +304,7 @@ extension ServerRPCExecutor {
304304
static func _intercept<Input, Output>(
305305
request: StreamingServerRequest<Input>,
306306
context: ServerContext,
307-
iterator: Array<any ServerInterceptor>.Iterator,
307+
iterator: Array<ServerInterceptorTarget>.Iterator,
308308
finally: @escaping @Sendable (
309309
_ request: StreamingServerRequest<Input>,
310310
_ context: ServerContext
@@ -313,17 +313,29 @@ extension ServerRPCExecutor {
313313
var iterator = iterator
314314

315315
switch iterator.next() {
316-
case .some(let interceptor):
317-
let iter = iterator
318-
do {
319-
return try await interceptor.intercept(request: request, context: context) {
320-
try await self._intercept(request: $0, context: $1, iterator: iter, finally: finally)
316+
case .some(let interceptorTarget):
317+
if interceptorTarget.applies(to: context.descriptor) {
318+
let iter = iterator
319+
do {
320+
return try await interceptorTarget.interceptor.intercept(
321+
request: request,
322+
context: context
323+
) {
324+
try await self._intercept(request: $0, context: $1, iterator: iter, finally: finally)
325+
}
326+
} catch let error as RPCError {
327+
return StreamingServerResponse(error: error)
328+
} catch let other {
329+
let error = RPCError(code: .unknown, message: "", cause: other)
330+
return StreamingServerResponse(error: error)
321331
}
322-
} catch let error as RPCError {
323-
return StreamingServerResponse(error: error)
324-
} catch let other {
325-
let error = RPCError(code: .unknown, message: "", cause: other)
326-
return StreamingServerResponse(error: error)
332+
} else {
333+
return try await self._intercept(
334+
request: request,
335+
context: context,
336+
iterator: iterator,
337+
finally: finally
338+
)
327339
}
328340

329341
case .none:

Sources/GRPCCore/Call/Server/RPCRouter.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public struct RPCRouter: Sendable {
4343
RPCWriter<RPCResponsePart>.Closable
4444
>,
4545
_ context: ServerContext,
46-
_ interceptors: [any ServerInterceptor]
46+
_ interceptors: [ServerInterceptorTarget]
4747
) async -> Void
4848

4949
@inlinable
@@ -75,7 +75,7 @@ public struct RPCRouter: Sendable {
7575
RPCWriter<RPCResponsePart>.Closable
7676
>,
7777
context: ServerContext,
78-
interceptors: [any ServerInterceptor]
78+
interceptors: [ServerInterceptorTarget]
7979
) async {
8080
await self._fn(stream, context, interceptors)
8181
}
@@ -151,7 +151,7 @@ extension RPCRouter {
151151
RPCWriter<RPCResponsePart>.Closable
152152
>,
153153
context: ServerContext,
154-
interceptors: [any ServerInterceptor]
154+
interceptors: [ServerInterceptorTarget]
155155
) async {
156156
if let handler = self.handlers[stream.descriptor] {
157157
await handler.handle(stream: stream, context: context, interceptors: interceptors)

Sources/GRPCCore/Call/Server/ServerInterceptor.swift

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
/// been returned from a service. They are typically used for cross-cutting concerns like filtering
2222
/// requests, validating messages, logging additional data, and tracing.
2323
///
24-
/// Interceptors are registered with the server apply to all RPCs. If you need to modify the
25-
/// behavior of an interceptor on a per-RPC basis then you can use the
26-
/// ``ServerContext/descriptor`` to determine which RPC is being called and
27-
/// conditionalise behavior accordingly.
24+
/// Interceptors are registered with the server via ``ServerInterceptorTarget``s.
25+
/// You may register them for all services registered with a server, for RPCs directed to specific services, or
26+
/// for RPCs directed to specific methods. If you need to modify the behavior of an interceptor on a
27+
/// per-RPC basis in more detail, then you can use the ``ServerContext/descriptor`` to determine
28+
/// which RPC is being called and conditionalise behavior accordingly.
2829
///
2930
/// ## RPC filtering
3031
///
@@ -33,19 +34,19 @@
3334
/// demonstrates this.
3435
///
3536
/// ```swift
36-
/// struct AuthServerInterceptor: Sendable {
37+
/// struct AuthServerInterceptor: ServerInterceptor {
3738
/// let isAuthorized: @Sendable (String, MethodDescriptor) async throws -> Void
3839
///
3940
/// func intercept<Input: Sendable, Output: Sendable>(
4041
/// request: StreamingServerRequest<Input>,
41-
/// context: ServerInterceptorContext,
42+
/// context: ServerContext,
4243
/// next: @Sendable (
4344
/// _ request: StreamingServerRequest<Input>,
44-
/// _ context: ServerInterceptorContext
45+
/// _ context: ServerContext
4546
/// ) async throws -> StreamingServerResponse<Output>
4647
/// ) async throws -> StreamingServerResponse<Output> {
4748
/// // Extract the auth token.
48-
/// guard let token = request.metadata["authorization"] else {
49+
/// guard let token = request.metadata[stringValues: "authorization"].first(where: { _ in true }) else {
4950
/// throw RPCError(code: .unauthenticated, message: "Not authenticated")
5051
/// }
5152
///
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright 2024, gRPC Authors All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
/// A `ServerInterceptorTarget` describes to which RPCs a server interceptor should be applied.
18+
///
19+
/// You can configure a server interceptor to be applied to:
20+
/// - all RPCs and services;
21+
/// - requests directed only to specific services registered with your server; or
22+
/// - requests directed only to specific methods (of a specific service).
23+
///
24+
/// - SeeAlso: ``ServerInterceptor`` for more information on server interceptors, and
25+
/// ``ClientInterceptorTarget`` for the client-side version of this type.
26+
public struct ServerInterceptorTarget: Sendable {
27+
internal enum Wrapped: Sendable {
28+
case allServices(interceptor: any ServerInterceptor)
29+
case serviceSpecific(interceptor: any ServerInterceptor, services: [String])
30+
case methodSpecific(interceptor: any ServerInterceptor, methods: [MethodDescriptor])
31+
}
32+
33+
/// A target specifying an interceptor that applies to all RPCs across all services registered with this server.
34+
/// - Parameter interceptor: The interceptor to register with the server.
35+
/// - Returns: A ``ServerInterceptorTarget``.
36+
public static func allServices(
37+
interceptor: any ServerInterceptor
38+
) -> Self {
39+
Self(wrapped: .allServices(interceptor: interceptor))
40+
}
41+
42+
/// A target specifying an interceptor that applies to RPCs directed only to the specified services.
43+
/// - Parameters:
44+
/// - interceptor: The interceptor to register with the server.
45+
/// - services: The list of service names for which this interceptor should intercept RPCs.
46+
/// - Returns: A ``ServerInterceptorTarget``.
47+
public static func serviceSpecific(
48+
interceptor: any ServerInterceptor,
49+
services: [String]
50+
) -> Self {
51+
Self(
52+
wrapped: .serviceSpecific(
53+
interceptor: interceptor,
54+
services: services
55+
)
56+
)
57+
}
58+
59+
/// A target specifying an interceptor that applies to RPCs directed only to the specified service methods.
60+
/// - Parameters:
61+
/// - interceptor: The interceptor to register with the server.
62+
/// - services: The list of method descriptors for which this interceptor should intercept RPCs.
63+
/// - Returns: A ``ServerInterceptorTarget``.
64+
public static func methodSpecific(
65+
interceptor: any ServerInterceptor,
66+
methods: [MethodDescriptor]
67+
) -> Self {
68+
Self(
69+
wrapped: .methodSpecific(
70+
interceptor: interceptor,
71+
methods: methods
72+
)
73+
)
74+
}
75+
76+
private let wrapped: Wrapped
77+
78+
private init(wrapped: Wrapped) {
79+
self.wrapped = wrapped
80+
}
81+
82+
/// Get the ``ServerInterceptor`` associated with this ``ServerInterceptorTarget``.
83+
public var interceptor: any ServerInterceptor {
84+
switch self.wrapped {
85+
case .allServices(let interceptor):
86+
return interceptor
87+
case .serviceSpecific(let interceptor, _):
88+
return interceptor
89+
case .methodSpecific(let interceptor, _):
90+
return interceptor
91+
}
92+
}
93+
94+
/// Returns whether this ``ServerInterceptorTarget`` applies to the given `descriptor`.
95+
/// - Parameter descriptor: A ``MethodDescriptor`` for which to test whether this interceptor applies.
96+
/// - Returns: `true` if this interceptor applies to the given `descriptor`, or `false` otherwise.
97+
public func applies(to descriptor: MethodDescriptor) -> Bool {
98+
switch self.wrapped {
99+
case .allServices:
100+
return true
101+
case .serviceSpecific(_, let services):
102+
return services.contains(descriptor.service)
103+
case .methodSpecific(_, let methods):
104+
return methods.contains(descriptor)
105+
}
106+
}
107+
}

Sources/GRPCCore/GRPCServer.swift

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private import Synchronization
4848
/// let server = GRPCServer(
4949
/// transport: inProcessTransport.server,
5050
/// services: [greeter, echo],
51-
/// interceptors: [statsRecorder]
51+
/// interceptors: [.allServices(interceptor: statsRecorder)]
5252
/// )
5353
/// ```
5454
///
@@ -78,13 +78,12 @@ public final class GRPCServer: Sendable {
7878
/// The services registered which the server is serving.
7979
private let router: RPCRouter
8080

81-
/// A collection of ``ServerInterceptor`` implementations which are applied to all accepted
82-
/// RPCs.
81+
/// A collection of ``ServerInterceptorTarget``s which may be applied to all accepted RPCs.
8382
///
8483
/// RPCs are intercepted in the order that interceptors are added. That is, a request received
8584
/// from the client will first be intercepted by the first added interceptor followed by the
8685
/// second, and so on.
87-
private let interceptors: [any ServerInterceptor]
86+
private let interceptors: [ServerInterceptorTarget]
8887

8988
/// The state of the server.
9089
private let state: Mutex<State>
@@ -152,7 +151,7 @@ public final class GRPCServer: Sendable {
152151
public convenience init(
153152
transport: any ServerTransport,
154153
services: [any RegistrableRPCService],
155-
interceptors: [any ServerInterceptor] = []
154+
interceptors: [ServerInterceptorTarget] = []
156155
) {
157156
var router = RPCRouter()
158157
for service in services {
@@ -175,7 +174,7 @@ public final class GRPCServer: Sendable {
175174
public init(
176175
transport: any ServerTransport,
177176
router: RPCRouter,
178-
interceptors: [any ServerInterceptor] = []
177+
interceptors: [ServerInterceptorTarget] = []
179178
) {
180179
self.state = Mutex(.notStarted)
181180
self.transport = transport

Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ struct ServerRPCExecutorTestHarness {
4646
}
4747
}
4848

49-
let interceptors: [any ServerInterceptor]
49+
let interceptors: [ServerInterceptorTarget]
5050

51-
init(interceptors: [any ServerInterceptor] = []) {
51+
init(interceptors: [ServerInterceptorTarget] = []) {
5252
self.interceptors = interceptors
5353
}
5454

Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ final class ServerRPCExecutorTests: XCTestCase {
258258
)
259259

260260
// The interceptor skips the handler altogether.
261-
let harness = ServerRPCExecutorTestHarness(interceptors: [.rejectAll(with: error)])
261+
let harness = ServerRPCExecutorTestHarness(interceptors: [
262+
.allServices(interceptor: .rejectAll(with: error))
263+
])
262264
try await harness.execute(
263265
deserializer: IdentityDeserializer(),
264266
serializer: IdentitySerializer()
@@ -288,8 +290,8 @@ final class ServerRPCExecutorTests: XCTestCase {
288290
// The interceptor skips the handler altogether.
289291
let harness = ServerRPCExecutorTestHarness(
290292
interceptors: [
291-
.requestCounter(counter1),
292-
.requestCounter(counter2),
293+
.allServices(interceptor: .requestCounter(counter1)),
294+
.allServices(interceptor: .requestCounter(counter2)),
293295
]
294296
)
295297

@@ -312,9 +314,9 @@ final class ServerRPCExecutorTests: XCTestCase {
312314
// The interceptor skips the handler altogether.
313315
let harness = ServerRPCExecutorTestHarness(
314316
interceptors: [
315-
.requestCounter(counter1),
316-
.rejectAll(with: RPCError(code: .unavailable, message: "")),
317-
.requestCounter(counter2),
317+
.allServices(interceptor: .requestCounter(counter1)),
318+
.allServices(interceptor: .rejectAll(with: RPCError(code: .unavailable, message: ""))),
319+
.allServices(interceptor: .requestCounter(counter2)),
318320
]
319321
)
320322

@@ -333,7 +335,9 @@ final class ServerRPCExecutorTests: XCTestCase {
333335

334336
func testThrowingInterceptor() async throws {
335337
let harness = ServerRPCExecutorTestHarness(
336-
interceptors: [.throwError(RPCError(code: .unavailable, message: "Unavailable"))]
338+
interceptors: [
339+
.allServices(interceptor: .throwError(RPCError(code: .unavailable, message: "Unavailable")))
340+
]
337341
)
338342

339343
try await harness.execute(handler: .echo) { inbound in

0 commit comments

Comments
 (0)