Skip to content

Commit 287c930

Browse files
feat: Add Spanner request ID in exceptions
1 parent f5a6fc9 commit 287c930

File tree

3 files changed

+286
-4
lines changed

3 files changed

+286
-4
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
using Grpc.Core;
16+
using Grpc.Core.Interceptors;
17+
using System;
18+
using System.Threading;
19+
using System.Threading.Tasks;
20+
using Xunit;
21+
22+
namespace Google.Cloud.Spanner.V1.Tests;
23+
24+
public class SpannerRequestIdCallInterceptorTest
25+
{
26+
private static readonly Method<string, string> s_unaryMethod = new Method<string, string>(MethodType.Unary, "service", "method", Marshallers.StringMarshaller, Marshallers.StringMarshaller);
27+
private static readonly Method<string, string> s_clientStreamingMethod = new Method<string, string>(MethodType.ClientStreaming, "service", "method", Marshallers.StringMarshaller, Marshallers.StringMarshaller);
28+
private static readonly Method<string, string> s_serverStreamingMethod = new Method<string, string>(MethodType.ServerStreaming, "service", "method", Marshallers.StringMarshaller, Marshallers.StringMarshaller);
29+
private static readonly Method<string, string> s_duplexStreamingMethod = new Method<string, string>(MethodType.DuplexStreaming, "service", "method", Marshallers.StringMarshaller, Marshallers.StringMarshaller);
30+
31+
public static TheoryData<Func<CallInvoker, CallOptions, Task>> CallInvokerActions { get; } = new TheoryData<Func<CallInvoker, CallOptions, Task>>
32+
{
33+
(invoker, options) => Task.Run(() => invoker.BlockingUnaryCall(s_unaryMethod, null, options, "")),
34+
(invoker, options) => invoker.AsyncUnaryCall(s_unaryMethod, null, options, "").ResponseAsync,
35+
(invoker, options) => invoker.AsyncClientStreamingCall(s_clientStreamingMethod, null, options).ResponseAsync,
36+
(invoker, options) => invoker.AsyncServerStreamingCall(s_serverStreamingMethod, null, options, "").ResponseStream.MoveNext(),
37+
(invoker, options) => invoker.AsyncDuplexStreamingCall(s_duplexStreamingMethod, null, options).ResponseStream.MoveNext()
38+
};
39+
40+
[Theory]
41+
[MemberData(nameof(CallInvokerActions))]
42+
public async Task RpcCall_ExceptionContainsRequestId(Func<CallInvoker, CallOptions, Task> action)
43+
{
44+
var invoker = new ThrowingCallInvoker();
45+
var interceptedInvoker = invoker.Intercept(new SpannerRequestIdCallInterceptor());
46+
var headers = new Metadata { { SpannerRequestId.HeaderKey, "test-request-id" } };
47+
var options = new CallOptions(headers);
48+
49+
var exception = await Assert.ThrowsAsync<RpcException>(() => action(interceptedInvoker, options));
50+
51+
Assert.True(exception.Data.Contains(SpannerRequestId.HeaderKey));
52+
Assert.Equal("test-request-id", exception.Data[SpannerRequestId.HeaderKey]);
53+
}
54+
55+
private class ThrowingCallInvoker : CallInvoker
56+
{
57+
public override TResponse BlockingUnaryCall<TRequest, TResponse>(Method<TRequest, TResponse> method, string host, CallOptions options, TRequest request) =>
58+
throw new RpcException(new Status(StatusCode.Internal, "Test exception"));
59+
60+
public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(Method<TRequest, TResponse> method, string host, CallOptions options, TRequest request) =>
61+
new AsyncUnaryCall<TResponse>(
62+
Task.FromException<TResponse>(new RpcException(new Status(StatusCode.Internal, "Test exception"))),
63+
Task.FromResult(new Metadata()), () => Status.DefaultSuccess, () => new Metadata(), () => { });
64+
65+
public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(Method<TRequest, TResponse> method, string host, CallOptions options) =>
66+
new AsyncClientStreamingCall<TRequest, TResponse>(
67+
new MockClientStreamWriter<TRequest>(),
68+
Task.FromException<TResponse>(new RpcException(new Status(StatusCode.Internal, "Test exception"))),
69+
Task.FromResult(new Metadata()), () => Status.DefaultSuccess, () => new Metadata(), () => { });
70+
71+
public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(Method<TRequest, TResponse> method, string host, CallOptions options, TRequest request) =>
72+
new AsyncServerStreamingCall<TResponse>(
73+
new ThrowingStreamReader<TResponse>(),
74+
Task.FromResult(new Metadata()), () => Status.DefaultSuccess, () => new Metadata(), () => { });
75+
76+
public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(Method<TRequest, TResponse> method, string host, CallOptions options) =>
77+
new AsyncDuplexStreamingCall<TRequest, TResponse>(
78+
new MockClientStreamWriter<TRequest>(),
79+
new ThrowingStreamReader<TResponse>(),
80+
Task.FromResult(new Metadata()), () => Status.DefaultSuccess, () => new Metadata(), () => { });
81+
}
82+
83+
private class MockClientStreamWriter<T> : IClientStreamWriter<T>
84+
{
85+
public WriteOptions WriteOptions { get; set; }
86+
public Task CompleteAsync() => Task.CompletedTask;
87+
public Task WriteAsync(T message) => Task.CompletedTask;
88+
}
89+
90+
private class ThrowingStreamReader<T> : IAsyncStreamReader<T>
91+
{
92+
public T Current => default;
93+
94+
public Task<bool> MoveNext(CancellationToken cancellationToken) =>
95+
Task.FromException<bool>(new RpcException(new Status(StatusCode.Internal, "Test exception")));
96+
}
97+
}

apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientBuilderPartial.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
using Google.Api.Gax.Grpc;
1717
using Google.Api.Gax.Grpc.Gcp;
1818
using Grpc.Core;
19+
using Grpc.Core.Interceptors;
1920
using System;
2021
using System.Threading;
2122
using System.Threading.Tasks;
@@ -150,8 +151,9 @@ partial void InterceptBuildAsync(CancellationToken cancellationToken, ref Task<S
150151
task = MaybeCreateEmulatorClientBuilder()?.BuildAsync(cancellationToken);
151152

152153
/// <inheritdoc/>
153-
protected override CallInvoker CreateCallInvoker() =>
154-
AffinityChannelPoolConfiguration is null
154+
protected override CallInvoker CreateCallInvoker()
155+
{
156+
var invoker = AffinityChannelPoolConfiguration is null
155157
? base.CreateCallInvoker()
156158
: new GcpCallInvoker(
157159
ServiceMetadata,
@@ -160,10 +162,13 @@ AffinityChannelPoolConfiguration is null
160162
GetChannelOptions(),
161163
GetApiConfig(),
162164
EffectiveGrpcAdapter);
165+
return invoker.Intercept(new SpannerRequestIdCallInterceptor());
166+
}
163167

164168
/// <inheritdoc/>
165-
protected override async Task<CallInvoker> CreateCallInvokerAsync(CancellationToken cancellationToken) =>
166-
AffinityChannelPoolConfiguration is null
169+
protected override async Task<CallInvoker> CreateCallInvokerAsync(CancellationToken cancellationToken)
170+
{
171+
var invoker = AffinityChannelPoolConfiguration is null
167172
? await base.CreateCallInvokerAsync(cancellationToken).ConfigureAwait(false)
168173
: new GcpCallInvoker(
169174
ServiceMetadata,
@@ -172,6 +177,8 @@ await GetChannelCredentialsAsync(cancellationToken).ConfigureAwait(false),
172177
GetChannelOptions(),
173178
GetApiConfig(),
174179
EffectiveGrpcAdapter);
180+
return invoker.Intercept(new SpannerRequestIdCallInterceptor());
181+
}
175182

176183
private ApiConfig GetApiConfig() => new ApiConfig
177184
{
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
using Grpc.Core;
16+
using Grpc.Core.Interceptors;
17+
using System;
18+
using System.Threading;
19+
using System.Threading.Tasks;
20+
21+
namespace Google.Cloud.Spanner.V1;
22+
23+
/// <summary>
24+
/// A <see cref="Interceptor"/> that wraps all calls, adding the Spanner request ID to any exceptions thrown while
25+
/// using the <see cref="SpannerClient"/>.
26+
/// </summary>
27+
internal sealed class SpannerRequestIdCallInterceptor : Interceptor
28+
{
29+
/// <inheritdoc/>
30+
public override TResponse BlockingUnaryCall<TRequest, TResponse>(
31+
TRequest request,
32+
ClientInterceptorContext<TRequest, TResponse> context,
33+
BlockingUnaryCallContinuation<TRequest, TResponse> continuation)
34+
{
35+
var requestId = GetRequestIdFromOptions(context.Options);
36+
return WrapExceptions(() => continuation(request, context), requestId);
37+
}
38+
39+
/// <inheritdoc/>
40+
public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(
41+
TRequest request,
42+
ClientInterceptorContext<TRequest, TResponse> context,
43+
AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
44+
{
45+
var call = continuation(request, context);
46+
var requestId = GetRequestIdFromOptions(context.Options);
47+
48+
return new AsyncUnaryCall<TResponse>(
49+
WrapExceptionAsync(call.ResponseAsync, requestId),
50+
call.ResponseHeadersAsync,
51+
call.GetStatus,
52+
call.GetTrailers,
53+
call.Dispose);
54+
}
55+
56+
/// <inheritdoc/>
57+
public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(
58+
ClientInterceptorContext<TRequest, TResponse> context,
59+
AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
60+
{
61+
var call = continuation(context);
62+
var requestId = GetRequestIdFromOptions(context.Options);
63+
64+
return new AsyncClientStreamingCall<TRequest, TResponse>(
65+
call.RequestStream,
66+
WrapExceptionAsync(call.ResponseAsync, requestId),
67+
call.ResponseHeadersAsync,
68+
call.GetStatus,
69+
call.GetTrailers,
70+
call.Dispose);
71+
}
72+
73+
/// <inheritdoc/>
74+
public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(
75+
TRequest request,
76+
ClientInterceptorContext<TRequest, TResponse> context,
77+
AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
78+
{
79+
var call = continuation(request, context);
80+
var requestId = GetRequestIdFromOptions(context.Options);
81+
var responseStream = new SpannerRequestIdStreamReader<TResponse>(call.ResponseStream, requestId);
82+
83+
return new AsyncServerStreamingCall<TResponse>(
84+
responseStream,
85+
call.ResponseHeadersAsync,
86+
call.GetStatus,
87+
call.GetTrailers,
88+
call.Dispose);
89+
}
90+
91+
/// <inheritdoc/>
92+
public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(
93+
ClientInterceptorContext<TRequest, TResponse> context,
94+
AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation)
95+
{
96+
var call = continuation(context);
97+
var requestId = GetRequestIdFromOptions(context.Options);
98+
var responseStream = new SpannerRequestIdStreamReader<TResponse>(call.ResponseStream, requestId);
99+
100+
return new AsyncDuplexStreamingCall<TRequest, TResponse>(
101+
call.RequestStream,
102+
responseStream,
103+
call.ResponseHeadersAsync,
104+
call.GetStatus,
105+
call.GetTrailers,
106+
call.Dispose);
107+
}
108+
109+
/// <summary>
110+
/// Gets the request ID from the call options.
111+
/// </summary>
112+
private static string GetRequestIdFromOptions(CallOptions options)
113+
{
114+
if (options.Headers is Metadata headers)
115+
{
116+
return headers.GetValue(SpannerRequestId.HeaderKey);
117+
}
118+
return null;
119+
}
120+
121+
/// <summary>
122+
/// Handles the response, adding the request ID to any exceptions thrown.
123+
/// </summary>
124+
private static T WrapExceptions<T>(Func<T> action, string requestId)
125+
{
126+
try
127+
{
128+
return action();
129+
}
130+
catch (Exception e)
131+
{
132+
if (requestId != null)
133+
{
134+
e.Data[SpannerRequestId.HeaderKey] = requestId;
135+
}
136+
throw;
137+
}
138+
}
139+
140+
/// <summary>
141+
/// Handles the response asynchronously, adding the request ID to any exceptions thrown.
142+
/// </summary>
143+
private static async Task<T> WrapExceptionAsync<T>(Task<T> task, string requestId)
144+
{
145+
try
146+
{
147+
return await task.ConfigureAwait(false);
148+
}
149+
catch (Exception e)
150+
{
151+
if (requestId != null)
152+
{
153+
e.Data[SpannerRequestId.HeaderKey] = requestId;
154+
}
155+
throw;
156+
}
157+
}
158+
159+
/// <summary>
160+
/// A stream reader that wraps the original stream reader and adds the request ID to any exceptions thrown.
161+
/// </summary>
162+
private class SpannerRequestIdStreamReader<T> : IAsyncStreamReader<T>
163+
{
164+
private readonly IAsyncStreamReader<T> _originalReader;
165+
private readonly string _requestId;
166+
167+
public SpannerRequestIdStreamReader(IAsyncStreamReader<T> originalReader, string requestId)
168+
{
169+
_originalReader = originalReader;
170+
_requestId = requestId;
171+
}
172+
173+
public T Current => _originalReader.Current;
174+
175+
public Task<bool> MoveNext(CancellationToken cancellationToken) =>
176+
WrapExceptionAsync(_originalReader.MoveNext(cancellationToken), _requestId);
177+
}
178+
}

0 commit comments

Comments
 (0)