diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index cecf462bd25..a89090e34d9 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -768,14 +768,9 @@ ResultSet executeQueryInternalWithOptions( rpc.getExecuteQueryRetrySettings(), rpc.getExecuteQueryRetryableCodes()) { @Override - CloseableIterator startStream( - @Nullable ByteString resumeToken, - AsyncResultSet.StreamMessageListener streamListener) { + CloseableIterator startStream(@Nullable ByteString resumeToken) { GrpcStreamIterator stream = new GrpcStreamIterator(statement, prefetchChunks, cancelQueryWhenClientIsClosed); - if (streamListener != null) { - stream.registerListener(streamListener); - } if (partitionToken != null) { request.setPartitionToken(partitionToken); } @@ -796,8 +791,8 @@ CloseableIterator startStream( getTransactionChannelHint(), isRouteToLeader()); session.markUsed(clock.instant()); - stream.setCall(call, request.getTransaction().hasBegin()); call.request(prefetchChunks); + stream.setCall(call, request.getTransaction().hasBegin()); return stream; } @@ -964,14 +959,9 @@ ResultSet readInternalWithOptions( rpc.getReadRetrySettings(), rpc.getReadRetryableCodes()) { @Override - CloseableIterator startStream( - @Nullable ByteString resumeToken, - AsyncResultSet.StreamMessageListener streamListener) { + CloseableIterator startStream(@Nullable ByteString resumeToken) { GrpcStreamIterator stream = new GrpcStreamIterator(prefetchChunks, cancelQueryWhenClientIsClosed); - if (streamListener != null) { - stream.registerListener(streamListener); - } TransactionSelector selector = null; if (resumeToken != null) { builder.setResumeToken(resumeToken); @@ -990,8 +980,8 @@ CloseableIterator startStream( getTransactionChannelHint(), isRouteToLeader()); session.markUsed(clock.instant()); - stream.setCall(call, /* withBeginTransaction = */ builder.getTransaction().hasBegin()); call.request(prefetchChunks); + stream.setCall(call, /* withBeginTransaction = */ builder.getTransaction().hasBegin()); return stream; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java index 3dca970f96e..fdc0398d5fe 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java @@ -150,14 +150,6 @@ interface CloseableIterator extends Iterator { void close(@Nullable String message); boolean isWithBeginTransaction(); - - /** - * @param streamMessageListener A class object which implements StreamMessageListener - * @return true if streaming is supported by the iterator, otherwise false - */ - default boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { - return false; - } } static double valueProtoToFloat64(com.google.protobuf.Value proto) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java index 2b3225bfc59..dfedcc4f8be 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java @@ -18,7 +18,6 @@ import com.google.api.core.ApiFuture; import com.google.common.base.Function; -import com.google.spanner.v1.PartialResultSet; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -224,12 +223,4 @@ interface ReadyCallback { * @param transformer function which will be used to transform the row. It should not return null. */ List toList(Function transformer) throws SpannerException; - - /** - * An interface to register the listener for streaming gRPC request. It will be called when a - * chunk is received from gRPC streaming call. - */ - interface StreamMessageListener { - void onStreamMessage(PartialResultSet partialResultSet, boolean bufferIsFull); - } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java index 1161822cd10..fa7cc158c19 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java @@ -18,6 +18,7 @@ import com.google.api.core.ApiFuture; import com.google.api.core.ApiFutures; +import com.google.api.core.ListenableFutureToApiFuture; import com.google.api.core.SettableApiFuture; import com.google.api.gax.core.ExecutorProvider; import com.google.cloud.spanner.AbstractReadContext.ListenableAsyncResultSet; @@ -28,13 +29,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.MoreExecutors; -import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.ResultSetStats; import java.util.Collection; import java.util.LinkedList; import java.util.List; import java.util.concurrent.BlockingDeque; +import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -44,14 +45,12 @@ import java.util.logging.Logger; /** Default implementation for {@link AsyncResultSet}. */ -class AsyncResultSetImpl extends ForwardingStructReader - implements ListenableAsyncResultSet, AsyncResultSet.StreamMessageListener { +class AsyncResultSetImpl extends ForwardingStructReader implements ListenableAsyncResultSet { private static final Logger log = Logger.getLogger(AsyncResultSetImpl.class.getName()); /** State of an {@link AsyncResultSetImpl}. */ private enum State { INITIALIZED, - STREAMING_INITIALIZED, /** SYNC indicates that the {@link ResultSet} is used in sync pattern. */ SYNC, CONSUMING, @@ -116,15 +115,12 @@ private enum State { private State state = State.INITIALIZED; - /** This variable indicates that produce rows thread is initiated */ - private volatile boolean produceRowsInitiated; - /** * This variable indicates whether all the results from the underlying result set have been read. */ private volatile boolean finished; - private volatile SettableApiFuture result; + private volatile ApiFuture result; /** * This variable indicates whether {@link #tryNext()} has returned {@link CursorState#DONE} or a @@ -333,12 +329,12 @@ public void run() { private final CallbackRunnable callbackRunnable = new CallbackRunnable(); /** - * {@link ProduceRowsRunnable} reads data from the underlying {@link ResultSet}, places these in + * {@link ProduceRowsCallable} reads data from the underlying {@link ResultSet}, places these in * the buffer and dispatches the {@link CallbackRunnable} when data is ready to be consumed. */ - private class ProduceRowsRunnable implements Runnable { + private class ProduceRowsCallable implements Callable { @Override - public void run() { + public Void call() throws Exception { boolean stop = false; boolean hasNext = false; try { @@ -397,17 +393,12 @@ public void run() { } // Call the callback if there are still rows in the buffer that need to be processed. while (!stop) { - try { - waitIfPaused(); - startCallbackIfNecessary(); - // Make sure we wait until the callback runner has actually finished. - consumingLatch.await(); - synchronized (monitor) { - stop = cursorReturnedDoneOrException; - } - } catch (Throwable e) { - result.setException(e); - return; + waitIfPaused(); + startCallbackIfNecessary(); + // Make sure we wait until the callback runner has actually finished. + consumingLatch.await(); + synchronized (monitor) { + stop = cursorReturnedDoneOrException; } } } finally { @@ -419,14 +410,14 @@ public void run() { } synchronized (monitor) { if (executionException != null) { - result.setException(executionException); - } else if (state == State.CANCELLED) { - result.setException(CANCELLED_EXCEPTION); - } else { - result.set(null); + throw executionException; + } + if (state == State.CANCELLED) { + throw CANCELLED_EXCEPTION; } } } + return null; } private void waitIfPaused() throws InterruptedException { @@ -458,26 +449,6 @@ private void startCallbackWithBufferLatchIfNecessary(int bufferLatch) { } } - private class InitiateStreamingRunnable implements Runnable { - - @Override - public void run() { - try { - // This method returns true if the underlying result set is a streaming result set (e.g. a - // GrpcResultSet). - // Those result sets will trigger initiateProduceRows() when the first results are received. - // Non-streaming result sets do not trigger this callback, and for those result sets, we - // need to eagerly start the ProduceRowsRunnable. - if (!initiateStreaming(AsyncResultSetImpl.this)) { - initiateProduceRows(); - } - } catch (Throwable exception) { - executionException = SpannerExceptionFactory.asSpannerException(exception); - initiateProduceRows(); - } - } - } - /** Sets the callback for this {@link AsyncResultSet}. */ @Override public ApiFuture setCallback(Executor exec, ReadyCallback cb) { @@ -487,24 +458,16 @@ public ApiFuture setCallback(Executor exec, ReadyCallback cb) { this.state == State.INITIALIZED, "callback may not be set multiple times"); // Start to fetch data and buffer these. - this.result = SettableApiFuture.create(); - this.state = State.STREAMING_INITIALIZED; - this.service.execute(new InitiateStreamingRunnable()); + this.result = + new ListenableFutureToApiFuture<>(this.service.submit(new ProduceRowsCallable())); this.executor = MoreExecutors.newSequentialExecutor(Preconditions.checkNotNull(exec)); this.callback = Preconditions.checkNotNull(cb); + this.state = State.RUNNING; pausedLatch.countDown(); return result; } } - private void initiateProduceRows() { - if (this.state == State.STREAMING_INITIALIZED) { - this.state = State.RUNNING; - } - produceRowsInitiated = true; - this.service.execute(new ProduceRowsRunnable()); - } - Future getResult() { return result; } @@ -615,10 +578,6 @@ public ResultSetMetadata getMetadata() { return delegateResultSet.get().getMetadata(); } - boolean initiateStreaming(StreamMessageListener streamMessageListener) { - return StreamingUtil.initiateStreaming(delegateResultSet.get(), streamMessageListener); - } - @Override protected void checkValidState() { synchronized (monitor) { @@ -634,22 +593,4 @@ public Struct getCurrentRowAsStruct() { checkValidState(); return currentRow; } - - @Override - public void onStreamMessage(PartialResultSet partialResultSet, boolean bufferIsFull) { - synchronized (monitor) { - if (produceRowsInitiated) { - return; - } - // if PartialResultSet contains a resume token or buffer size is full, or - // we have reached the end of the stream, we can start the thread. - boolean startJobThread = - !partialResultSet.getResumeToken().isEmpty() - || bufferIsFull - || partialResultSet == GrpcStreamIterator.END_OF_STREAM; - if (startJobThread || state != State.STREAMING_INITIALIZED) { - initiateProduceRows(); - } - } - } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java index 3c4883e6586..babbb310a45 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java @@ -16,7 +16,6 @@ package com.google.cloud.spanner; -import com.google.api.core.InternalApi; import com.google.common.base.Preconditions; import com.google.common.base.Supplier; import com.google.common.base.Suppliers; @@ -24,8 +23,7 @@ import com.google.spanner.v1.ResultSetStats; /** Forwarding implementation of ResultSet that forwards all calls to a delegate. */ -public class ForwardingResultSet extends ForwardingStructReader - implements ProtobufResultSet, StreamingResultSet { +public class ForwardingResultSet extends ForwardingStructReader implements ProtobufResultSet { private Supplier delegate; @@ -104,10 +102,4 @@ public ResultSetStats getStats() { public ResultSetMetadata getMetadata() { return delegate.get().getMetadata(); } - - @Override - @InternalApi - public boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { - return StreamingUtil.initiateStreaming(delegate.get(), streamMessageListener); - } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java index c2a4ee5a585..23c9dd7c2d3 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java @@ -19,7 +19,6 @@ import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; import static com.google.common.base.Preconditions.checkState; -import com.google.api.core.InternalApi; import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.Value; import com.google.spanner.v1.PartialResultSet; @@ -31,8 +30,7 @@ import javax.annotation.Nullable; @VisibleForTesting -class GrpcResultSet extends AbstractResultSet> - implements ProtobufResultSet, StreamingResultSet { +class GrpcResultSet extends AbstractResultSet> implements ProtobufResultSet { private final GrpcValueIterator iterator; private final Listener listener; private final DecodeMode decodeMode; @@ -125,12 +123,6 @@ public ResultSetMetadata getMetadata() { return metadata; } - @Override - @InternalApi - public boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { - return iterator.initiateStreaming(streamMessageListener); - } - @Override public void close() { synchronized (this) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java index 79c02eab58c..af6b5683502 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java @@ -20,11 +20,9 @@ import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import com.google.common.collect.AbstractIterator; import com.google.common.util.concurrent.Uninterruptibles; import com.google.spanner.v1.PartialResultSet; -import java.util.Optional; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -38,8 +36,7 @@ class GrpcStreamIterator extends AbstractIterator implements CloseableIterator { private static final Logger logger = Logger.getLogger(GrpcStreamIterator.class.getName()); - static final PartialResultSet END_OF_STREAM = PartialResultSet.newBuilder().build(); - private AsyncResultSet.StreamMessageListener streamMessageListener; + private static final PartialResultSet END_OF_STREAM = PartialResultSet.newBuilder().build(); private final ConsumerImpl consumer; private final BlockingQueue stream; @@ -69,10 +66,6 @@ protected final SpannerRpc.ResultStreamConsumer consumer() { return consumer; } - void registerListener(AsyncResultSet.StreamMessageListener streamMessageListener) { - this.streamMessageListener = Preconditions.checkNotNull(streamMessageListener); - } - public void setCall(SpannerRpc.StreamingCall call, boolean withBeginTransaction) { this.call = call; this.withBeginTransaction = withBeginTransaction; @@ -142,7 +135,6 @@ protected final PartialResultSet computeNext() { private void addToStream(PartialResultSet results) { // We assume that nothing from the user will interrupt gRPC event threads. Uninterruptibles.putUninterruptibly(stream, results); - onStreamMessage(results); } private class ConsumerImpl implements SpannerRpc.ResultStreamConsumer { @@ -190,9 +182,4 @@ public boolean cancelQueryWhenClientIsClosed() { return this.cancelQueryWhenClientIsClosed; } } - - private void onStreamMessage(PartialResultSet partialResultSet) { - Optional.ofNullable(streamMessageListener) - .ifPresent(sl -> sl.onStreamMessage(partialResultSet, stream.remainingCapacity() <= 1)); - } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java index 24c431eec31..1a3df8b9123 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java @@ -127,10 +127,6 @@ ResultSetStats getStats() { return statistics; } - boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { - return stream.initiateStreaming(streamMessageListener); - } - Type type() { checkState(type != null, "metadata has not been received"); return type; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java index 39165da2d38..3e82ab7d5ff 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java @@ -23,7 +23,6 @@ import com.google.api.client.util.BackOff; import com.google.api.client.util.ExponentialBackOff; -import com.google.api.core.InternalApi; import com.google.api.gax.grpc.GrpcStatusCode; import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.rpc.StatusCode.Code; @@ -59,7 +58,6 @@ abstract class ResumableStreamIterator extends AbstractIterator retryableCodes; private static final Logger logger = Logger.getLogger(ResumableStreamIterator.class.getName()); @@ -198,8 +196,7 @@ public void execute(Runnable command) { } } - abstract CloseableIterator startStream( - @Nullable ByteString resumeToken, AsyncResultSet.StreamMessageListener streamMessageListener); + abstract CloseableIterator startStream(@Nullable ByteString resumeToken); /** * Prepares the iterator for a retry on a different gRPC channel. Returns true if that is @@ -223,21 +220,23 @@ public boolean isWithBeginTransaction() { return stream != null && stream.isWithBeginTransaction(); } - @Override - @InternalApi - public boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { - this.streamMessageListener = streamMessageListener; - startGrpcStreaming(); - return true; - } - @Override protected PartialResultSet computeNext() { int numAttemptsOnOtherChannel = 0; Context context = Context.current(); while (true) { // Eagerly start stream before consuming any buffered items. - startGrpcStreaming(); + if (stream == null) { + span.addAnnotation( + "Starting/Resuming stream", + "ResumeToken", + resumeToken == null ? "null" : resumeToken.toStringUtf8()); + try (IScope scope = tracer.withSpan(span)) { + // When start a new stream set the Span as current to make the gRPC Span a child of + // this Span. + stream = checkNotNull(startStream(resumeToken)); + } + } // Buffer contains items up to a resume token or has reached capacity: flush. if (!buffer.isEmpty() && (finished || !safeToRetry || !buffer.getLast().getResumeToken().isEmpty())) { @@ -316,20 +315,6 @@ && prepareIteratorForRetryOnDifferentGrpcChannel()) { } } - private void startGrpcStreaming() { - if (stream == null) { - span.addAnnotation( - "Starting/Resuming stream", - "ResumeToken", - resumeToken == null ? "null" : resumeToken.toStringUtf8()); - try (IScope scope = tracer.withSpan(span)) { - // When start a new stream set the Span as current to make the gRPC Span a child of - // this Span. - stream = checkNotNull(startStream(resumeToken, streamMessageListener)); - } - } - } - boolean isRetryable(SpannerException spannerException) { return spannerException.isRetryable() || retryableCodes.contains( diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingResultSet.java deleted file mode 100644 index 47b10d852c6..00000000000 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingResultSet.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.cloud.spanner; - -import com.google.api.core.InternalApi; - -/** Streaming implementation of ResultSet that supports streaming of chunks */ -interface StreamingResultSet extends ResultSet { - - /** - * Returns the {@link boolean} for this {@link ResultSet}. This method will be used by - * AsyncResultSet internally to initiate gRPC streaming. This method should not be called by the - * users. - */ - @InternalApi - boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener); -} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingUtil.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingUtil.java deleted file mode 100644 index 54496d39f96..00000000000 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingUtil.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.cloud.spanner; - -final class StreamingUtil { - - private StreamingUtil() {} - - static boolean initiateStreaming( - ResultSet resultSet, AsyncResultSet.StreamMessageListener streamMessageListener) { - if (resultSet instanceof StreamingResultSet) { - return ((StreamingResultSet) resultSet).initiateStreaming(streamMessageListener); - } - return false; - } -} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java index 0ba924ef740..98497fbf140 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java @@ -22,9 +22,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.when; import com.google.api.core.ApiFuture; @@ -34,9 +32,6 @@ import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; import com.google.common.base.Function; import com.google.common.collect.Range; -import com.google.protobuf.ByteString; -import com.google.protobuf.Value; -import com.google.spanner.v1.PartialResultSet; import java.util.List; import java.util.concurrent.BlockingDeque; import java.util.concurrent.CountDownLatch; @@ -53,7 +48,6 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -387,20 +381,13 @@ public Boolean answer(InvocationOnMock invocation) throws Throwable { public void testCallbackIsNotCalledWhilePausedAndCanceled() throws InterruptedException, ExecutionException { Executor executor = Executors.newSingleThreadExecutor(); - StreamingResultSet delegate = mock(StreamingResultSet.class); + ResultSet delegate = mock(ResultSet.class); final AtomicInteger callbackCounter = new AtomicInteger(); ApiFuture callbackResult; try (AsyncResultSetImpl rs = new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { - - when(delegate.initiateStreaming(any(AsyncResultSet.StreamMessageListener.class))) - .thenAnswer( - answer -> { - rs.onStreamMessage(PartialResultSet.newBuilder().build(), false); - return null; - }); callbackResult = rs.setCallback( executor, @@ -511,60 +498,4 @@ public void callbackReturnsDoneBeforeEnd_shouldStopIteration() throws Exception rs.getResult().get(10L, TimeUnit.SECONDS); } } - - @Test - public void testOnStreamMessageWhenResumeTokenIsPresent() { - StreamingResultSet delegate = mock(StreamingResultSet.class); - try (AsyncResultSetImpl rs = - new AsyncResultSetImpl(mockedProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { - // Marking Streaming as supported - Mockito.when( - delegate.initiateStreaming(Mockito.any(AsyncResultSet.StreamMessageListener.class))) - .thenReturn(true); - - rs.setCallback(Executors.newSingleThreadExecutor(), ignored -> CallbackResponse.DONE); - rs.onStreamMessage( - PartialResultSet.newBuilder().addValues(Value.newBuilder().build()).build(), false); - - rs.onStreamMessage( - PartialResultSet.newBuilder().setResumeToken(ByteString.copyFromUtf8("test")).build(), - false); - Mockito.verify(mockedProvider.getExecutor(), times(2)).execute(Mockito.any()); - } - } - - @Test - public void testOnStreamMessageWhenCurrentBufferSizeReachedPrefetchChunkSize() { - StreamingResultSet delegate = mock(StreamingResultSet.class); - try (AsyncResultSetImpl rs = - new AsyncResultSetImpl(mockedProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { - // Marking Streaming as supported - Mockito.when( - delegate.initiateStreaming(Mockito.any(AsyncResultSet.StreamMessageListener.class))) - .thenReturn(true); - - rs.setCallback(Executors.newSingleThreadExecutor(), ignored -> CallbackResponse.DONE); - rs.onStreamMessage( - PartialResultSet.newBuilder().addValues(Value.newBuilder().build()).build(), true); - Mockito.verify(mockedProvider.getExecutor(), times(2)).execute(Mockito.any()); - } - } - - @Test - public void testOnStreamMessageWhenAsyncResultIsCancelled() { - StreamingResultSet delegate = mock(StreamingResultSet.class); - try (AsyncResultSetImpl rs = - new AsyncResultSetImpl(mockedProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { - // Marking Streaming as supported - Mockito.when( - delegate.initiateStreaming(Mockito.any(AsyncResultSet.StreamMessageListener.class))) - .thenReturn(true); - - rs.setCallback(Executors.newSingleThreadExecutor(), ignored -> CallbackResponse.DONE); - rs.cancel(); - rs.onStreamMessage( - PartialResultSet.newBuilder().addValues(Value.newBuilder().build()).build(), false); - Mockito.verify(mockedProvider.getExecutor(), times(2)).execute(Mockito.any()); - } - } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java index ebe86724678..d126719ebb8 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java @@ -64,8 +64,7 @@ public class ResumableStreamIteratorTest { interface Starter { AbstractResultSet.CloseableIterator startStream( - @Nullable ByteString resumeToken, - AsyncResultSet.StreamMessageListener streamMessageListener); + @Nullable ByteString resumeToken); } interface ResultSetStream { @@ -165,9 +164,8 @@ private void initWithLimit(int maxBufferSize) { SpannerStubSettings.newBuilder().executeStreamingSqlSettings().getRetryableCodes()) { @Override AbstractResultSet.CloseableIterator startStream( - @Nullable ByteString resumeToken, - AsyncResultSet.StreamMessageListener streamMessageListener) { - return starter.startStream(resumeToken, null); + @Nullable ByteString resumeToken) { + return starter.startStream(resumeToken); } }; } @@ -175,7 +173,7 @@ AbstractResultSet.CloseableIterator startStream( @Test public void simple() { ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(null, "a")) .thenReturn(resultSet(null, "b")) @@ -197,7 +195,7 @@ public void closedOTSpan() { setInternalState(ResumableStreamIterator.class, this.resumableStreamIterator, "span", span); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -220,7 +218,7 @@ public void closedOCSpan() { setInternalState(ResumableStreamIterator.class, this.resumableStreamIterator, "span", span); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -234,14 +232,14 @@ public void closedOCSpan() { @Test public void restart() { ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"), null)) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"))) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r3"), "c")) @@ -253,7 +251,7 @@ public void restart() { @Test public void restartWithHoldBack() { ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -262,7 +260,7 @@ public void restartWithHoldBack() { .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"), null)) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"))) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r3"), "c")) @@ -274,7 +272,7 @@ public void restartWithHoldBack() { @Test public void restartWithHoldBackMidStream() { ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(null, "b")) @@ -283,7 +281,7 @@ public void restartWithHoldBackMidStream() { .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"), null)) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"))) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r3"), "e")) @@ -306,7 +304,7 @@ public void retryableErrorWithoutRetryInfo() throws IOException { ResumableStreamIterator.class, this.resumableStreamIterator, "backOff", backOff); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenThrow( @@ -314,7 +312,7 @@ public void retryableErrorWithoutRetryInfo() throws IOException { ErrorCode.UNAVAILABLE, "failed by test", Status.UNAVAILABLE.asRuntimeException())); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r1"), null)) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r1"))) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -326,7 +324,7 @@ public void retryableErrorWithoutRetryInfo() throws IOException { @Test public void nonRetryableError() { ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -345,7 +343,7 @@ public void bufferLimitSimple() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(null, "a")) .thenReturn(resultSet(null, "b")) @@ -358,7 +356,7 @@ public void bufferLimitSimpleWithRestartTokens() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -371,14 +369,14 @@ public void bufferLimitRestart() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"), null)) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"))) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r3"), "c")) @@ -392,13 +390,13 @@ public void bufferLimitRestartWithinLimitAtStartOfResults() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(null, "XXXXXX")) .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s2)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(null, "a")) .thenReturn(resultSet(null, "b")) @@ -411,14 +409,14 @@ public void bufferLimitRestartWithinLimitMidResults() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(null, "XXXXXX")) .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r1"), null)) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r1"))) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(null, "b")) @@ -432,7 +430,7 @@ public void bufferLimitMissingTokensUnsafeToRetry() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(null, "b")) @@ -449,7 +447,7 @@ public void bufferLimitMissingTokensSafeToRetry() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(null, "b")) @@ -457,7 +455,7 @@ public void bufferLimitMissingTokensSafeToRetry() { .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r3"), null)) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r3"))) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()).thenReturn(resultSet(null, "d")).thenReturn(null);