Skip to content

Commit a6c173b

Browse files
feat: Add Spanner request ID in exceptions
1 parent c06ffb0 commit a6c173b

File tree

3 files changed

+286
-4
lines changed

3 files changed

+286
-4
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
using Grpc.Core;
16+
using Grpc.Core.Interceptors;
17+
using System;
18+
using System.Threading;
19+
using System.Threading.Tasks;
20+
using Xunit;
21+
22+
namespace Google.Cloud.Spanner.V1.Tests;
23+
24+
public class RequestIdCallInterceptorTest
25+
{
26+
private static readonly Method<string, string> s_unaryMethod = new Method<string, string>(MethodType.Unary, "service", "method", Marshallers.StringMarshaller, Marshallers.StringMarshaller);
27+
private static readonly Method<string, string> s_clientStreamingMethod = new Method<string, string>(MethodType.ClientStreaming, "service", "method", Marshallers.StringMarshaller, Marshallers.StringMarshaller);
28+
private static readonly Method<string, string> s_serverStreamingMethod = new Method<string, string>(MethodType.ServerStreaming, "service", "method", Marshallers.StringMarshaller, Marshallers.StringMarshaller);
29+
private static readonly Method<string, string> s_duplexStreamingMethod = new Method<string, string>(MethodType.DuplexStreaming, "service", "method", Marshallers.StringMarshaller, Marshallers.StringMarshaller);
30+
31+
public static TheoryData<Func<CallInvoker, CallOptions, Task>> CallInvokerActions { get; } = new TheoryData<Func<CallInvoker, CallOptions, Task>>
32+
{
33+
(invoker, options) => Task.Run(() => invoker.BlockingUnaryCall(s_unaryMethod, null, options, "")),
34+
(invoker, options) => invoker.AsyncUnaryCall(s_unaryMethod, null, options, "").ResponseAsync,
35+
(invoker, options) => invoker.AsyncClientStreamingCall(s_clientStreamingMethod, null, options).ResponseAsync,
36+
(invoker, options) => invoker.AsyncServerStreamingCall(s_serverStreamingMethod, null, options, "").ResponseStream.MoveNext(),
37+
(invoker, options) => invoker.AsyncDuplexStreamingCall(s_duplexStreamingMethod, null, options).ResponseStream.MoveNext()
38+
};
39+
40+
[Theory]
41+
[MemberData(nameof(CallInvokerActions))]
42+
public async Task RpcCall_ExceptionContainsRequestId(Func<CallInvoker, CallOptions, Task> action)
43+
{
44+
var invoker = new ThrowingCallInvoker();
45+
var interceptedInvoker = invoker.Intercept(new RequestIdCallInterceptor());
46+
var headers = new Metadata { { RequestId.HeaderKey, "test-request-id" } };
47+
var options = new CallOptions(headers);
48+
49+
var exception = await Assert.ThrowsAsync<RpcException>(() => action(interceptedInvoker, options));
50+
51+
Assert.True(exception.Data.Contains(RequestId.HeaderKey));
52+
Assert.Equal("test-request-id", exception.Data[RequestId.HeaderKey]);
53+
}
54+
55+
private class ThrowingCallInvoker : CallInvoker
56+
{
57+
public override TResponse BlockingUnaryCall<TRequest, TResponse>(Method<TRequest, TResponse> method, string host, CallOptions options, TRequest request) =>
58+
throw new RpcException(new Status(StatusCode.Internal, "Test exception"));
59+
60+
public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(Method<TRequest, TResponse> method, string host, CallOptions options, TRequest request) =>
61+
new AsyncUnaryCall<TResponse>(
62+
Task.FromException<TResponse>(new RpcException(new Status(StatusCode.Internal, "Test exception"))),
63+
Task.FromResult(new Metadata()), () => Status.DefaultSuccess, () => new Metadata(), () => { });
64+
65+
public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(Method<TRequest, TResponse> method, string host, CallOptions options) =>
66+
new AsyncClientStreamingCall<TRequest, TResponse>(
67+
new MockClientStreamWriter<TRequest>(),
68+
Task.FromException<TResponse>(new RpcException(new Status(StatusCode.Internal, "Test exception"))),
69+
Task.FromResult(new Metadata()), () => Status.DefaultSuccess, () => new Metadata(), () => { });
70+
71+
public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(Method<TRequest, TResponse> method, string host, CallOptions options, TRequest request) =>
72+
new AsyncServerStreamingCall<TResponse>(
73+
new ThrowingStreamReader<TResponse>(),
74+
Task.FromResult(new Metadata()), () => Status.DefaultSuccess, () => new Metadata(), () => { });
75+
76+
public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(Method<TRequest, TResponse> method, string host, CallOptions options) =>
77+
new AsyncDuplexStreamingCall<TRequest, TResponse>(
78+
new MockClientStreamWriter<TRequest>(),
79+
new ThrowingStreamReader<TResponse>(),
80+
Task.FromResult(new Metadata()), () => Status.DefaultSuccess, () => new Metadata(), () => { });
81+
}
82+
83+
private class MockClientStreamWriter<T> : IClientStreamWriter<T>
84+
{
85+
public WriteOptions WriteOptions { get; set; }
86+
public Task CompleteAsync() => Task.CompletedTask;
87+
public Task WriteAsync(T message) => Task.CompletedTask;
88+
}
89+
90+
private class ThrowingStreamReader<T> : IAsyncStreamReader<T>
91+
{
92+
public T Current => default;
93+
94+
public Task<bool> MoveNext(CancellationToken cancellationToken) =>
95+
Task.FromException<bool>(new RpcException(new Status(StatusCode.Internal, "Test exception")));
96+
}
97+
}
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
using Grpc.Core;
16+
using Grpc.Core.Interceptors;
17+
using System;
18+
using System.Threading;
19+
using System.Threading.Tasks;
20+
21+
namespace Google.Cloud.Spanner.V1;
22+
23+
/// <summary>
24+
/// A <see cref="Interceptor"/> that wraps all calls, adding the Spanner request ID to any exceptions thrown while
25+
/// using the <see cref="SpannerClient"/>.
26+
/// </summary>
27+
internal sealed class RequestIdCallInterceptor : Interceptor
28+
{
29+
/// <inheritdoc/>
30+
public override TResponse BlockingUnaryCall<TRequest, TResponse>(
31+
TRequest request,
32+
ClientInterceptorContext<TRequest, TResponse> context,
33+
BlockingUnaryCallContinuation<TRequest, TResponse> continuation)
34+
{
35+
var requestId = GetRequestIdFromOptions(context.Options);
36+
return WrapExceptions(() => continuation(request, context), context.Options);
37+
}
38+
39+
/// <inheritdoc/>
40+
public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(
41+
TRequest request,
42+
ClientInterceptorContext<TRequest, TResponse> context,
43+
AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
44+
{
45+
var call = continuation(request, context);
46+
var requestId = GetRequestIdFromOptions(context.Options);
47+
48+
return new AsyncUnaryCall<TResponse>(
49+
WrapExceptionAsync(call.ResponseAsync, context.Options),
50+
call.ResponseHeadersAsync,
51+
call.GetStatus,
52+
call.GetTrailers,
53+
call.Dispose);
54+
}
55+
56+
/// <inheritdoc/>
57+
public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(
58+
ClientInterceptorContext<TRequest, TResponse> context,
59+
AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
60+
{
61+
var call = continuation(context);
62+
63+
return new AsyncClientStreamingCall<TRequest, TResponse>(
64+
call.RequestStream,
65+
WrapExceptionAsync(call.ResponseAsync, context.Options),
66+
call.ResponseHeadersAsync,
67+
call.GetStatus,
68+
call.GetTrailers,
69+
call.Dispose);
70+
}
71+
72+
/// <inheritdoc/>
73+
public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(
74+
TRequest request,
75+
ClientInterceptorContext<TRequest, TResponse> context,
76+
AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
77+
{
78+
var call = continuation(request, context);
79+
var responseStream = new SpannerRequestIdStreamReader<TResponse>(call.ResponseStream, context.Options);
80+
81+
return new AsyncServerStreamingCall<TResponse>(
82+
responseStream,
83+
call.ResponseHeadersAsync,
84+
call.GetStatus,
85+
call.GetTrailers,
86+
call.Dispose);
87+
}
88+
89+
/// <inheritdoc/>
90+
public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(
91+
ClientInterceptorContext<TRequest, TResponse> context,
92+
AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation)
93+
{
94+
var call = continuation(context);
95+
var requestId = GetRequestIdFromOptions(context.Options);
96+
var responseStream = new SpannerRequestIdStreamReader<TResponse>(call.ResponseStream, context.Options);
97+
98+
return new AsyncDuplexStreamingCall<TRequest, TResponse>(
99+
call.RequestStream,
100+
responseStream,
101+
call.ResponseHeadersAsync,
102+
call.GetStatus,
103+
call.GetTrailers,
104+
call.Dispose);
105+
}
106+
107+
/// <summary>
108+
/// Gets the request ID from the call options.
109+
/// </summary>
110+
private static string GetRequestIdFromOptions(CallOptions options)
111+
{
112+
if (options.Headers is Metadata headers)
113+
{
114+
return headers.GetValue(RequestId.HeaderKey);
115+
}
116+
return null;
117+
}
118+
119+
/// <summary>
120+
/// Handles the response, adding the request ID to any exceptions thrown.
121+
/// </summary>
122+
private static T WrapExceptions<T>(Func<T> action, CallOptions options)
123+
{
124+
try
125+
{
126+
return action();
127+
}
128+
catch (Exception e)
129+
{
130+
var requestId = GetRequestIdFromOptions(options);
131+
if (requestId != null)
132+
{
133+
e.Data[RequestId.HeaderKey] = requestId;
134+
}
135+
throw;
136+
}
137+
}
138+
139+
/// <summary>
140+
/// Handles the response asynchronously, adding the request ID to any exceptions thrown.
141+
/// </summary>
142+
private static async Task<T> WrapExceptionAsync<T>(Task<T> task, CallOptions options)
143+
{
144+
try
145+
{
146+
return await task.ConfigureAwait(false);
147+
}
148+
catch (Exception e)
149+
{
150+
var requestId = GetRequestIdFromOptions(options);
151+
if (requestId != null)
152+
{
153+
e.Data[RequestId.HeaderKey] = requestId;
154+
}
155+
throw;
156+
}
157+
}
158+
159+
/// <summary>
160+
/// A stream reader that wraps the original stream reader and adds the request ID to any exceptions thrown.
161+
/// </summary>
162+
private class SpannerRequestIdStreamReader<T> : IAsyncStreamReader<T>
163+
{
164+
private readonly IAsyncStreamReader<T> _originalReader;
165+
private readonly CallOptions _options;
166+
167+
public SpannerRequestIdStreamReader(IAsyncStreamReader<T> originalReader, CallOptions options)
168+
{
169+
_originalReader = originalReader;
170+
_options = options;
171+
}
172+
173+
public T Current => _originalReader.Current;
174+
175+
public Task<bool> MoveNext(CancellationToken cancellationToken) =>
176+
WrapExceptionAsync(_originalReader.MoveNext(cancellationToken), _options);
177+
}
178+
}

apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientBuilderPartial.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
using Google.Api.Gax.Grpc;
1717
using Google.Api.Gax.Grpc.Gcp;
1818
using Grpc.Core;
19+
using Grpc.Core.Interceptors;
1920
using System;
2021
using System.Threading;
2122
using System.Threading.Tasks;
@@ -160,8 +161,9 @@ partial void InterceptBuildAsync(CancellationToken cancellationToken, ref Task<S
160161
task = MaybeCreateEmulatorClientBuilder()?.BuildAsync(cancellationToken);
161162

162163
/// <inheritdoc/>
163-
protected override CallInvoker CreateCallInvoker() =>
164-
AffinityChannelPoolConfiguration is null
164+
protected override CallInvoker CreateCallInvoker()
165+
{
166+
var invoker = AffinityChannelPoolConfiguration is null
165167
? base.CreateCallInvoker()
166168
: new GcpCallInvoker(
167169
ServiceMetadata,
@@ -170,10 +172,13 @@ AffinityChannelPoolConfiguration is null
170172
GetChannelOptions(),
171173
GetApiConfig(),
172174
EffectiveGrpcAdapter);
175+
return invoker.Intercept(new RequestIdCallInterceptor());
176+
}
173177

174178
/// <inheritdoc/>
175-
protected override async Task<CallInvoker> CreateCallInvokerAsync(CancellationToken cancellationToken) =>
176-
AffinityChannelPoolConfiguration is null
179+
protected override async Task<CallInvoker> CreateCallInvokerAsync(CancellationToken cancellationToken)
180+
{
181+
var invoker = AffinityChannelPoolConfiguration is null
177182
? await base.CreateCallInvokerAsync(cancellationToken).ConfigureAwait(false)
178183
: new GcpCallInvoker(
179184
ServiceMetadata,
@@ -182,6 +187,8 @@ await GetChannelCredentialsAsync(cancellationToken).ConfigureAwait(false),
182187
GetChannelOptions(),
183188
GetApiConfig(),
184189
EffectiveGrpcAdapter);
190+
return invoker.Intercept(new RequestIdCallInterceptor());
191+
}
185192

186193
private ApiConfig GetApiConfig() => new ApiConfig
187194
{

0 commit comments

Comments
 (0)