diff --git a/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1.Tests/RequestIdOnExceptionInterceptorTests.cs b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1.Tests/RequestIdOnExceptionInterceptorTests.cs new file mode 100644 index 000000000000..2f356e98e6a0 --- /dev/null +++ b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1.Tests/RequestIdOnExceptionInterceptorTests.cs @@ -0,0 +1,290 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Grpc.Core; +using Grpc.Core.Interceptors; +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Google.Cloud.Spanner.V1.Tests; + +public class RequestIdOnExceptionInterceptorTests +{ + private static readonly SpannerClientBuilder.RequestIdOnExceptionInterceptor s_interceptor = + SpannerClientBuilder.RequestIdOnExceptionInterceptor.Instance; + + private readonly string _fakeRequest = "fake"; + private readonly ClientInterceptorContext _unaryContext = CreateContext(MethodType.ClientStreaming); + private readonly ClientInterceptorContext _clientStreamingContext = CreateContext(MethodType.ClientStreaming); + private readonly ClientInterceptorContext _serverStreamingContext = CreateContext(MethodType.ClientStreaming); + private readonly ClientInterceptorContext _duplexStreamingContext = CreateContext(MethodType.ClientStreaming); + + private static readonly string s_requestId = Guid.NewGuid().ToString(); + private static readonly Metadata s_metadata = new() { { "x-goog-spanner-request-id", s_requestId } }; + private static readonly CallOptions s_options = new(headers: s_metadata); + private static readonly string s_sampleSessionName = "sessionName"; + + private static readonly Exception s_exception = new(); + + [Fact] + public void BlockingUnaryCall_ContinuationThrows_ExceptionEnriched() => + AssertThrowsEnrichedException(() => + s_interceptor.BlockingUnaryCall(_fakeRequest, _unaryContext, (req, ctx) => throw s_exception)); + + [Fact] + public void AsyncUnaryCall_ContinuationThrows_ExceptionEnriched() => + AssertThrowsEnrichedException(() => + s_interceptor.AsyncUnaryCall(_fakeRequest, _unaryContext, (req, ctx) => throw s_exception)); + + [Fact] + public void AsyncClientStreamingCall_ContinuationThrows_ExceptionEnriched() => + AssertThrowsEnrichedException(() => + s_interceptor.AsyncClientStreamingCall(_clientStreamingContext, (ctx) => throw s_exception)); + + [Fact] + public void AsyncServerStreamingCall_ContinuationThrows_ExceptionEnriched() => + AssertThrowsEnrichedException(() => + s_interceptor.AsyncServerStreamingCall(_fakeRequest, _serverStreamingContext, (req, ctx) => throw s_exception)); + + [Fact] + public void AsyncDuplexStreamingCall_ContinuationThrows_ExceptionEnriched() => + AssertThrowsEnrichedException(() => + s_interceptor.AsyncDuplexStreamingCall(_duplexStreamingContext, (ctx) => throw s_exception)); + + [Fact] + public async Task AsyncUnaryCall_ResponseAsyncThrows_ExceptionEnriched() + { + var call = CreateAsyncUnaryCall(response: Task.FromException(s_exception)); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncUnaryCall(_fakeRequest, _unaryContext, (req, ctx) => call).ResponseAsync); + } + + [Fact] + public async Task AsyncUnaryCall_ResponseHeadersAsyncThrows_ExceptionEnriched() + { + var call = CreateAsyncUnaryCall(headers: Task.FromException(s_exception)); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncUnaryCall(_fakeRequest, _unaryContext, (req, ctx) => call).ResponseHeadersAsync); + } + + [Fact] + public async Task AsyncClientStreamingCall_ResponseAsyncThrows_ExceptionEnriched() + { + var call = CreateAsyncClientStreamingCall(response: Task.FromException(s_exception)); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncClientStreamingCall(_clientStreamingContext, (ctx) => call).ResponseAsync); + } + + [Fact] + public async Task AsyncClientStreamingCall_ResponseHeadersAsyncThrows_ExceptionEnriched() + { + var call = CreateAsyncClientStreamingCall(headers: Task.FromException(s_exception)); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncClientStreamingCall(_clientStreamingContext, (ctx) => call).ResponseHeadersAsync); + } + + [Fact] + public async Task AsyncClientStreamingCall_RequestStreamWriteThrows_ExceptionEnriched() + { + var call = CreateAsyncClientStreamingCall(requestStream: new ThrowingClientStreamWriter(s_exception)); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncClientStreamingCall(_clientStreamingContext, (ctx) => call).RequestStream.WriteAsync("1")); + } + + [Fact] + public async Task AsyncClientStreamingCall_RequestStreamCompleteThrows_ExceptionEnriched() + { + var call = CreateAsyncClientStreamingCall(requestStream: new ThrowingClientStreamWriter(s_exception)); + var context = CreateContext(MethodType.ClientStreaming); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncClientStreamingCall(_clientStreamingContext, (ctx) => call).RequestStream.CompleteAsync()); + } + + [Fact] + public async Task AsyncServerStreamingCall_ResponseHeadersAsyncThrows_ExceptionEnriched() + { + var call = CreateAsyncServerStreamingCall(headers: Task.FromException(s_exception)); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncServerStreamingCall("1", _serverStreamingContext, (req, ctx) => call).ResponseHeadersAsync); + } + + [Fact] + public async Task AsyncServerStreamingCall_ResponseStreamMoveNextThrows_ExceptionEnriched() + { + var call = CreateAsyncServerStreamingCall(responseStream: new ThrowingAsyncStreamReader(s_exception)); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncServerStreamingCall(_fakeRequest, _serverStreamingContext, (req, ctx) => call).ResponseStream.MoveNext(default)); + } + + [Fact] + public async Task AsyncDuplexStreamingCall_ResponseHeadersAsyncThrows_ExceptionEnriched() + { + var call = CreateAsyncDuplexStreamingCall(headers: Task.FromException(s_exception)); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncDuplexStreamingCall(_duplexStreamingContext, (ctx) => call).ResponseHeadersAsync); + } + + [Fact] + public async Task AsyncDuplexStreamingCall_RequestStreamWriteThrows_ExceptionEnriched() + { + var call = CreateAsyncDuplexStreamingCall(requestStream: new ThrowingClientStreamWriter(s_exception)); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncDuplexStreamingCall(_duplexStreamingContext, (ctx) => call).RequestStream.WriteAsync("1")); + } + + [Fact] + public async Task AsyncDuplexStreamingCall_RequestStreamCompleteThrows_ExceptionEnriched() + { + var call = CreateAsyncDuplexStreamingCall(requestStream: new ThrowingClientStreamWriter(s_exception)); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncDuplexStreamingCall(_duplexStreamingContext, (ctx) => call).RequestStream.CompleteAsync()); + } + + [Fact] + public async Task AsyncDuplexStreamingCall_ResponseStreamMoveNextThrows_ExceptionEnriched() + { + var call = CreateAsyncDuplexStreamingCall(responseStream: new ThrowingAsyncStreamReader(s_exception)); + await AssertThrowsEnrichedExceptionAsync(() => + s_interceptor.AsyncDuplexStreamingCall(_duplexStreamingContext, (ctx) => call).ResponseStream.MoveNext(default)); + } + + /// + /// This test validates when a gRPC error is thrown while using the + /// the exception is enriched with the RequestId. We do not cover the full set of exception flows + /// because does not implement all gRPC call types (i.e. no DuplexStreaming + /// and ClientStreaming) and we have already covered all cases with direct unit tests on + /// . This test serves to validate + /// attaches the interceptor on build. + /// + [Fact] + public async Task SpannerClient_Throws_ExceptionEnriched() + { + var callInvoker = new FakeThrowingCallInvoker(s_exception); + var client = new SpannerClientBuilder { CallInvoker = callInvoker }.Build(); + + var stream = client.ExecuteStreamingSql(new ExecuteSqlRequest { Session = s_sampleSessionName, Sql = "SELECT 1" }); + await AssertThrowsEnrichedExceptionAsync(async () => await stream.GrpcCall.ResponseStream.MoveNext(default)); + } + + // Verification Helpers + + private static void AssertThrowsEnrichedException(Action action) => + AssertEnrichedException(Assert.Throws(action)); + + private static async Task AssertThrowsEnrichedExceptionAsync(Func action) => + AssertEnrichedException(await Assert.ThrowsAsync(action)); + + private static void AssertEnrichedException(Exception ex) + { + // The Exception.Data property should contain a non-empty Request ID field + Assert.Same(s_exception, ex); + Assert.True(s_exception.Data.Contains("x-goog-spanner-request-id")); + Assert.False(string.IsNullOrEmpty((string)s_exception.Data["x-goog-spanner-request-id"])); + } + + // Creation Helpers + + private static ClientInterceptorContext CreateContext(MethodType methodType) => + new ClientInterceptorContext( + new Method(methodType, "s", "m", Marshallers.StringMarshaller, Marshallers.StringMarshaller), + null, + s_options); + + private static AsyncUnaryCall CreateAsyncUnaryCall(Task response = null, Task headers = null) => + new AsyncUnaryCall( + response ?? Task.FromResult("1"), + headers ?? Task.FromResult(new Metadata()), + () => Status.DefaultSuccess, + () => new Metadata(), + () => { }); + + private static AsyncClientStreamingCall CreateAsyncClientStreamingCall( + IClientStreamWriter requestStream = null, + Task response = null, + Task headers = null) => + new AsyncClientStreamingCall( + requestStream ?? new ThrowingClientStreamWriter(null), + response ?? Task.FromResult("1"), + headers ?? Task.FromResult(new Metadata()), + () => Status.DefaultSuccess, + () => new Metadata(), + () => { }); + + private static AsyncServerStreamingCall CreateAsyncServerStreamingCall( + IAsyncStreamReader responseStream = null, + Task headers = null) => + new AsyncServerStreamingCall( + responseStream ?? new ThrowingAsyncStreamReader(null), + headers ?? Task.FromResult(new Metadata()), + () => Status.DefaultSuccess, + () => new Metadata(), + () => { }); + + private static AsyncDuplexStreamingCall CreateAsyncDuplexStreamingCall( + IClientStreamWriter requestStream = null, + IAsyncStreamReader responseStream = null, + Task headers = null) => + new AsyncDuplexStreamingCall( + requestStream ?? new ThrowingClientStreamWriter(null), + responseStream ?? new ThrowingAsyncStreamReader(null), + headers ?? Task.FromResult(new Metadata()), + () => Status.DefaultSuccess, + () => new Metadata(), + () => { }); + + private class ThrowingAsyncStreamReader : IAsyncStreamReader + { + private readonly Exception _exception; + public ThrowingAsyncStreamReader(Exception exception) => _exception = exception; + public T Current => default; + public Task MoveNext(CancellationToken cancellationToken) => + _exception != null ? Task.FromException(_exception) : Task.FromResult(false); + } + + private class ThrowingClientStreamWriter : IClientStreamWriter + { + private readonly Exception _exception; + public ThrowingClientStreamWriter(Exception exception) => _exception = exception; + public WriteOptions WriteOptions { get; set; } + public Task CompleteAsync() => _exception != null ? Task.FromException(_exception) : Task.CompletedTask; + public Task WriteAsync(T message) => _exception != null ? Task.FromException(_exception) : Task.CompletedTask; + } + + private class FakeThrowingCallInvoker : CallInvoker + { + private readonly Exception _exceptionToThrow; + + public FakeThrowingCallInvoker(Exception exceptionToThrow) + { + _exceptionToThrow = exceptionToThrow; + } + + public override AsyncServerStreamingCall AsyncServerStreamingCall(Method method, string host, CallOptions options, TRequest request) + { + return new AsyncServerStreamingCall( + new ThrowingAsyncStreamReader(_exceptionToThrow), + Task.FromException(_exceptionToThrow), + () => Status.DefaultSuccess, + () => new Metadata(), + () => { }); + } + + public override AsyncUnaryCall AsyncUnaryCall(Method method, string host, CallOptions options, TRequest request) => throw new NotImplementedException(); + public override AsyncClientStreamingCall AsyncClientStreamingCall(Method method, string host, CallOptions options) => throw new NotImplementedException(); + public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall(Method method, string host, CallOptions options) => throw new NotImplementedException(); + public override TResponse BlockingUnaryCall(Method method, string host, CallOptions options, TRequest request) => throw new NotImplementedException(); + } +} diff --git a/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1.Tests/RequestIdTests.cs b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1.Tests/RequestIdTests.cs new file mode 100644 index 000000000000..fd912133580d --- /dev/null +++ b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1.Tests/RequestIdTests.cs @@ -0,0 +1,281 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Grpc.Core; +using Google.Api.Gax.Grpc; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Xunit; + +namespace Google.Cloud.Spanner.V1.Tests; + +public class RequestIdTests +{ + private const string SampleDatabaseName = "projects/proj/instances/inst/databases/db"; + private const string SampleSessionName = "projects/proj/instances/inst/databases/db/sessions/sess"; + + [Theory] + [MemberData(nameof(SpannerClientActions))] + public void RequestId_Format(Action action) + { + var invoker = new SyncFailureCallInvoker(0); + var client = new SpannerClientBuilder { CallInvoker = invoker }.Build(); + + action(client); + + // The expected format is 6 parts that break down as such: + // {VersionId}.{ProcessId}.{ClientId}.{ChannelId}.{RequestId}.{AttemptNum} + Metadata headerMetadata = invoker.CapturedMetadata[0]; + var idString = headerMetadata.Get("x-goog-spanner-request-id").Value; + + var parts = new SpannerRequestIdParts(idString); + + Assert.Equal(1, parts.VersionId); + // The ClientId cannot be determined exaclty so we just ensure it's postive + Assert.True(parts.ClientId > 0); + Assert.Equal(1, parts.ChannelId); + Assert.Equal(1, parts.RequestId); + Assert.Equal(1, parts.AttemptNum); + } + + [Theory] + [MemberData(nameof(SpannerClientActions))] + public void SetsHeaderOnRpcCalls(Action action) + { + var invoker = new SyncFailureCallInvoker(0); + var client = new SpannerClientBuilder { CallInvoker = invoker }.Build(); + action(client); + Metadata.Entry entry = Assert.Single(invoker.CapturedMetadata[0], e => e.Key == "x-goog-spanner-request-id"); + Assert.NotNull(entry.Value); + } + + [Fact] + public void IncrementsRequestIdOnRetry() + { + var invoker = new SyncFailureCallInvoker(numberOfFailuresToSimulate: 1, statusCodeToThrow: StatusCode.ResourceExhausted); + var settings = new SpannerSettings + { + // Configure the CreateSession call to retry on Unavailable errors. + CallSettings = CallSettings.FromRetry(RetrySettings.FromExponentialBackoff( + maxAttempts: 3, + initialBackoff: TimeSpan.FromMilliseconds(1), + maxBackoff: TimeSpan.FromMilliseconds(1), + backoffMultiplier: 1.0, + retryFilter: RetrySettings.FilterForStatusCodes(StatusCode.Unavailable))) + }; + var client = new SpannerClientBuilder { CallInvoker = invoker, Settings = settings }.Build(); + + client.CreateSession(new CreateSessionRequest { Database = SampleDatabaseName }); + client.CreateSession(new CreateSessionRequest { Database = SampleDatabaseName }); + client.CreateSession(new CreateSessionRequest { Database = SampleDatabaseName }); + + // Assert that the invoker was called four times for the three client calls. + // The first call should have failed the first time and succeeded on retry. + Assert.Equal(4, invoker.CapturedMetadata.Count); + + var requestIds = invoker.CapturedMetadata + .Select(m => m.Single(e => e.Key == "x-goog-spanner-request-id").Value) + .Select(id => new SpannerRequestIdParts(id)) + .ToList(); + + Assert.Equal((1, 1), (requestIds[0].RequestId, requestIds[0].AttemptNum)); + Assert.Equal((1, 2), (requestIds[1].RequestId, requestIds[1].AttemptNum)); + Assert.Equal((2, 1), (requestIds[2].RequestId, requestIds[2].AttemptNum)); + Assert.Equal((3, 1), (requestIds[3].RequestId, requestIds[3].AttemptNum)); + } + + [Fact] + public void RequestIdSource_ProcessId_OverwritingBehavior() => + // The process ID should only ever be set once per process. + Assert.Throws(() => + { + // Note in the case of test re-runs within the same process the + // "first_override" may cause the exception to be thrown as it was + // set prior. + SpannerClientImpl.ProcessId = 1UL; + SpannerClientImpl.ProcessId = 2UL; + }); + + public static TheoryData> SpannerClientActions { get; } = new TheoryData> + { + client => client.ExecuteSql(new ExecuteSqlRequest { Session = SampleSessionName, Sql = "SELECT 1" }), + client => client.ExecuteStreamingSql(new ExecuteSqlRequest { Session = SampleSessionName, Sql = "SELECT 1" }), + client => client.GetSession(new GetSessionRequest { Name = SampleSessionName }), + client => client.ListSessions(new ListSessionsRequest { Database = SampleDatabaseName }).AsRawResponses().First(), + client => client.DeleteSession(new DeleteSessionRequest { Name = SampleSessionName }), + client => client.ExecuteSql(new ExecuteSqlRequest { Session = SampleSessionName }), + client => client.ExecuteBatchDml(new ExecuteBatchDmlRequest { Session = SampleSessionName }), + client => client.Read(new ReadRequest { Session = SampleSessionName }), + client => client.StreamingRead(new ReadRequest { Session = SampleSessionName }), + client => client.BeginTransaction(new BeginTransactionRequest { Session = SampleSessionName }), + client => client.Commit(new CommitRequest { Session = SampleSessionName }), + client => client.Rollback(new RollbackRequest { Session = SampleSessionName }), + client => client.PartitionQuery(new PartitionQueryRequest { Session = SampleSessionName }), + client => client.PartitionRead(new PartitionReadRequest { Session = SampleSessionName }), + }; + + private struct SpannerRequestIdParts + { + public int VersionId { get; } + public ulong ProcessId { get; } + public int ClientId { get; } + public int ChannelId { get; } + public int RequestId { get; } + public int AttemptNum { get; } + + public SpannerRequestIdParts(string requestId) + { + var parts = requestId.Split('.'); + Assert.Equal(6, parts.Length); + + VersionId = int.Parse(parts[0]); + ProcessId = ulong.Parse(parts[1]); + ClientId = int.Parse(parts[2]); + ChannelId = int.Parse(parts[3]); + RequestId = int.Parse(parts[4]); + AttemptNum = int.Parse(parts[5]); + } + } + + /// + /// CallInvoker that throws an RpcException synchronously upon method invocation. + /// Simulates immediate failures like client-side validation or connection errors. + /// Used by SpannerClient tests to verify synchronous error handling. + /// + private class SyncFailureCallInvoker : CallInvoker + { + private int _invocationCount = 0; + private readonly int _numberOfFailuresToSimulate; + private readonly StatusCode _statusCodeToThrow; + private readonly string _exceptionMessage; + + /// + /// Creates a new instance of . + /// + /// The number of times to simulate a failure before succeeding. + /// The gRPC status code to use in the thrown exception. + /// The message to use in the thrown exception. + public SyncFailureCallInvoker(int numberOfFailuresToSimulate = int.MaxValue, StatusCode statusCodeToThrow = StatusCode.Internal, string exceptionMessage = "Test exception") + { + _numberOfFailuresToSimulate = numberOfFailuresToSimulate; + _statusCodeToThrow = statusCodeToThrow; + _exceptionMessage = exceptionMessage; + } + + /// + /// The list of metadata headers captured from each method invocation. + /// + public List CapturedMetadata { get; } = new List(); + + /// + /// Records the metadata from a call. + /// + /// The metadata headers to record. + protected void RecordMetadata(Metadata headers) + { + CapturedMetadata.Add(headers); + } + + /// + public override TResponse BlockingUnaryCall(Method method, string host, CallOptions options, TRequest request) + { + RecordMetadata(options.Headers); + MaybeThrowException(); + return (TResponse)Activator.CreateInstance(typeof(TResponse)); + } + + /// + public override AsyncUnaryCall AsyncUnaryCall(Method method, string host, CallOptions options, TRequest request) + { + RecordMetadata(options.Headers); + MaybeThrowException(); + return new AsyncUnaryCall( + Task.FromResult((TResponse)Activator.CreateInstance(typeof(TResponse))), + Task.FromResult(new Metadata()), + () => Status.DefaultSuccess, + () => new Metadata(), + () => { }); + } + + /// + public override AsyncServerStreamingCall AsyncServerStreamingCall(Method method, string host, CallOptions options, TRequest request) + { + RecordMetadata(options.Headers); + MaybeThrowException(); + return new AsyncServerStreamingCall( + new EmptyAsyncStreamReader(), + Task.FromResult(new Metadata()), + () => Status.DefaultSuccess, + () => new Metadata(), + () => { }); + } + + /// + /// Determines whether the current invocation should fail based on the configured failure count. + /// + /// True if the call should fail; otherwise, false. + protected bool ShouldFail() + { + _invocationCount++; + return _invocationCount <= _numberOfFailuresToSimulate; + } + + /// + /// Creates the RpcException to be thrown or returned. + /// + /// The configured RpcException. + protected RpcException CreateRpcException() + { + return new RpcException(new Status(_statusCodeToThrow, _exceptionMessage)); + } + + /// + /// Checks whether the current invocation should fail, and throws an exception if so. + /// + protected void MaybeThrowException() + { + if (ShouldFail()) + { + throw CreateRpcException(); + } + } + + /// + public override AsyncClientStreamingCall AsyncClientStreamingCall(Method method, string host, CallOptions options) => + throw new NotImplementedException(); // SpannerClient does not support client streaming calls + + /// + public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall(Method method, string host, CallOptions options) => + throw new NotImplementedException(); // SpannerClient does not support duplex streaming calls + } + + /// + /// IAsyncStreamReader{T} implementation that immediately signals the end of the stream. + /// Used as a placeholder in test CallInvokers where no data is expected. + /// + /// The message type. + public class EmptyAsyncStreamReader : IAsyncStreamReader + { + /// + public T Current => default; + + /// + public Task MoveNext(System.Threading.CancellationToken cancellationToken) + { + return Task.FromResult(false); + } + } +} diff --git a/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientBuilderPartial.RequestIdOnExceptionInterceptor.cs b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientBuilderPartial.RequestIdOnExceptionInterceptor.cs new file mode 100644 index 000000000000..812e43549868 --- /dev/null +++ b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientBuilderPartial.RequestIdOnExceptionInterceptor.cs @@ -0,0 +1,226 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Grpc.Core; +using Grpc.Core.Interceptors; +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Google.Cloud.Spanner.V1; + +public partial class SpannerClientBuilder +{ + /// + /// A that wraps all calls, adding the Spanner request ID to any exceptions thrown while + /// using the . + /// + internal sealed class RequestIdOnExceptionInterceptor : Interceptor + { + /// + /// Provides access to the singleton instance of + /// + internal static RequestIdOnExceptionInterceptor Instance = new(); + + private RequestIdOnExceptionInterceptor() + { + } + + /// + public override TResponse BlockingUnaryCall( + TRequest request, + ClientInterceptorContext context, + BlockingUnaryCallContinuation continuation) => + WrapException(() => continuation(request, context), context.Options); + + /// + public override AsyncUnaryCall AsyncUnaryCall( + TRequest request, + ClientInterceptorContext context, + AsyncUnaryCallContinuation continuation) + { + var call = WrapException(() => continuation(request, context), context.Options); + return new AsyncUnaryCall( + WrapExceptionAsync(call.ResponseAsync, context.Options), + WrapExceptionAsync(call.ResponseHeadersAsync, context.Options), + call.GetStatus, + call.GetTrailers, + call.Dispose); + } + + /// + public override AsyncClientStreamingCall AsyncClientStreamingCall( + ClientInterceptorContext context, + AsyncClientStreamingCallContinuation continuation) + { + var call = WrapException(() => continuation(context), context.Options); + return new AsyncClientStreamingCall( + new SpannerRequestIdStreamWriter(call.RequestStream, context.Options), + WrapExceptionAsync(call.ResponseAsync, context.Options), + WrapExceptionAsync(call.ResponseHeadersAsync, context.Options), + call.GetStatus, + call.GetTrailers, + call.Dispose); + } + + /// + public override AsyncServerStreamingCall AsyncServerStreamingCall( + TRequest request, + ClientInterceptorContext context, + AsyncServerStreamingCallContinuation continuation) + { + var call = WrapException(() => continuation(request, context), context.Options); + var wrappedResponseStream = new SpannerRequestIdStreamReader(call.ResponseStream, context.Options); + + return new AsyncServerStreamingCall( + wrappedResponseStream, + WrapExceptionAsync(call.ResponseHeadersAsync, context.Options), + call.GetStatus, + call.GetTrailers, + call.Dispose); + } + + /// + public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall( + ClientInterceptorContext context, + AsyncDuplexStreamingCallContinuation continuation) + { + var call = WrapException(() => continuation(context), context.Options); + var wrappedResponseStream = new SpannerRequestIdStreamReader(call.ResponseStream, context.Options); + var wrappedRequestStream = new SpannerRequestIdStreamWriter(call.RequestStream, context.Options); + + return new AsyncDuplexStreamingCall( + wrappedRequestStream, + wrappedResponseStream, + WrapExceptionAsync(call.ResponseHeadersAsync, context.Options), + call.GetStatus, + call.GetTrailers, + call.Dispose); + } + + /// + /// Handles the response asynchronously, adding the request ID to any exceptions thrown. + /// + private static async Task WrapExceptionAsync(Task task, CallOptions options) + { + try + { + return await task.ConfigureAwait(false); + } + catch (Exception e) + { + EnrichException(e, options); + throw; + } + } + + /// + /// Handles the response asynchronously, adding the request ID to any exceptions thrown. + /// + private static async Task WrapExceptionAsync(Task task, CallOptions options) + { + try + { + await task.ConfigureAwait(false); + } + catch (Exception e) + { + EnrichException(e, options); + throw; + } + } + + /// + /// Handles the response, adding the request ID to any exceptions thrown. + /// + private static T WrapException(Func action, CallOptions options) + { + try + { + return action(); + } + catch (Exception e) + { + EnrichException(e, options); + throw; + } + } + + /// + /// Enriches an exception with the Spanner Request ID from the provided . + /// + /// The exception to enrich. + /// The containing the request ID. + /// The enriched exception (the same instance passed in). + private static Exception EnrichException(Exception e, CallOptions options) + { + var requestId = GetRequestIdFromOptions(options); + if (requestId != null) + { + e.Data[SpannerClientImpl.RequestIdHeaderKey] = requestId; + } + return e; + + static string GetRequestIdFromOptions(CallOptions options) + { + if (options.Headers is Metadata headers) + { + return headers.GetValue(SpannerClientImpl.RequestIdHeaderKey); + } + return null; + } + } + + /// + /// A stream reader that wraps the original stream reader and adds the request ID to any exceptions thrown. + /// + private class SpannerRequestIdStreamReader : IAsyncStreamReader + { + private readonly IAsyncStreamReader _originalReader; + private readonly CallOptions _options; + + public SpannerRequestIdStreamReader(IAsyncStreamReader originalReader, CallOptions options) + { + _originalReader = originalReader; + _options = options; + } + + public T Current => _originalReader.Current; + + public Task MoveNext(CancellationToken cancellationToken) => + WrapExceptionAsync(_originalReader.MoveNext(cancellationToken), _options); + } + + /// + /// A stream writer that wraps the original stream writer and adds the request ID to any exceptions thrown. + /// + private class SpannerRequestIdStreamWriter : IClientStreamWriter + { + private readonly IClientStreamWriter _originalWriter; + private readonly CallOptions _options; + + public SpannerRequestIdStreamWriter(IClientStreamWriter originalWriter, CallOptions options) + { + _originalWriter = originalWriter; + _options = options; + } + + public WriteOptions WriteOptions { get => _originalWriter.WriteOptions; set => _originalWriter.WriteOptions = value; } + + public Task CompleteAsync() => WrapExceptionAsync(_originalWriter.CompleteAsync(), _options); + + public Task WriteAsync(T message) => WrapExceptionAsync(_originalWriter.WriteAsync(message), _options); + } + } +} diff --git a/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientBuilderPartial.cs b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientBuilderPartial.cs index 35f9668ab2b1..3961c31ab5d2 100644 --- a/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientBuilderPartial.cs +++ b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientBuilderPartial.cs @@ -16,6 +16,7 @@ using Google.Api.Gax.Grpc; using Google.Api.Gax.Grpc.Gcp; using Grpc.Core; +using Grpc.Core.Interceptors; using System; using System.Threading; using System.Threading.Tasks; @@ -25,6 +26,16 @@ namespace Google.Cloud.Spanner.V1 { public partial class SpannerClientBuilder { + /// + /// The process ID, assigned to each outgoing RPC to identify the request source. This can be set once, + /// but only before the process has made its first Spanner request. + /// + /// The process ID has already been set. + public static ulong ProcessId + { + set => SpannerClientImpl.ProcessId = value; + } + /// /// The Grpc.Gcp method configurations for pool options. /// @@ -150,8 +161,9 @@ partial void InterceptBuildAsync(CancellationToken cancellationToken, ref Task - protected override CallInvoker CreateCallInvoker() => - AffinityChannelPoolConfiguration is null + protected override CallInvoker CreateCallInvoker() + { + var invoker = AffinityChannelPoolConfiguration is null ? base.CreateCallInvoker() : new GcpCallInvoker( ServiceMetadata, @@ -160,10 +172,13 @@ AffinityChannelPoolConfiguration is null GetChannelOptions(), GetApiConfig(), EffectiveGrpcAdapter); + return invoker.Intercept(RequestIdOnExceptionInterceptor.Instance); + } /// - protected override async Task CreateCallInvokerAsync(CancellationToken cancellationToken) => - AffinityChannelPoolConfiguration is null + protected override async Task CreateCallInvokerAsync(CancellationToken cancellationToken) + { + var invoker = AffinityChannelPoolConfiguration is null ? await base.CreateCallInvokerAsync(cancellationToken).ConfigureAwait(false) : new GcpCallInvoker( ServiceMetadata, @@ -172,6 +187,8 @@ await GetChannelCredentialsAsync(cancellationToken).ConfigureAwait(false), GetChannelOptions(), GetApiConfig(), EffectiveGrpcAdapter); + return invoker.Intercept(RequestIdOnExceptionInterceptor.Instance); + } private ApiConfig GetApiConfig() => new ApiConfig { diff --git a/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.ProcessId.cs b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.ProcessId.cs new file mode 100644 index 000000000000..79f97c398a24 --- /dev/null +++ b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.ProcessId.cs @@ -0,0 +1,66 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, + +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Threading; + +namespace Google.Cloud.Spanner.V1; + +public partial class SpannerClientImpl +{ + /// + /// The process ID. + /// + /// The process ID has already been set. + internal static ulong ProcessId + { + set => ProcessIdSource.Set(value); + } + + private static class ProcessIdSource + { + private static string s_value; + internal static string Value + { + get + { + // If not set, attempt to set it with a new random value. + // (It's okay if multiple threads generate a value; only the first one wins). + if (s_value == null) + { + Interlocked.CompareExchange(ref s_value, GenerateId(), null); + } + return s_value; + } + } + + internal static void Set(ulong id) + { + // Atomically set the value. If it was already set (i.e. returns non-null), throw. + if (Interlocked.CompareExchange(ref s_value, id.ToString(), null) != null) + { + throw new InvalidOperationException("The Process ID was already set and cannot be overwritten now."); + } + } + + private static string GenerateId() + { + var random = new Random(); + var buffer = new byte[sizeof(ulong)]; + random.NextBytes(buffer); + return BitConverter.ToUInt64(buffer, 0).ToString(); + } + } +} diff --git a/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.RequestId.cs b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.RequestId.cs new file mode 100644 index 000000000000..ee33aadbc8fc --- /dev/null +++ b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.RequestId.cs @@ -0,0 +1,97 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, + +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Grpc.Core; + +namespace Google.Cloud.Spanner.V1; + +public partial class SpannerClientImpl +{ + /// + /// The header key used for Spanner request IDs. + /// + internal const string RequestIdHeaderKey = "x-goog-spanner-request-id"; + + /// + /// Represents a structured request ID for Spanner RPCs, formatted as: + /// {version}.{process}.{client}.{channel}.{request}.{attempt} + /// + private sealed class RequestId + { + /// + /// The version number of the request ID format. + /// + private const int FormatVersion = 1; + + /// + /// A unique ID for the gRPC channel being used. This is hardcoded to 1 for now. + /// See: b/459445539 + /// + private readonly int _channelId = 1; + + /// + /// A unique ID for the instance that generated the request. + /// + private readonly int _clientId; + + /// + /// A unique ID for the logical request made using the current . + /// + private readonly int _requestId; + + /// + /// The attempt count of the request, incremented before each RPC attempt. + /// + private int _attemptCount; + + /// + /// Initializes a new instance of the class + /// with a specified client and logical request identifier. + /// + /// The ID for the associated with the request. + /// The ID of the logical request within the client. + internal RequestId(int clientId, int requestId) + { + _clientId = clientId; + _requestId = requestId; + } + + /// + /// Increments the request's attempt number and then returns the request ID. + /// + /// + /// Retry attempts are expected to be sequential so we can safely + /// increment without first acquiring a lock. + /// + private void IncrementAttempt() => _attemptCount++; + + /// + /// Add a request ID header. The header mutation increments the attempt each time the header is populated which + /// happens right before a call. This allows us to increment the attempt num and assign a unique request id in the + /// case of retries. + /// + internal void AddHeader(Metadata metadata) + { + IncrementAttempt(); + metadata.Add(RequestIdHeaderKey, ToString()); + } + + /// + /// Returns the string representation of the Spanner request ID. + /// + public override string ToString() => $"{FormatVersion}.{ProcessIdSource.Value}.{_clientId}.{_channelId}.{_requestId}.{_attemptCount}"; + } + +} diff --git a/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.RequestIdSource.cs b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.RequestIdSource.cs new file mode 100644 index 000000000000..e1f2793a654c --- /dev/null +++ b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.RequestIdSource.cs @@ -0,0 +1,61 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, + +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Threading; + +namespace Google.Cloud.Spanner.V1; + +public partial class SpannerClientImpl +{ + /// + /// Manages the generation of unique instances for a specific + /// instance. + /// + /// + /// This class acts as a dedicated factory for instances. + /// There should be a 1:1 mapping between a instance and a instance, + /// which simplifies the by centralizing the request ID generation logic. + /// + private sealed class RequestIdSource + { + /// + /// A counter that is incremented on the instantiation of each new used to identify + /// the . Given the one-to-one mapping between a and a , this effectively means the counter is incremented for each new client instance. + /// + private static int s_instanceCounter; + + /// + /// A counter incremented for each logical request made from the client. + /// + private int _requestCounter = 0; + + /// + /// The unique ID for the instance this generator is associated with. + /// + private readonly int _clientId; + + /// + /// Initializes a new instance of the class, + /// assigning a unique client identifier. + /// + internal RequestIdSource() => _clientId = Interlocked.Increment(ref s_instanceCounter); + + /// + /// Generates the next for a new logical request made with the associated instance. + /// + internal RequestId NewRequestId() => new RequestId(_clientId, Interlocked.Increment(ref _requestCounter)); + } +} diff --git a/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.cs b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.cs index cfb8c5163913..5a259292b17c 100644 --- a/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.cs +++ b/apis/Google.Cloud.Spanner.V1/Google.Cloud.Spanner.V1/SpannerClientPartial.cs @@ -14,6 +14,7 @@ using Google.Api.Gax.Grpc; using Google.Cloud.Spanner.Common.V1; +using System.Threading; namespace Google.Cloud.Spanner.V1 { @@ -48,6 +49,11 @@ internal void MaybeApplyRouteToLeaderHeader(ref CallSettings settings) public partial class SpannerClientImpl { + /// + /// Creates the request ID associated with each RPC. + /// + private RequestIdSource _requestIdSource; + /// /// The name of the header used for efficiently routing requests. /// @@ -64,78 +70,100 @@ public partial class SpannerClientImpl partial void OnConstruction(Spanner.SpannerClient grpcClient, SpannerSettings effectiveSettings, ClientHelper clientHelper) { Settings = effectiveSettings; + _requestIdSource = new RequestIdSource(); } partial void Modify_CreateSessionRequest(ref CreateSessionRequest request, ref CallSettings settings) { ApplyResourcePrefixHeaderFromDatabase(ref settings, request.Database); MaybeApplyRouteToLeaderHeader(ref settings); + ApplyRequestIdHeader(ref settings); } partial void Modify_BatchCreateSessionsRequest(ref BatchCreateSessionsRequest request, ref CallSettings settings) { ApplyResourcePrefixHeaderFromDatabase(ref settings, request.Database); MaybeApplyRouteToLeaderHeader(ref settings); + ApplyRequestIdHeader(ref settings); } partial void Modify_GetSessionRequest(ref GetSessionRequest request, ref CallSettings settings) { ApplyResourcePrefixHeaderFromSession(ref settings, request.Name); MaybeApplyRouteToLeaderHeader(ref settings); + ApplyRequestIdHeader(ref settings); } - partial void Modify_ListSessionsRequest(ref ListSessionsRequest request, ref CallSettings settings) => + partial void Modify_ListSessionsRequest(ref ListSessionsRequest request, ref CallSettings settings) + { // This operation is never routed to leader so we don't call MaybeApplyRouteToLeaderHeader. ApplyResourcePrefixHeaderFromDatabase(ref settings, request.Database); + ApplyRequestIdHeader(ref settings); + } - partial void Modify_DeleteSessionRequest(ref DeleteSessionRequest request, ref CallSettings settings) => + partial void Modify_DeleteSessionRequest(ref DeleteSessionRequest request, ref CallSettings settings) + { // This operation is never routed to leader so we don't call MaybeApplyRouteToLeaderHeader. ApplyResourcePrefixHeaderFromSession(ref settings, request.Name); + ApplyRequestIdHeader(ref settings); + } - partial void Modify_ExecuteSqlRequest(ref ExecuteSqlRequest request, ref CallSettings settings) => + partial void Modify_ExecuteSqlRequest(ref ExecuteSqlRequest request, ref CallSettings settings) + { // This operations is routed to leader only if the transaction it uses is of a certain type. // We don't have that information here so the leader routing header needs to be applied elsewhere. ApplyResourcePrefixHeaderFromSession(ref settings, request.Session); + ApplyRequestIdHeader(ref settings); + } partial void Modify_ExecuteBatchDmlRequest(ref ExecuteBatchDmlRequest request, ref CallSettings settings) { ApplyResourcePrefixHeaderFromSession(ref settings, request.Session); MaybeApplyRouteToLeaderHeader(ref settings); + ApplyRequestIdHeader(ref settings); } - partial void Modify_ReadRequest(ref ReadRequest request, ref CallSettings settings) => + partial void Modify_ReadRequest(ref ReadRequest request, ref CallSettings settings) + { // This operations is routed to leader only if the transaction it uses is of a certain type. // We don't have that information here so the leader routing header needs to be applied elsewhere. ApplyResourcePrefixHeaderFromSession(ref settings, request.Session); + ApplyRequestIdHeader(ref settings); + } partial void Modify_BeginTransactionRequest(ref BeginTransactionRequest request, ref CallSettings settings) { ApplyResourcePrefixHeaderFromSession(ref settings, request.Session); MaybeApplyRouteToLeaderHeader(ref settings, request.Options?.ModeCase ?? TransactionOptions.ModeOneofCase.None); + ApplyRequestIdHeader(ref settings); } partial void Modify_CommitRequest(ref CommitRequest request, ref CallSettings settings) { ApplyResourcePrefixHeaderFromSession(ref settings, request.Session); MaybeApplyRouteToLeaderHeader(ref settings); + ApplyRequestIdHeader(ref settings); } partial void Modify_RollbackRequest(ref RollbackRequest request, ref CallSettings settings) { ApplyResourcePrefixHeaderFromSession(ref settings, request.Session); MaybeApplyRouteToLeaderHeader(ref settings); + ApplyRequestIdHeader(ref settings); } partial void Modify_PartitionQueryRequest(ref PartitionQueryRequest request, ref CallSettings settings) { ApplyResourcePrefixHeaderFromSession(ref settings, request.Session); MaybeApplyRouteToLeaderHeader(ref settings); + ApplyRequestIdHeader(ref settings); } partial void Modify_PartitionReadRequest(ref PartitionReadRequest request, ref CallSettings settings) { ApplyResourcePrefixHeaderFromSession(ref settings, request.Session); MaybeApplyRouteToLeaderHeader(ref settings); + ApplyRequestIdHeader(ref settings); } internal static void ApplyResourcePrefixHeaderFromDatabase(ref CallSettings settings, string resource) @@ -167,5 +195,12 @@ internal static void ApplyResourcePrefixHeaderFromSession(ref CallSettings setti settings = settings.WithHeader(ResourcePrefixHeader, database.ToString()); } } + + internal void ApplyRequestIdHeader(ref CallSettings settings) + { + var requestId = _requestIdSource.NewRequestId(); + var newSettings = CallSettings.FromHeaderMutation(requestId.AddHeader); + settings = settings.MergedWith(newSettings); + } } }