Skip to content

Commit 13ea905

Browse files
committed
Add option to include response metadata in client interceptor
1 parent 4c5a11c commit 13ea905

File tree

2 files changed

+142
-12
lines changed

2 files changed

+142
-12
lines changed

Sources/GRPCOTelTracingInterceptors/Tracing/ClientOTelTracingInterceptor.swift

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public struct ClientOTelTracingInterceptor: ClientInterceptor {
3636

3737
private let traceEachMessage: Bool
3838
private var includeRequestMetadata: Bool
39+
private var includeResponseMetadata: Bool
3940

4041
/// Create a new instance of a ``ClientOTelTracingInterceptor``.
4142
///
@@ -45,20 +46,24 @@ public struct ClientOTelTracingInterceptor: ClientInterceptor {
4546
/// `network.transport` attribute in spans.
4647
/// - traceEachMessage: If `true`, each request part sent and response part received will be recorded as a separate
4748
/// event in a tracing span.
48-
/// - includeRequestMetadata: if `true`, **all** metadata keys included in the request will be added to the span as attributes.
49+
/// - includeRequestMetadata: if `true`, **all** metadata keys with string values included in the request will be added to the span as attributes.
50+
/// - includeResponseMetadata: if `true`, **all** metadata keys with string values included in the response will be added to the span as attributes.
4951
///
50-
/// - Important: Be careful when setting `includeRequestMetadata=true`, as including all request metadata can be a security risk.
52+
/// - Important: Be careful when setting `includeRequestMetadata` or `includeResponseMetadata` to `true`,
53+
/// as including all request/response metadata can be a security risk.
5154
public init(
5255
serverHostname: String,
5356
networkTransportMethod: String,
5457
traceEachMessage: Bool = true,
55-
includeRequestMetadata: Bool = false
58+
includeRequestMetadata: Bool = false,
59+
includeResponseMetadata: Bool = false
5660
) {
5761
self.injector = ClientRequestInjector()
5862
self.serverHostname = serverHostname
5963
self.networkTransportMethod = networkTransportMethod
6064
self.traceEachMessage = traceEachMessage
6165
self.includeRequestMetadata = includeRequestMetadata
66+
self.includeResponseMetadata = includeResponseMetadata
6267
}
6368

6469
/// This interceptor will inject as the request's metadata whatever `ServiceContext` key-value pairs
@@ -142,6 +147,11 @@ public struct ClientOTelTracingInterceptor: ClientInterceptor {
142147
}
143148

144149
var response = try await next(request, context)
150+
151+
if self.includeResponseMetadata {
152+
span.setMetadataStringAttributesAsResponseSpanAttributes(response.metadata)
153+
}
154+
145155
switch response.accepted {
146156
case .success(var success):
147157
let hookedSequence:
@@ -150,14 +160,22 @@ public struct ClientOTelTracingInterceptor: ClientInterceptor {
150160
>
151161
if self.traceEachMessage {
152162
let messageReceivedCounter = Atomic(1)
153-
hookedSequence = HookedRPCAsyncSequence(wrapping: success.bodyParts) { _ in
154-
var event = SpanEvent(name: "rpc.message")
155-
event.attributes[GRPCTracingKeys.rpcMessageType] = "RECEIVED"
156-
event.attributes[GRPCTracingKeys.rpcMessageID] =
163+
hookedSequence = HookedRPCAsyncSequence(wrapping: success.bodyParts) { part in
164+
switch part {
165+
case .message(let message):
166+
var event = SpanEvent(name: "rpc.message")
167+
event.attributes[GRPCTracingKeys.rpcMessageType] = "RECEIVED"
168+
event.attributes[GRPCTracingKeys.rpcMessageID] =
157169
messageReceivedCounter
158-
.wrappingAdd(1, ordering: .sequentiallyConsistent)
159-
.oldValue
160-
span.addEvent(event)
170+
.wrappingAdd(1, ordering: .sequentiallyConsistent)
171+
.oldValue
172+
span.addEvent(event)
173+
174+
case .trailingMetadata(let trailingMetadata):
175+
if self.includeResponseMetadata {
176+
span.setMetadataStringAttributesAsResponseSpanAttributes(trailingMetadata)
177+
}
178+
}
161179
} onFinish: { error in
162180
if let error {
163181
if let errorCode = error.grpcErrorCode {

Tests/GRPCOTelTracingInterceptorsTests/GRPCOTelTracingInterceptorsTests.swift

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ struct OTelTracingClientInterceptorTests {
180180
}
181181
}
182182

183-
@Test("All request metadata is included if opted-in")
183+
@Test("All string-valued request metadata is included if opted-in")
184184
func testRequestMetadataOptIn() async throws {
185185
var serviceContext = ServiceContext.topLevel
186186
let traceIDString = UUID().uuidString
@@ -199,7 +199,7 @@ struct OTelTracingClientInterceptorTests {
199199
)
200200
let methodDescriptor = MethodDescriptor(
201201
fullyQualifiedService: "OTelTracingClientInterceptorTests",
202-
method: "testAllEventsRecorded"
202+
method: "testRequestMetadataOptIn"
203203
)
204204
let response = try await interceptor.intercept(
205205
tracer: self.tracer,
@@ -288,6 +288,118 @@ struct OTelTracingClientInterceptorTests {
288288
}
289289
}
290290

291+
@Test("All string-valued response metadata is included if opted-in")
292+
func testResponseMetadataOptIn() async throws {
293+
var serviceContext = ServiceContext.topLevel
294+
let traceIDString = UUID().uuidString
295+
296+
let (requestStream, requestStreamContinuation) = AsyncStream<String>.makeStream()
297+
serviceContext.traceID = traceIDString
298+
299+
// FIXME: use 'ServiceContext.withValue(serviceContext)'
300+
//
301+
// This is blocked on: https://github.com/apple/swift-service-context/pull/46
302+
try await ServiceContext.$current.withValue(serviceContext) {
303+
let interceptor = ClientOTelTracingInterceptor(
304+
serverHostname: "someserver.com",
305+
networkTransportMethod: "tcp",
306+
includeResponseMetadata: true
307+
)
308+
let methodDescriptor = MethodDescriptor(
309+
fullyQualifiedService: "OTelTracingClientInterceptorTests",
310+
method: "testResponseMetadataOptIn"
311+
)
312+
let response = try await interceptor.intercept(
313+
tracer: self.tracer,
314+
request: .init(
315+
metadata: [
316+
"some-request-metadata": "some-request-value",
317+
"some-repeated-request-metadata": "some-repeated-request-value1",
318+
"some-repeated-request-metadata": "some-repeated-request-value2",
319+
"some-request-metadata-bin": .binary([1]),
320+
],
321+
producer: { writer in
322+
try await writer.write(contentsOf: ["request1"])
323+
try await writer.write(contentsOf: ["request2"])
324+
}
325+
),
326+
context: ClientContext(
327+
descriptor: methodDescriptor,
328+
remotePeer: "ipv4:10.1.2.80:567",
329+
localPeer: "ipv4:10.1.2.80:123"
330+
)
331+
) { stream, _ in
332+
// Assert the metadata contains the injected context key-value.
333+
#expect(
334+
stream.metadata.contains(where: {
335+
($0.key == "trace-id") && ($0.value == .string(traceIDString))
336+
})
337+
)
338+
339+
// Write into the request stream to make sure the `producer` closure's called.
340+
let writer = RPCWriter(wrapping: TestWriter(streamContinuation: requestStreamContinuation))
341+
try await stream.producer(writer)
342+
requestStreamContinuation.finish()
343+
344+
return .init(
345+
metadata: [
346+
"some-response-metadata": "some-response-value",
347+
"some-response-metadata-bin": .binary([2]),
348+
],
349+
bodyParts: RPCAsyncSequence(
350+
wrapping: AsyncThrowingStream<StreamingClientResponse.Contents.BodyPart, any Error> {
351+
$0.yield(.message(["response"]))
352+
$0.yield(.trailingMetadata([
353+
"some-repeated-response-metadata": "some-repeated-response-value1",
354+
"some-repeated-response-metadata": "some-repeated-response-value2"
355+
]))
356+
$0.finish()
357+
}
358+
)
359+
)
360+
}
361+
362+
await assertStreamContentsEqual(["request1", "request2"], requestStream)
363+
try await assertStreamContentsEqual([["response"]], response.messages)
364+
365+
assertTestSpanComponents(forMethod: methodDescriptor, tracer: self.tracer) { events in
366+
#expect(
367+
events == [
368+
// Recorded when `request1` is sent
369+
TestSpanEvent("rpc.message", ["rpc.message.type": "SENT", "rpc.message.id": 1]),
370+
// Recorded when `request2` is sent
371+
TestSpanEvent("rpc.message", ["rpc.message.type": "SENT", "rpc.message.id": 2]),
372+
// Recorded when receiving response part
373+
TestSpanEvent("rpc.message", ["rpc.message.type": "RECEIVED", "rpc.message.id": 1]),
374+
]
375+
)
376+
} assertAttributes: { attributes in
377+
#expect(
378+
attributes == [
379+
"rpc.system": "grpc",
380+
"rpc.method": .string(methodDescriptor.method),
381+
"rpc.service": .string(methodDescriptor.service.fullyQualifiedService),
382+
"rpc.grpc.status_code": 0,
383+
"server.address": "someserver.com",
384+
"server.port": 567,
385+
"network.peer.address": "10.1.2.80",
386+
"network.peer.port": 567,
387+
"network.transport": "tcp",
388+
"network.type": "ipv4",
389+
"rpc.grpc.response.metadata.some-response-metadata": "some-response-value",
390+
"rpc.grpc.response.metadata.some-repeated-response-metadata": .stringArray([
391+
"some-repeated-response-value1", "some-repeated-response-value2",
392+
]),
393+
]
394+
)
395+
} assertStatus: { status in
396+
#expect(status == nil)
397+
} assertErrors: { errors in
398+
#expect(errors == [])
399+
}
400+
}
401+
}
402+
291403
@Test("RPC that throws is correctly recorded")
292404
func testThrowingRPC() async throws {
293405
var serviceContext = ServiceContext.topLevel

0 commit comments

Comments
 (0)