diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java index ed5b0179349..3ae3016af5e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java @@ -27,6 +27,10 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.spanner.v1.BatchWriteResponse; import io.opentelemetry.api.common.Attributes; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.Nullable; class DatabaseClientImpl implements DatabaseClient { @@ -40,6 +44,8 @@ class DatabaseClientImpl implements DatabaseClient { @VisibleForTesting final MultiplexedSessionDatabaseClient multiplexedSessionDatabaseClient; @VisibleForTesting final boolean useMultiplexedSessionPartitionedOps; @VisibleForTesting final boolean useMultiplexedSessionForRW; + private final int dbId; + private final AtomicInteger nthRequest; final boolean useMultiplexedSessionBlindWrite; @@ -86,6 +92,18 @@ class DatabaseClientImpl implements DatabaseClient { this.tracer = tracer; this.useMultiplexedSessionForRW = useMultiplexedSessionForRW; this.commonAttributes = commonAttributes; + + this.dbId = this.dbIdFromClientId(this.clientId); + this.nthRequest = new AtomicInteger(0); + } + + private int dbIdFromClientId(String clientId) { + int i = clientId.indexOf("-"); + String strWithValue = clientId.substring(i + 1); + if (Objects.equals(strWithValue, "")) { + strWithValue = "0"; + } + return Integer.parseInt(strWithValue); } @VisibleForTesting @@ -179,8 +197,21 @@ public CommitResponse writeAtLeastOnceWithOptions( return getMultiplexedSessionDatabaseClient() .writeAtLeastOnceWithOptions(mutations, options); } + + int nthRequest = this.nextNthRequest(); + int channelId = 1; /* TODO: infer the channelId from the gRPC channel of the session */ + XGoogSpannerRequestId reqId = XGoogSpannerRequestId.of(this.dbId, channelId, nthRequest, 0); + return runWithSessionRetry( - session -> session.writeAtLeastOnceWithOptions(mutations, options)); + (session) -> { + reqId.incrementAttempt(); + // TODO: Update the channelId depending on the session that is inferred. + ArrayList allOptions = new ArrayList(Arrays.asList(options)); + System.out.println("\033[35msession.class: " + session.getClass() + "\033[00m"); + allOptions.add(new Options.RequestIdOption(reqId)); + return session.writeAtLeastOnceWithOptions( + mutations, allOptions.toArray(new TransactionOption[0])); + }); } catch (RuntimeException e) { span.setStatus(e); throw e; @@ -189,6 +220,10 @@ public CommitResponse writeAtLeastOnceWithOptions( } } + private int nextNthRequest() { + return this.nthRequest.incrementAndGet(); + } + @Override public ServerStream batchWriteAtLeastOnce( final Iterable mutationGroups, final TransactionOption... options) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java index c8c588f813a..5c993965094 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java @@ -512,6 +512,7 @@ void appendToOptions(Options options) { private RpcOrderBy orderBy; private RpcLockHint lockHint; private Boolean lastStatement; + private XGoogSpannerRequestId reqId; // Construction is via factory methods below. private Options() {} @@ -568,6 +569,14 @@ String pageToken() { return pageToken; } + boolean hasReqId() { + return reqId != null; + } + + XGoogSpannerRequestId reqId() { + return reqId; + } + boolean hasFilter() { return filter != null; } @@ -1018,4 +1027,30 @@ public boolean equals(Object o) { return o instanceof LastStatementUpdateOption; } } + + static final class RequestIdOption extends InternalOption + implements TransactionOption, UpdateOption { + private final XGoogSpannerRequestId reqId; + + RequestIdOption(XGoogSpannerRequestId reqId) { + this.reqId = reqId; + } + + @Override + void appendToOptions(Options options) { + options.reqId = this.reqId; + } + + @Override + public int hashCode() { + return RequestIdOption.class.hashCode(); + } + + @Override + public boolean equals(Object o) { + // TODO: Examine why the precedent for LastStatementUpdateOption + // does not check against the actual value. + return o instanceof RequestIdOption; + } + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java index 2edfb66d896..405c5f86812 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java @@ -31,10 +31,11 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.concurrent.GuardedBy; /** Client for creating single sessions and batches of sessions. */ -class SessionClient implements AutoCloseable { +class SessionClient implements AutoCloseable, XGoogSpannerRequestId.RequestIdCreator { static class SessionId { private static final PathTemplate NAME_TEMPLATE = PathTemplate.create( @@ -174,6 +175,12 @@ interface SessionConsumer { private final DatabaseId db; private final Attributes commonAttributes; + // SessionClient is created long before a DatabaseClientImpl is created, + // as batch sessions are firstly created then later attached to each Client. + private static AtomicInteger NTH_ID = new AtomicInteger(0); + private final int nthId; + private final AtomicInteger nthRequest; + @GuardedBy("this") private volatile long sessionChannelCounter; @@ -186,6 +193,8 @@ interface SessionConsumer { this.executorFactory = executorFactory; this.executor = executorFactory.get(); this.commonAttributes = spanner.getTracer().createCommonAttributes(db); + this.nthId = SessionClient.NTH_ID.incrementAndGet(); + this.nthRequest = new AtomicInteger(0); } @Override @@ -201,16 +210,24 @@ DatabaseId getDatabaseId() { return db; } + @Override + public XGoogSpannerRequestId nextRequestId(long channelId, int attempt) { + return XGoogSpannerRequestId.of(this.nthId, this.nthRequest.incrementAndGet(), channelId, 1); + } + /** Create a single session. */ SessionImpl createSession() { // The sessionChannelCounter could overflow, but that will just flip it to Integer.MIN_VALUE, // which is also a valid channel hint. final Map options; + final long channelId; synchronized (this) { options = optionMap(SessionOption.channelHint(sessionChannelCounter++)); + channelId = sessionChannelCounter; } ISpan span = spanner.getTracer().spanBuilder(SpannerImpl.CREATE_SESSION, this.commonAttributes); try (IScope s = spanner.getTracer().withSpan(span)) { + XGoogSpannerRequestId reqId = this.nextRequestId(channelId, 1); com.google.spanner.v1.Session session = spanner .getRpc() @@ -218,11 +235,13 @@ SessionImpl createSession() { db.getName(), spanner.getOptions().getDatabaseRole(), spanner.getOptions().getSessionLabels(), - options); + reqId.withOptions(options)); SessionReference sessionReference = new SessionReference( session.getName(), session.getCreateTime(), session.getMultiplexed(), options); - return new SessionImpl(spanner, sessionReference); + SessionImpl sessionImpl = new SessionImpl(spanner, sessionReference); + sessionImpl.setRequestIdCreator(this); + return sessionImpl; } catch (RuntimeException e) { span.setStatus(e); throw e; @@ -273,6 +292,7 @@ SessionImpl createMultiplexedSession() { spanner, new SessionReference( session.getName(), session.getCreateTime(), session.getMultiplexed(), null)); + sessionImpl.setRequestIdCreator(this); span.addAnnotation( String.format("Request for %d multiplexed session returned %d session", 1, 1)); return sessionImpl; @@ -387,6 +407,8 @@ private List internalBatchCreateSessions( .spanBuilderWithExplicitParent(SpannerImpl.BATCH_CREATE_SESSIONS_REQUEST, parent); span.addAnnotation(String.format("Requesting %d sessions", sessionCount)); try (IScope s = spanner.getTracer().withSpan(span)) { + XGoogSpannerRequestId reqId = + XGoogSpannerRequestId.of(this.nthId, this.nthRequest.incrementAndGet(), channelHint, 1); List sessions = spanner .getRpc() @@ -395,21 +417,20 @@ private List internalBatchCreateSessions( sessionCount, spanner.getOptions().getDatabaseRole(), spanner.getOptions().getSessionLabels(), - options); + reqId.withOptions(options)); span.addAnnotation( String.format( "Request for %d sessions returned %d sessions", sessionCount, sessions.size())); span.end(); List res = new ArrayList<>(sessionCount); for (com.google.spanner.v1.Session session : sessions) { - res.add( + SessionImpl sessionImpl = new SessionImpl( spanner, new SessionReference( - session.getName(), - session.getCreateTime(), - session.getMultiplexed(), - options))); + session.getName(), session.getCreateTime(), session.getMultiplexed(), options)); + sessionImpl.setRequestIdCreator(this); + res.add(sessionImpl); } return res; } catch (RuntimeException e) { @@ -425,6 +446,8 @@ SessionImpl sessionWithId(String name) { synchronized (this) { options = optionMap(SessionOption.channelHint(sessionChannelCounter++)); } - return new SessionImpl(spanner, new SessionReference(name, options)); + SessionImpl sessionImpl = new SessionImpl(spanner, new SessionReference(name, options)); + sessionImpl.setRequestIdCreator(this); + return sessionImpl; } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index 454709275f8..55f656066f6 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -123,18 +123,31 @@ interface SessionTransaction { private final Clock clock; private final Map options; private final ErrorHandler errorHandler; + private XGoogSpannerRequestId.RequestIdCreator requestIdCreator; SessionImpl(SpannerImpl spanner, SessionReference sessionReference) { this(spanner, sessionReference, NO_CHANNEL_HINT); } SessionImpl(SpannerImpl spanner, SessionReference sessionReference, int channelHint) { + this(spanner, sessionReference, channelHint, new XGoogSpannerRequestId.NoopRequestIdCreator()); + } + + SessionImpl( + SpannerImpl spanner, + SessionReference sessionReference, + int channelHint, + XGoogSpannerRequestId.RequestIdCreator requestIdCreator) { this.spanner = spanner; this.tracer = spanner.getTracer(); this.sessionReference = sessionReference; this.clock = spanner.getOptions().getSessionPoolOptions().getPoolMaintainerClock(); this.options = createOptions(sessionReference, channelHint); this.errorHandler = createErrorHandler(spanner.getOptions()); + this.requestIdCreator = requestIdCreator; + if (this.requestIdCreator == null) { + this.requestIdCreator = new XGoogSpannerRequestId.NoopRequestIdCreator(); + } } static Map createOptions( @@ -269,8 +282,14 @@ public CommitResponse writeAtLeastOnceWithOptions( CommitRequest request = requestBuilder.build(); ISpan span = tracer.spanBuilder(SpannerImpl.COMMIT); try (IScope s = tracer.withSpan(span)) { + // TODO: Derive the channelId from the session being used currently. + XGoogSpannerRequestId reqId = this.requestIdCreator.nextRequestId(1 /* channelId */, 0); return SpannerRetryHelper.runTxWithRetriesOnAborted( - () -> new CommitResponse(spanner.getRpc().commit(request, getOptions()))); + () -> { + reqId.incrementAttempt(); + return new CommitResponse( + spanner.getRpc().commit(request, reqId.withOptions(getOptions()))); + }); } catch (RuntimeException e) { span.setStatus(e); throw e; @@ -530,4 +549,8 @@ void onTransactionDone() {} TraceWrapper getTracer() { return tracer; } + + public void setRequestIdCreator(XGoogSpannerRequestId.RequestIdCreator creator) { + this.requestIdCreator = creator; + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java index 4f6c0114750..89f84df187e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java @@ -17,10 +17,19 @@ package com.google.cloud.spanner; import com.google.api.core.InternalApi; +import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.common.annotations.VisibleForTesting; +import io.grpc.Metadata; import java.math.BigInteger; import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.regex.MatchResult; +import java.util.regex.Matcher; +import java.util.regex.Pattern; @InternalApi public class XGoogSpannerRequestId { @@ -28,6 +37,9 @@ public class XGoogSpannerRequestId { @VisibleForTesting static final String RAND_PROCESS_ID = XGoogSpannerRequestId.generateRandProcessId(); + public static final Metadata.Key REQUEST_HEADER_KEY = + Metadata.Key.of("x-goog-spanner-request-id", Metadata.ASCII_STRING_MARSHALLER); + @VisibleForTesting static final long VERSION = 1; // The version of the specification being implemented. @@ -48,6 +60,26 @@ public static XGoogSpannerRequestId of( return new XGoogSpannerRequestId(nthClientId, nthChannelId, nthRequest, attempt); } + @VisibleForTesting + static final Pattern REGEX = + Pattern.compile("^(\\d)\\.([0-9a-z]{16})\\.(\\d+)\\.(\\d+)\\.(\\d+)\\.(\\d+)$"); + + public static XGoogSpannerRequestId of(String s) { + Matcher m = XGoogSpannerRequestId.REGEX.matcher(s); + if (!m.matches()) { + throw new IllegalStateException( + s + " does not match " + XGoogSpannerRequestId.REGEX.pattern()); + } + + MatchResult mr = m.toMatchResult(); + + return new XGoogSpannerRequestId( + Long.parseLong(mr.group(3)), + Long.parseLong(mr.group(4)), + Long.parseLong(mr.group(5)), + Long.parseLong(mr.group(6))); + } + private static String generateRandProcessId() { // Expecting to use 64-bits of randomness to avoid clashes. BigInteger bigInt = new BigInteger(64, new SecureRandom()); @@ -66,6 +98,13 @@ public String toString() { this.attempt); } + private boolean isGreaterThan(XGoogSpannerRequestId other) { + return this.nthClientId > other.nthClientId + && this.nthChannelId > other.nthChannelId + && this.nthRequest > other.nthRequest + && this.attempt > other.attempt; + } + @Override public boolean equals(Object other) { // instanceof for a null object returns false. @@ -81,8 +120,55 @@ public boolean equals(Object other) { && Objects.equals(this.attempt, otherReqId.attempt); } + public void incrementAttempt() { + this.attempt++; + } + + @SuppressWarnings("unchecked") + public Map withOptions(Map options) { + Map copyOptions = new HashMap<>(); + copyOptions.putAll(options); + copyOptions.put(SpannerRpc.Option.REQUEST_ID, this.toString()); + return copyOptions; + } + @Override public int hashCode() { return Objects.hash(this.nthClientId, this.nthChannelId, this.nthRequest, this.attempt); } + + public interface RequestIdCreator { + XGoogSpannerRequestId nextRequestId(long channelId, int attempt); + } + + public static class NoopRequestIdCreator implements RequestIdCreator { + NoopRequestIdCreator() {} + + @Override + public XGoogSpannerRequestId nextRequestId(long channelId, int attempt) { + return XGoogSpannerRequestId.of(1, 1, 1, 0); + } + } + + public static void assertMonotonicityOfIds(String prefix, List reqIds) { + int size = reqIds.size(); + + List violations = new ArrayList<>(); + for (int i = 1; i < size; i++) { + XGoogSpannerRequestId prev = reqIds.get(i - 1); + XGoogSpannerRequestId curr = reqIds.get(i); + if (prev.isGreaterThan(curr)) { + violations.add(String.format("#%d(%s) > #%d(%s)", i - 1, prev, i, curr)); + } + } + + if (violations.size() == 0) { + return; + } + + throw new IllegalStateException( + prefix + + " monotonicity violation:" + + String.join("\n\t", violations.toArray(new String[0]))); + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 77ce3a540b8..bc9270287fa 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -71,6 +71,7 @@ import com.google.cloud.spanner.SpannerOptions; import com.google.cloud.spanner.SpannerOptions.CallContextConfigurator; import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider; +import com.google.cloud.spanner.XGoogSpannerRequestId; import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStub; import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStubSettings; import com.google.cloud.spanner.admin.database.v1.stub.GrpcDatabaseAdminCallableFactory; @@ -88,6 +89,7 @@ import com.google.common.base.Supplier; import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.Resources; import com.google.common.util.concurrent.RateLimiter; @@ -2023,6 +2025,8 @@ GrpcCallContext newCallContext( // Set channel affinity in GAX. context = context.withChannelAffinity(Option.CHANNEL_HINT.getLong(options).intValue()); } + String methodName = method.getFullMethodName(); + context = withRequestId(context, options, methodName); } if (compressorName != null) { // This sets the compressor for Client -> Server. @@ -2046,6 +2050,7 @@ GrpcCallContext newCallContext( context .withStreamWaitTimeoutDuration(waitTimeout) .withStreamIdleTimeoutDuration(idleTimeout); + CallContextConfigurator configurator = SpannerOptions.CALL_CONTEXT_CONFIGURATOR_KEY.get(); ApiCallContext apiCallContextFromContext = null; if (configurator != null) { @@ -2054,6 +2059,18 @@ GrpcCallContext newCallContext( return (GrpcCallContext) context.merge(apiCallContextFromContext); } + GrpcCallContext withRequestId(GrpcCallContext context, Map options, String methodName) { + String reqId = (String) options.get(Option.REQUEST_ID); + if (reqId == null || Objects.equals(reqId, "")) { + return context; + } + + Map> withReqId = + ImmutableMap.of( + XGoogSpannerRequestId.REQUEST_HEADER_KEY.name(), Collections.singletonList(reqId)); + return context.withExtraHeaders(withReqId); + } + void registerResponseObserver(SpannerResponseObserver responseObserver) { responseObservers.add(responseObserver); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java index 9ad94204743..d029084477a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java @@ -78,7 +78,8 @@ public interface SpannerRpc extends ServiceRpc { /** Options passed in {@link SpannerRpc} methods to control how an RPC is issued. */ enum Option { - CHANNEL_HINT("Channel Hint"); + CHANNEL_HINT("Channel Hint"), + REQUEST_ID("Request Id"); private final String value; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java index 87ea5c19ce9..4103703695d 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java @@ -105,6 +105,7 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Server; +import io.grpc.ServerInterceptors; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.inprocess.InProcessServerBuilder; @@ -119,6 +120,7 @@ import java.util.Arrays; import java.util.Base64; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Random; import java.util.Set; @@ -152,6 +154,7 @@ public class DatabaseClientImplTest { private static final String DATABASE_NAME = String.format( "projects/%s/instances/%s/databases/%s", TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE); + private static XGoogSpannerRequestIdTest.ServerHeaderEnforcer xGoogReqIdInterceptor; private static MockSpannerServiceImpl mockSpanner; private static Server server; private static LocalChannelProvider channelProvider; @@ -220,13 +223,16 @@ public static void startStaticServer() throws IOException { StatementResult.query(SELECT1_FROM_TABLE, MockSpannerTestUtil.SELECT1_RESULTSET)); mockSpanner.setBatchWriteResult(BATCH_WRITE_RESPONSES); + Set checkMethods = + new HashSet(Arrays.asList("google.spanner.v1.Spanner/BatchCreateSessions")); + xGoogReqIdInterceptor = new XGoogSpannerRequestIdTest.ServerHeaderEnforcer(checkMethods); executor = Executors.newSingleThreadExecutor(); String uniqueName = InProcessServerBuilder.generateName(); server = InProcessServerBuilder.forName(uniqueName) // We need to use a real executor for timeouts to occur. .scheduledExecutorService(new ScheduledThreadPoolExecutor(1)) - .addService(mockSpanner) + .addService(ServerInterceptors.intercept(mockSpanner, xGoogReqIdInterceptor)) .build() .start(); channelProvider = LocalChannelProvider.create(uniqueName); @@ -266,6 +272,7 @@ public void tearDown() { spanner.close(); spannerWithEmptySessionPool.close(); mockSpanner.reset(); + xGoogReqIdInterceptor.reset(); mockSpanner.removeAllExecutionTimes(); } @@ -1402,6 +1409,10 @@ public void testWriteAtLeastOnceAborted() { List commitRequests = mockSpanner.getRequestsOfType(CommitRequest.class); assertEquals(2, commitRequests.size()); + xGoogReqIdInterceptor.assertIntegrity(); + System.out.println( + "\033[33mGot: " + xGoogReqIdInterceptor.accumulatedUnaryValues() + "\033[00m"); + xGoogReqIdInterceptor.printAccumulatedValues(); } @Test @@ -5168,6 +5179,26 @@ public void testRetryOnResourceExhausted() { } } + @Test + public void testSelectHasXGoogRequestIdHeader() { + Statement statement = + Statement.newBuilder("select id from test where b=@p1") + .bind("p1") + .toBytesArray( + Arrays.asList(ByteArray.copyFrom("test1"), null, ByteArray.copyFrom("test2"))) + .build(); + mockSpanner.putStatementResult(StatementResult.query(statement, SELECT1_RESULTSET)); + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ResultSet resultSet = client.singleUse().executeQuery(statement)) { + assertTrue(resultSet.next()); + assertEquals(1L, resultSet.getLong(0)); + assertFalse(resultSet.next()); + } finally { + xGoogReqIdInterceptor.assertIntegrity(); + } + } + @Test public void testSessionPoolExhaustedError_containsStackTraces() { assumeFalse( diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/XGoogSpannerRequestIdTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/XGoogSpannerRequestIdTest.java index 12c9213c7dc..188f0717799 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/XGoogSpannerRequestIdTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/XGoogSpannerRequestIdTest.java @@ -18,18 +18,27 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.regex.Matcher; -import java.util.regex.Pattern; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class XGoogSpannerRequestIdTest { - private static final Pattern REGEX_RAND_PROCESS_ID = - Pattern.compile("1.([0-9a-z]{16})(\\.\\d+){3}\\.(\\d+)$"); @Test public void testEquals() { @@ -48,7 +57,130 @@ public void testEquals() { @Test public void testEnsureHexadecimalFormatForRandProcessID() { String str = XGoogSpannerRequestId.of(1, 2, 3, 4).toString(); - Matcher m = XGoogSpannerRequestIdTest.REGEX_RAND_PROCESS_ID.matcher(str); + Matcher m = XGoogSpannerRequestId.REGEX.matcher(str); assertTrue(m.matches()); } + + public static class ServerHeaderEnforcer implements ServerInterceptor { + private Map> unaryResults; + private Map> streamingResults; + private List gotValues; + private Set checkMethods; + + ServerHeaderEnforcer(Set checkMethods) { + this.gotValues = new CopyOnWriteArrayList(); + this.unaryResults = + new ConcurrentHashMap>(); + this.streamingResults = + new ConcurrentHashMap>(); + this.checkMethods = checkMethods; + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + final Metadata requestHeaders, + ServerCallHandler next) { + boolean isUnary = call.getMethodDescriptor().getType() == MethodType.UNARY; + String methodName = call.getMethodDescriptor().getFullMethodName(); + String gotReqIdStr = requestHeaders.get(XGoogSpannerRequestId.REQUEST_HEADER_KEY); + if (!this.checkMethods.contains(methodName)) { + // System.out.println( + // "\033[35mBypassing " + methodName + " but has " + gotReqIdStr + "\033[00m"); + return next.startCall(call, requestHeaders); + } + + Map> saver = this.streamingResults; + if (isUnary) { + saver = this.unaryResults; + } + + // Firstly assert and validate that at least we've got a requestId. + Matcher m = XGoogSpannerRequestId.REGEX.matcher(gotReqIdStr); + assertNotNull(gotReqIdStr); + assertTrue(m.matches()); + + XGoogSpannerRequestId reqId = XGoogSpannerRequestId.of(gotReqIdStr); + if (!saver.containsKey(methodName)) { + saver.put(methodName, new CopyOnWriteArrayList()); + } + + saver.get(methodName).add(reqId); + + // Finally proceed with the call. + return next.startCall(call, requestHeaders); + } + + public String[] accumulatedValues() { + return this.gotValues.toArray(new String[0]); + } + + public void assertIntegrity() { + this.unaryResults.forEach( + (String method, CopyOnWriteArrayList values) -> { + // System.out.println("\033[36munary.method: " + method + "\033[00m"); + XGoogSpannerRequestId.assertMonotonicityOfIds(method, values); + }); + this.streamingResults.forEach( + (String method, CopyOnWriteArrayList values) -> { + // System.out.println("\033[36mstreaming.method: " + method + "\033[00m"); + XGoogSpannerRequestId.assertMonotonicityOfIds(method, values); + }); + } + + public static class methodAndRequestId { + String method; + String requestId; + + public methodAndRequestId(String method, String requestId) { + this.method = method; + this.requestId = requestId; + } + + public String toString() { + return "{" + this.method + ":" + this.requestId + "}"; + } + } + + public methodAndRequestId[] accumulatedUnaryValues() { + List accumulated = new ArrayList(); + this.unaryResults.forEach( + (String method, CopyOnWriteArrayList values) -> { + for (int i = 0; i < values.size(); i++) { + accumulated.add(new methodAndRequestId(method, values.get(i).toString())); + } + }); + return accumulated.toArray(new methodAndRequestId[0]); + } + + public methodAndRequestId[] accumulatedStreamingValues() { + List accumulated = new ArrayList(); + this.streamingResults.forEach( + (String method, CopyOnWriteArrayList values) -> { + for (int i = 0; i < values.size(); i++) { + accumulated.add(new methodAndRequestId(method, values.get(i).toString())); + } + }); + return accumulated.toArray(new methodAndRequestId[0]); + } + + public void printAccumulatedValues() { + methodAndRequestId[] unary = this.accumulatedUnaryValues(); + System.out.println("accumulatedUnaryvalues"); + for (int i = 0; i < unary.length; i++) { + System.out.println("\t" + unary[i].toString()); + } + methodAndRequestId[] streaming = this.accumulatedStreamingValues(); + System.out.println("accumulatedStreaminvalues"); + for (int i = 0; i < streaming.length; i++) { + System.out.println("\t" + streaming[i].toString()); + } + } + + public void reset() { + this.gotValues.clear(); + this.unaryResults.clear(); + this.streamingResults.clear(); + } + } }