diff --git a/driver-core/src/main/com/mongodb/internal/TimeoutContext.java b/driver-core/src/main/com/mongodb/internal/TimeoutContext.java index ba3b8eb0ac5..b3bdc65a9f7 100644 --- a/driver-core/src/main/com/mongodb/internal/TimeoutContext.java +++ b/driver-core/src/main/com/mongodb/internal/TimeoutContext.java @@ -17,8 +17,6 @@ import com.mongodb.MongoClientException; import com.mongodb.MongoOperationTimeoutException; -import com.mongodb.internal.async.AsyncRunnable; -import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.connection.CommandMessage; import com.mongodb.internal.time.StartTime; import com.mongodb.internal.time.Timeout; @@ -26,19 +24,14 @@ import com.mongodb.session.ClientSession; import java.util.Objects; -import java.util.Optional; import java.util.function.LongConsumer; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertNull; import static com.mongodb.assertions.Assertions.isTrue; import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE; -import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.time.Timeout.ZeroSemantics.ZERO_DURATION_MEANS_INFINITE; -import static java.util.Optional.empty; -import static java.util.Optional.ofNullable; import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.concurrent.TimeUnit.NANOSECONDS; /** * Timeout Context. @@ -46,18 +39,15 @@ *

The context for handling timeouts in relation to the Client Side Operation Timeout specification.

*/ public class TimeoutContext { - - private final boolean isMaintenanceContext; private final TimeoutSettings timeoutSettings; - @Nullable - private Timeout timeout; + private final Timeout timeout; @Nullable - private Timeout computedServerSelectionTimeout; - private long minRoundTripTimeMS = 0; - + private final Timeout computedServerSelectionTimeout; @Nullable - private MaxTimeSupplier maxTimeSupplier = null; + private final MaxTimeSupplier maxTimeSupplier; + private final boolean isMaintenanceContext; + private final long minRoundTripTimeMS; public static MongoOperationTimeoutException createMongoRoundTripTimeoutException() { return createMongoTimeoutException("Remaining timeoutMS is less than or equal to the server's minimum round trip time."); @@ -116,11 +106,6 @@ public static TimeoutContext createTimeoutContext(final ClientSession session, f return new TimeoutContext(timeoutSettings); } - // Creates a copy of the timeout context that can be reset without resetting the original. - public TimeoutContext copyTimeoutContext() { - return new TimeoutContext(getTimeoutSettings(), getTimeout()); - } - public TimeoutContext(final TimeoutSettings timeoutSettings) { this(false, timeoutSettings, startTimeout(timeoutSettings.getTimeoutMS())); } @@ -129,9 +114,41 @@ private TimeoutContext(final TimeoutSettings timeoutSettings, @Nullable final Ti this(false, timeoutSettings, timeout); } - private TimeoutContext(final boolean isMaintenanceContext, final TimeoutSettings timeoutSettings, @Nullable final Timeout timeout) { + private TimeoutContext(final boolean isMaintenanceContext, + final TimeoutSettings timeoutSettings, + @Nullable final Timeout timeout) { + this(isMaintenanceContext, + null, + 0, + timeoutSettings, + null, + timeout); + } + + private TimeoutContext(final boolean isMaintenanceContext, + @Nullable final Timeout computedServerSelectionTimeout, + final long minRoundTripTimeMS, + final TimeoutSettings timeoutSettings, + @Nullable final MaxTimeSupplier maxTimeSupplier) { + this(isMaintenanceContext, + computedServerSelectionTimeout, + minRoundTripTimeMS, + timeoutSettings, + maxTimeSupplier, + startTimeout(timeoutSettings.getTimeoutMS())); + } + + private TimeoutContext(final boolean isMaintenanceContext, + @Nullable final Timeout computedServerSelectionTimeout, + final long minRoundTripTimeMS, + final TimeoutSettings timeoutSettings, + @Nullable final MaxTimeSupplier maxTimeSupplier, + @Nullable final Timeout timeout) { this.isMaintenanceContext = isMaintenanceContext; this.timeoutSettings = timeoutSettings; + this.computedServerSelectionTimeout = computedServerSelectionTimeout; + this.minRoundTripTimeMS = minRoundTripTimeMS; + this.maxTimeSupplier = maxTimeSupplier; this.timeout = timeout; } @@ -152,17 +169,6 @@ public void onExpired(final Runnable onExpired) { Timeout.nullAsInfinite(timeout).onExpired(onExpired); } - /** - * Sets the recent min round trip time - * @param minRoundTripTimeMS the min round trip time - * @return this - */ - public TimeoutContext minRoundTripTimeMS(final long minRoundTripTimeMS) { - isTrue("'minRoundTripTimeMS' must be a positive number", minRoundTripTimeMS >= 0); - this.minRoundTripTimeMS = minRoundTripTimeMS; - return this; - } - @Nullable public Timeout timeoutIncludingRoundTrip() { return timeout == null ? null : timeout.shortenBy(minRoundTripTimeMS, MILLISECONDS); @@ -237,8 +243,19 @@ private static void runWithFixedTimeout(final long ms, final LongConsumer onRema } } - public void resetToDefaultMaxTime() { - this.maxTimeSupplier = null; + /** + * Creates a new {@link TimeoutContext} with the same settings, but with the + * {@link TimeoutSettings#getMaxAwaitTimeMS()} as the maxTimeMS override which will be used + * in {@link #runMaxTimeMS(LongConsumer)}. + */ + public TimeoutContext withMaxTimeAsMaxAwaitTimeOverride() { + return new TimeoutContext( + isMaintenanceContext, + computedServerSelectionTimeout, + minRoundTripTimeMS, + timeoutSettings, + timeoutSettings::getMaxAwaitTimeMS, + timeout); } /** @@ -253,26 +270,77 @@ public void resetToDefaultMaxTime() { * If remaining CSOT timeout is less than this static timeout, then CSOT timeout will be used. * */ - public void setMaxTimeOverride(final long maxTimeMS) { - this.maxTimeSupplier = () -> maxTimeMS; + public TimeoutContext withMaxTimeOverride(final long maxTimeMS) { + return new TimeoutContext( + isMaintenanceContext, + computedServerSelectionTimeout, + minRoundTripTimeMS, + timeoutSettings, + () -> maxTimeMS, + timeout); + } + + /** + * Creates {@link TimeoutContext} with the default maxTimeMS behaviour in {@link #runMaxTimeMS(LongConsumer)}: + * - if timeoutMS is set, the remaining timeoutMS will be used as the maxTimeMS. + * - if timeoutMS is not set, the {@link TimeoutSettings#getMaxTimeMS()} will be used. + */ + public TimeoutContext withDefaultMaxTime() { + return new TimeoutContext( + isMaintenanceContext, + computedServerSelectionTimeout, + minRoundTripTimeMS, + timeoutSettings, + null, + timeout); } /** * Disable the maxTimeMS override. This way the maxTimeMS will not * be appended to the command in the {@link CommandMessage}. */ - public void disableMaxTimeOverride() { - this.maxTimeSupplier = () -> 0; + public TimeoutContext withDisabledMaxTimeOverride() { + return new TimeoutContext( + isMaintenanceContext, + computedServerSelectionTimeout, + minRoundTripTimeMS, + timeoutSettings, + () -> 0, + timeout); } /** * The override will be provided as the remaining value in * {@link #runMaxTimeMS}, where 0 is ignored. */ - public void setMaxTimeOverrideToMaxCommitTime() { - this.maxTimeSupplier = () -> getMaxCommitTimeMS(); + public TimeoutContext withMaxTimeOverrideAsMaxCommitTime() { + return new TimeoutContext( + isMaintenanceContext, + computedServerSelectionTimeout, + minRoundTripTimeMS, + timeoutSettings, + () -> getMaxCommitTimeMS(), + timeout); + } + + + /** + * Creates {@link TimeoutContext} with the recent min round trip time. + * + * @param minRoundTripTimeMS the min round trip time + * @return this + */ + public TimeoutContext withMinRoundTripTime(final long minRoundTripTimeMS) { + return new TimeoutContext( + isMaintenanceContext, + computedServerSelectionTimeout, + minRoundTripTimeMS, + timeoutSettings, + maxTimeSupplier, + timeout); } + @VisibleForTesting(otherwise = PRIVATE) public long getMaxCommitTimeMS() { Long maxCommitTimeMS = timeoutSettings.getMaxCommitTimeMS(); @@ -296,65 +364,44 @@ public int getConnectTimeoutMs() { } /** - * @see #hasTimeoutMS() - * @see #doWithResetTimeout(Runnable) - * @see #doWithResetTimeout(AsyncRunnable, SingleResultCallback) - */ - public void resetTimeoutIfPresent() { - getAndResetTimeoutIfPresent(); - } - - /** - * @see #hasTimeoutMS() - * @return A {@linkplain Optional#isPresent() non-empty} previous {@linkplain Timeout} iff {@link #hasTimeoutMS()}, - * i.e., iff it was reset. + * Resets the timeout if this timeout context is being used by pool maintenance */ - private Optional getAndResetTimeoutIfPresent() { - Timeout result = timeout; - if (hasTimeoutMS()) { - timeout = startTimeout(timeoutSettings.getTimeoutMS()); - return ofNullable(result); + public TimeoutContext withNewlyStartedTimeoutMaintenanceTimeout() { + if (!isMaintenanceContext) { + return this; } - return empty(); + + return new TimeoutContext( + true, + computedServerSelectionTimeout, + minRoundTripTimeMS, + timeoutSettings, + maxTimeSupplier); } - /** - * @see #resetTimeoutIfPresent() - */ - public void doWithResetTimeout(final Runnable action) { - Optional originalTimeout = getAndResetTimeoutIfPresent(); - try { - action.run(); - } finally { - originalTimeout.ifPresent(original -> timeout = original); - } + + public TimeoutContext withMinRoundTripTimeMS(final long minRoundTripTimeMS) { + isTrue("'minRoundTripTimeMS' must be a positive number", minRoundTripTimeMS >= 0); + return new TimeoutContext(isMaintenanceContext, + computedServerSelectionTimeout, + minRoundTripTimeMS, + timeoutSettings, + maxTimeSupplier, + timeout); } - /** - * @see #resetTimeoutIfPresent() - */ - public void doWithResetTimeout(final AsyncRunnable action, final SingleResultCallback callback) { - beginAsync().thenRun(c -> { - Optional originalTimeout = getAndResetTimeoutIfPresent(); - beginAsync().thenRun(c2 -> { - action.finish(c2); - }).thenAlwaysRunAndFinish(() -> { - originalTimeout.ifPresent(original -> timeout = original); - }, c); - }).finish(callback); + // Creates a copy of the timeout context that can be reset without resetting the original. + public TimeoutContext copyTimeoutContext() { + return new TimeoutContext(getTimeoutSettings(), getTimeout()); } - /** - * Resets the timeout if this timeout context is being used by pool maintenance - */ - public void resetMaintenanceTimeout() { - if (!isMaintenanceContext) { - return; - } - timeout = Timeout.nullAsInfinite(timeout).call(NANOSECONDS, - () -> timeout, - (ms) -> startTimeout(timeoutSettings.getTimeoutMS()), - () -> startTimeout(timeoutSettings.getTimeoutMS())); + public TimeoutContext withNewlyStartedTimeout() { + return new TimeoutContext( + isMaintenanceContext, + computedServerSelectionTimeout, + minRoundTripTimeMS, + timeoutSettings, + maxTimeSupplier); } public TimeoutContext withAdditionalReadTimeout(final int additionalReadTimeout) { @@ -421,20 +468,11 @@ public static Timeout startTimeout(@Nullable final Long timeoutMS) { * @return the timeout context */ public Timeout computeServerSelectionTimeout() { - Timeout serverSelectionTimeout = StartTime.now() - .timeoutAfterOrInfiniteIfNegative(getTimeoutSettings().getServerSelectionTimeoutMS(), MILLISECONDS); - - - if (isMaintenanceContext || !hasTimeoutMS()) { - return serverSelectionTimeout; - } - - if (timeout != null && Timeout.earliest(serverSelectionTimeout, timeout) == timeout) { - return timeout; + if (hasTimeoutMS()) { + return assertNotNull(timeout); } - computedServerSelectionTimeout = serverSelectionTimeout; - return computedServerSelectionTimeout; + return StartTime.now().timeoutAfterOrInfiniteIfNegative(getTimeoutSettings().getServerSelectionTimeoutMS(), MILLISECONDS); } /** @@ -442,10 +480,16 @@ public Timeout computeServerSelectionTimeout() { * * @return a new timeout context with the cached computed server selection timeout if available or this */ - public TimeoutContext withComputedServerSelectionTimeoutContext() { - if (this.hasTimeoutMS() && computedServerSelectionTimeout != null) { - return new TimeoutContext(false, timeoutSettings, computedServerSelectionTimeout); + public TimeoutContext withComputedServerSelectionTimeoutContextNew() { + if (this.hasTimeoutMS()) { + Timeout serverSelectionTimeout = StartTime.now() + .timeoutAfterOrInfiniteIfNegative(getTimeoutSettings().getServerSelectionTimeoutMS(), MILLISECONDS); + if (isMaintenanceContext) { + return new TimeoutContext(false, timeoutSettings, serverSelectionTimeout); + } + return new TimeoutContext(false, timeoutSettings, Timeout.earliest(serverSelectionTimeout, timeout)); } + return this; } diff --git a/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java b/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java index 632e453d0c0..11da1c97f75 100644 --- a/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java +++ b/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java @@ -29,6 +29,8 @@ *

This class is not part of the public API and may be removed or changed at any time

*/ public interface SingleResultCallback { + SingleResultCallback THEN_DO_NOTHING = (r, t) -> {}; + /** * Called when the function completes. This method must not complete abruptly, see {@link AsyncCallbackFunction} for more details. * diff --git a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackTriFunction.java b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackTriFunction.java new file mode 100644 index 00000000000..0df5ff8c358 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackTriFunction.java @@ -0,0 +1,40 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.internal.async.function; + +import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.lang.Nullable; + +/** + * An {@linkplain AsyncCallbackFunction asynchronous callback-based function} of three parameters. + * This class is a callback-based. + * + *

This class is not part of the public API and may be removed or changed at any time

+ * + * @param The type of the first parameter to the function. + * @param The type of the second parameter to the function. + * @param See {@link AsyncCallbackFunction} + * @see AsyncCallbackFunction + */ +@FunctionalInterface +public interface AsyncCallbackTriFunction { + /** + * @param p1 The first {@code @}{@link Nullable} argument of the asynchronous function. + * @param p2 The second {@code @}{@link Nullable} argument of the asynchronous function. + * @see AsyncCallbackFunction#apply(Object, SingleResultCallback) + */ + void apply(P1 p1, P2 p2, P3 p3, SingleResultCallback callback); +} diff --git a/driver-core/src/main/com/mongodb/internal/async/function/RetryState.java b/driver-core/src/main/com/mongodb/internal/async/function/RetryState.java index e1cecf721fc..5c9f756a94d 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/RetryState.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/RetryState.java @@ -47,7 +47,7 @@ @NotThreadSafe public final class RetryState { public static final int RETRIES = 1; - private static final int INFINITE_ATTEMPTS = Integer.MAX_VALUE; + public static final int INFINITE_ATTEMPTS = Integer.MAX_VALUE; private final LoopState loopState; private final int attempts; @@ -67,19 +67,16 @@ public final class RetryState { *

* * @param retries A positive number of allowed retries. {@link Integer#MAX_VALUE} is a special value interpreted as being unlimited. - * @param timeoutContext A timeout context that will be used to determine if the operation has timed out. + * @param retryUntilTimeoutThrowsException // TODO-JAVA-5640 shouldn't a timeout always stop retries? * @see #attempts() */ - public static RetryState withRetryableState(final int retries, final TimeoutContext timeoutContext) { + public static RetryState withRetryableState(final int retries, final boolean retryUntilTimeoutThrowsException) { assertTrue(retries > 0); - if (timeoutContext.hasTimeoutMS()){ - return new RetryState(INFINITE_ATTEMPTS, timeoutContext); - } - return new RetryState(retries, null); + return new RetryState(retries, retryUntilTimeoutThrowsException); } public static RetryState withNonRetryableState() { - return new RetryState(0, null); + return new RetryState(0, false); } /** @@ -94,19 +91,19 @@ public static RetryState withNonRetryableState() { * @see #attempts() */ public RetryState(final TimeoutContext timeoutContext) { - this(INFINITE_ATTEMPTS, timeoutContext); + this(INFINITE_ATTEMPTS, timeoutContext.hasTimeoutMS()); } /** * @param retries A non-negative number of allowed retries. {@link Integer#MAX_VALUE} is a special value interpreted as being unlimited. - * @param timeoutContext A timeout context that will be used to determine if the operation has timed out. + * @param retryUntilTimeoutThrowsException * @see #attempts() */ - private RetryState(final int retries, @Nullable final TimeoutContext timeoutContext) { + private RetryState(final int retries, final boolean retryUntilTimeoutThrowsException) { assertTrue(retries >= 0); loopState = new LoopState(); attempts = retries == INFINITE_ATTEMPTS ? INFINITE_ATTEMPTS : retries + 1; - this.retryUntilTimeoutThrowsException = timeoutContext != null && timeoutContext.hasTimeoutMS(); + this.retryUntilTimeoutThrowsException = retryUntilTimeoutThrowsException; } /** diff --git a/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterAwareReadWriteBinding.java b/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterAwareReadWriteBinding.java index c66dc321513..5de9df43174 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterAwareReadWriteBinding.java +++ b/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterAwareReadWriteBinding.java @@ -18,6 +18,7 @@ import com.mongodb.ServerAddress; import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.connection.OperationContext; /** *

This class is not part of the public API and may be removed or changed at any time

@@ -30,7 +31,7 @@ public interface AsyncClusterAwareReadWriteBinding extends AsyncReadWriteBinding * @param serverAddress the server address * @param callback the to be passed the connection source */ - void getConnectionSource(ServerAddress serverAddress, SingleResultCallback callback); + void getConnectionSource(ServerAddress serverAddress, OperationContext operationContext, SingleResultCallback callback); @Override AsyncClusterAwareReadWriteBinding retain(); diff --git a/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterBinding.java b/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterBinding.java index fd46261a6df..322ef374381 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterBinding.java +++ b/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterBinding.java @@ -25,6 +25,7 @@ import com.mongodb.internal.connection.AsyncConnection; import com.mongodb.internal.connection.Cluster; import com.mongodb.internal.connection.OperationContext; +import com.mongodb.internal.connection.ReadConcernAwareNoOpSessionContext; import com.mongodb.internal.connection.Server; import com.mongodb.internal.selector.ReadPreferenceServerSelector; import com.mongodb.internal.selector.ReadPreferenceWithFallbackServerSelector; @@ -33,7 +34,6 @@ import com.mongodb.selector.ServerSelector; import static com.mongodb.assertions.Assertions.notNull; -import static java.util.concurrent.TimeUnit.NANOSECONDS; /** * A simple ReadWriteBinding implementation that supplies write connection sources bound to a possibly different primary each time, and a @@ -44,24 +44,17 @@ public class AsyncClusterBinding extends AbstractReferenceCounted implements AsyncClusterAwareReadWriteBinding { private final Cluster cluster; private final ReadPreference readPreference; - private final ReadConcern readConcern; - private final OperationContext operationContext; /** * Creates an instance. * * @param cluster a non-null Cluster which will be used to select a server to bind to * @param readPreference a non-null ReadPreference for read operations - * @param readConcern a non-null read concern - * @param operationContext the operation context *

This class is not part of the public API and may be removed or changed at any time

*/ - public AsyncClusterBinding(final Cluster cluster, final ReadPreference readPreference, final ReadConcern readConcern, - final OperationContext operationContext) { + public AsyncClusterBinding(final Cluster cluster, final ReadPreference readPreference) { this.cluster = notNull("cluster", cluster); this.readPreference = notNull("readPreference", readPreference); - this.readConcern = notNull("readConcern", readConcern); - this.operationContext = notNull("operationContext", operationContext); } @Override @@ -76,21 +69,18 @@ public ReadPreference getReadPreference() { } @Override - public OperationContext getOperationContext() { - return operationContext; - } - - @Override - public void getReadConnectionSource(final SingleResultCallback callback) { - getAsyncClusterBindingConnectionSource(new ReadPreferenceServerSelector(readPreference), callback); + public void getReadConnectionSource(final OperationContext operationContext, + final SingleResultCallback callback) { + getAsyncClusterBindingConnectionSource(new ReadPreferenceServerSelector(readPreference), operationContext, callback); } @Override public void getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference, + final OperationContext operationContext, final SingleResultCallback callback) { // Assume 5.0+ for load-balanced mode if (cluster.getSettings().getMode() == ClusterConnectionMode.LOAD_BALANCED) { - getReadConnectionSource(callback); + getReadConnectionSource(operationContext, callback); } else { ReadPreferenceWithFallbackServerSelector readPreferenceWithFallbackServerSelector = new ReadPreferenceWithFallbackServerSelector(readPreference, minWireVersion, fallbackReadPreference); @@ -106,16 +96,19 @@ public void getReadConnectionSource(final int minWireVersion, final ReadPreferen } @Override - public void getWriteConnectionSource(final SingleResultCallback callback) { - getAsyncClusterBindingConnectionSource(new WritableServerSelector(), callback); + public void getWriteConnectionSource(final OperationContext operationContext, + final SingleResultCallback callback) { + getAsyncClusterBindingConnectionSource(new WritableServerSelector(), operationContext, callback); } @Override - public void getConnectionSource(final ServerAddress serverAddress, final SingleResultCallback callback) { - getAsyncClusterBindingConnectionSource(new ServerAddressSelector(serverAddress), callback); + public void getConnectionSource(final ServerAddress serverAddress, final OperationContext operationContext, + final SingleResultCallback callback) { + getAsyncClusterBindingConnectionSource(new ServerAddressSelector(serverAddress), operationContext, callback); } private void getAsyncClusterBindingConnectionSource(final ServerSelector serverSelector, + final OperationContext operationContext, final SingleResultCallback callback) { cluster.selectServerAsync(serverSelector, operationContext, (result, t) -> { if (t != null) { @@ -132,12 +125,14 @@ private final class AsyncClusterBindingConnectionSource extends AbstractReferenc private final ServerDescription serverDescription; private final ReadPreference appliedReadPreference; - private AsyncClusterBindingConnectionSource(final Server server, final ServerDescription serverDescription, - final ReadPreference appliedReadPreference) { + private AsyncClusterBindingConnectionSource(final Server server, + final ServerDescription serverDescription, + final ReadPreference appliedReadPreference) { this.server = server; this.serverDescription = serverDescription; this.appliedReadPreference = appliedReadPreference; - operationContext.getTimeoutContext().minRoundTripTimeMS(NANOSECONDS.toMillis(serverDescription.getMinRoundTripTimeNanos())); + // TODO should be calculated externaly + // operationContext.getTimeoutContext().minRoundTripTimeMS(NANOSECONDS.toMillis(serverDescription.getMinRoundTripTimeNanos())); AsyncClusterBinding.this.retain(); } @@ -146,19 +141,18 @@ public ServerDescription getServerDescription() { return serverDescription; } - @Override - public OperationContext getOperationContext() { - return operationContext; - } - @Override public ReadPreference getReadPreference() { return appliedReadPreference; } @Override - public void getConnection(final SingleResultCallback callback) { - server.getConnectionAsync(operationContext, callback); + public void getConnection(final OperationContext operationContext, final SingleResultCallback callback) { + // The first read in a causally consistent session MUST not send afterClusterTime to the server + // (because the operationTime has not yet been determined). Therefore, we use ReadConcernAwareNoOpSessionContext to + // so that we do not advance clusterTime on ClientSession in given operationContext because it might not be yet set. + ReadConcern readConcern = operationContext.getSessionContext().getReadConcern(); + server.getConnectionAsync(operationContext.withSessionContext(new ReadConcernAwareNoOpSessionContext(readConcern)), callback); } public AsyncConnectionSource retain() { diff --git a/driver-core/src/main/com/mongodb/internal/binding/AsyncConnectionSource.java b/driver-core/src/main/com/mongodb/internal/binding/AsyncConnectionSource.java index 5d70faf598e..aae7d6b7419 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/AsyncConnectionSource.java +++ b/driver-core/src/main/com/mongodb/internal/binding/AsyncConnectionSource.java @@ -20,6 +20,7 @@ import com.mongodb.connection.ServerDescription; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.connection.AsyncConnection; +import com.mongodb.internal.connection.OperationContext; /** * A source of connections to a single MongoDB server. @@ -42,12 +43,7 @@ public interface AsyncConnectionSource extends BindingContext, ReferenceCounted */ ReadPreference getReadPreference(); - /** - * Gets a connection from this source. - * - * @param callback the to be passed the connection - */ - void getConnection(SingleResultCallback callback); + void getConnection(OperationContext operationContext, SingleResultCallback callback); @Override AsyncConnectionSource retain(); diff --git a/driver-core/src/main/com/mongodb/internal/binding/AsyncReadBinding.java b/driver-core/src/main/com/mongodb/internal/binding/AsyncReadBinding.java index 633091b3efb..e6d5cdfbbcf 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/AsyncReadBinding.java +++ b/driver-core/src/main/com/mongodb/internal/binding/AsyncReadBinding.java @@ -18,6 +18,7 @@ import com.mongodb.ReadPreference; import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.connection.OperationContext; /** * An asynchronous factory of connection sources to servers that can be read from and that satisfy the specified read preference. @@ -35,7 +36,7 @@ public interface AsyncReadBinding extends BindingContext, ReferenceCounted { * Returns a connection source to a server that satisfies the read preference with which this instance is configured. * @param callback the to be passed the connection source */ - void getReadConnectionSource(SingleResultCallback callback); + void getReadConnectionSource(OperationContext operationContext, SingleResultCallback callback); /** * Return a connection source that satisfies the read preference with which this instance is configured, if all connected servers have @@ -48,6 +49,7 @@ public interface AsyncReadBinding extends BindingContext, ReferenceCounted { * @see com.mongodb.internal.operation.AggregateToCollectionOperation */ void getReadConnectionSource(int minWireVersion, ReadPreference fallbackReadPreference, + OperationContext operationContext, SingleResultCallback callback); @Override diff --git a/driver-core/src/main/com/mongodb/internal/binding/AsyncWriteBinding.java b/driver-core/src/main/com/mongodb/internal/binding/AsyncWriteBinding.java index 39bdf4729c2..1dba1c245e6 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/AsyncWriteBinding.java +++ b/driver-core/src/main/com/mongodb/internal/binding/AsyncWriteBinding.java @@ -17,6 +17,7 @@ package com.mongodb.internal.binding; import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.connection.OperationContext; /** * An asynchronous factory of connection sources to servers that can be written to, e.g, a standalone, a mongos, or a replica set primary. @@ -30,7 +31,7 @@ public interface AsyncWriteBinding extends BindingContext, ReferenceCounted { * * @param callback the to be passed the connection source */ - void getWriteConnectionSource(SingleResultCallback callback); + void getWriteConnectionSource(OperationContext operationContext, SingleResultCallback callback); @Override AsyncWriteBinding retain(); diff --git a/driver-core/src/main/com/mongodb/internal/binding/BindingContext.java b/driver-core/src/main/com/mongodb/internal/binding/BindingContext.java index c10f0fb16ac..289d9070c5e 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/BindingContext.java +++ b/driver-core/src/main/com/mongodb/internal/binding/BindingContext.java @@ -16,9 +16,6 @@ package com.mongodb.internal.binding; -import com.mongodb.internal.connection.OperationContext; - - /** *

This class is not part of the public API and may be removed or changed at any time

*/ @@ -29,5 +26,5 @@ public interface BindingContext { * * @return the operation context for the binding context. */ - OperationContext getOperationContext(); + // OperationContext getOperationContext(); } diff --git a/driver-core/src/main/com/mongodb/internal/binding/ClusterAwareReadWriteBinding.java b/driver-core/src/main/com/mongodb/internal/binding/ClusterAwareReadWriteBinding.java index 8f7552341a7..b97b22c3a06 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/ClusterAwareReadWriteBinding.java +++ b/driver-core/src/main/com/mongodb/internal/binding/ClusterAwareReadWriteBinding.java @@ -17,6 +17,7 @@ package com.mongodb.internal.binding; import com.mongodb.ServerAddress; +import com.mongodb.internal.connection.OperationContext; /** * This interface is not part of the public API and may be removed or changed at any time. @@ -27,5 +28,5 @@ public interface ClusterAwareReadWriteBinding extends ReadWriteBinding { * Returns a connection source to the specified server address. * @return the connection source */ - ConnectionSource getConnectionSource(ServerAddress serverAddress); + ConnectionSource getConnectionSource(ServerAddress serverAddress, OperationContext operationContext); } diff --git a/driver-core/src/main/com/mongodb/internal/binding/ClusterBinding.java b/driver-core/src/main/com/mongodb/internal/binding/ClusterBinding.java index cd3f8473bbb..d6f39d308ea 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/ClusterBinding.java +++ b/driver-core/src/main/com/mongodb/internal/binding/ClusterBinding.java @@ -24,6 +24,7 @@ import com.mongodb.internal.connection.Cluster; import com.mongodb.internal.connection.Connection; import com.mongodb.internal.connection.OperationContext; +import com.mongodb.internal.connection.ReadConcernAwareNoOpSessionContext; import com.mongodb.internal.connection.Server; import com.mongodb.internal.connection.ServerTuple; import com.mongodb.internal.selector.ReadPreferenceServerSelector; @@ -32,7 +33,6 @@ import com.mongodb.internal.selector.WritableServerSelector; import static com.mongodb.assertions.Assertions.notNull; -import static java.util.concurrent.TimeUnit.NANOSECONDS; /** * A simple ReadWriteBinding implementation that supplies write connection sources bound to a possibly different primary each time, and a @@ -43,22 +43,15 @@ public class ClusterBinding extends AbstractReferenceCounted implements ClusterAwareReadWriteBinding { private final Cluster cluster; private final ReadPreference readPreference; - private final ReadConcern readConcern; - private final OperationContext operationContext; /** * Creates an instance. * @param cluster a non-null Cluster which will be used to select a server to bind to * @param readPreference a non-null ReadPreference for read operations - * @param readConcern a non-null read concern - * @param operationContext the operation context */ - public ClusterBinding(final Cluster cluster, final ReadPreference readPreference, final ReadConcern readConcern, - final OperationContext operationContext) { + public ClusterBinding(final Cluster cluster, final ReadPreference readPreference) { this.cluster = notNull("cluster", cluster); this.readPreference = notNull("readPreference", readPreference); - this.readConcern = notNull("readConcern", readConcern); - this.operationContext = notNull("operationContext", operationContext); } @Override @@ -73,36 +66,39 @@ public ReadPreference getReadPreference() { } @Override - public OperationContext getOperationContext() { - return operationContext; + public ConnectionSource getReadConnectionSource(final OperationContext operationContext) { + return new ClusterBindingConnectionSource( + cluster.selectServer(new ReadPreferenceServerSelector(readPreference), operationContext), + readPreference); } @Override - public ConnectionSource getReadConnectionSource() { - return new ClusterBindingConnectionSource(cluster.selectServer(new ReadPreferenceServerSelector(readPreference), operationContext), readPreference); - } - - @Override - public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference) { + public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference, + final OperationContext operationContext) { // Assume 5.0+ for load-balanced mode if (cluster.getSettings().getMode() == ClusterConnectionMode.LOAD_BALANCED) { - return getReadConnectionSource(); + return getReadConnectionSource(operationContext); } else { ReadPreferenceWithFallbackServerSelector readPreferenceWithFallbackServerSelector = new ReadPreferenceWithFallbackServerSelector(readPreference, minWireVersion, fallbackReadPreference); ServerTuple serverTuple = cluster.selectServer(readPreferenceWithFallbackServerSelector, operationContext); - return new ClusterBindingConnectionSource(serverTuple, readPreferenceWithFallbackServerSelector.getAppliedReadPreference()); + return new ClusterBindingConnectionSource(serverTuple, + readPreferenceWithFallbackServerSelector.getAppliedReadPreference()); } } @Override - public ConnectionSource getWriteConnectionSource() { - return new ClusterBindingConnectionSource(cluster.selectServer(new WritableServerSelector(), operationContext), readPreference); + public ConnectionSource getWriteConnectionSource(final OperationContext operationContext) { + return new ClusterBindingConnectionSource( + cluster.selectServer(new WritableServerSelector(), operationContext), + readPreference); } @Override - public ConnectionSource getConnectionSource(final ServerAddress serverAddress) { - return new ClusterBindingConnectionSource(cluster.selectServer(new ServerAddressSelector(serverAddress), operationContext), readPreference); + public ConnectionSource getConnectionSource(final ServerAddress serverAddress, final OperationContext operationContext) { + return new ClusterBindingConnectionSource( + cluster.selectServer(new ServerAddressSelector(serverAddress), operationContext), + readPreference); } private final class ClusterBindingConnectionSource extends AbstractReferenceCounted implements ConnectionSource { @@ -114,7 +110,8 @@ private ClusterBindingConnectionSource(final ServerTuple serverTuple, final Read this.server = serverTuple.getServer(); this.serverDescription = serverTuple.getServerDescription(); this.appliedReadPreference = appliedReadPreference; - operationContext.getTimeoutContext().minRoundTripTimeMS(NANOSECONDS.toMillis(serverDescription.getMinRoundTripTimeNanos())); + //TODO THis has to be moved outside of the consutructor to the place where getConnectionSource is called to create a new OperationContet to use further + // operationContext.getTimeoutContext().minRoundTripTimeMS(NANOSECONDS.toMillis(serverDescription.getMinRoundTripTimeNanos())); ClusterBinding.this.retain(); } @@ -123,19 +120,18 @@ public ServerDescription getServerDescription() { return serverDescription; } - @Override - public OperationContext getOperationContext() { - return operationContext; - } - @Override public ReadPreference getReadPreference() { return appliedReadPreference; } @Override - public Connection getConnection() { - return server.getConnection(operationContext); + public Connection getConnection(final OperationContext operationContext) { + // The first read in a causally consistent session MUST not send afterClusterTime to the server + // (because the operationTime has not yet been determined). Therefore, we use ReadConcernAwareNoOpSessionContext to + // so that we do not advance clusterTime on ClientSession in given operationContext because it might not be yet set. + ReadConcern readConcern = operationContext.getSessionContext().getReadConcern(); + return server.getConnection(operationContext.withSessionContext(new ReadConcernAwareNoOpSessionContext(readConcern))); } public ConnectionSource retain() { diff --git a/driver-core/src/main/com/mongodb/internal/binding/ConnectionSource.java b/driver-core/src/main/com/mongodb/internal/binding/ConnectionSource.java index 90c8b85cf16..5ba3c562895 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/ConnectionSource.java +++ b/driver-core/src/main/com/mongodb/internal/binding/ConnectionSource.java @@ -19,19 +19,20 @@ import com.mongodb.ReadPreference; import com.mongodb.connection.ServerDescription; import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.OperationContext; /** * A source of connections to a single MongoDB server. * *

This class is not part of the public API and may be removed or changed at any time

*/ -public interface ConnectionSource extends BindingContext, ReferenceCounted { +public interface ConnectionSource extends ReferenceCounted { ServerDescription getServerDescription(); ReadPreference getReadPreference(); - Connection getConnection(); + Connection getConnection(OperationContext operationContext); @Override ConnectionSource retain(); diff --git a/driver-core/src/main/com/mongodb/internal/binding/ReadBinding.java b/driver-core/src/main/com/mongodb/internal/binding/ReadBinding.java index ffdde848382..67c10ccbc5c 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/ReadBinding.java +++ b/driver-core/src/main/com/mongodb/internal/binding/ReadBinding.java @@ -17,6 +17,7 @@ package com.mongodb.internal.binding; import com.mongodb.ReadPreference; +import com.mongodb.internal.connection.OperationContext; /** * A factory of connection sources to servers that can be read from and that satisfy the specified read preference. @@ -30,7 +31,7 @@ public interface ReadBinding extends BindingContext, ReferenceCounted { * Returns a connection source to a server that satisfies the read preference with which this instance is configured. * @return the connection source */ - ConnectionSource getReadConnectionSource(); + ConnectionSource getReadConnectionSource(OperationContext operationContext); /** * Return a connection source that satisfies the read preference with which this instance is configured, if all connected servers have @@ -42,7 +43,7 @@ public interface ReadBinding extends BindingContext, ReferenceCounted { * * @see com.mongodb.internal.operation.AggregateToCollectionOperation */ - ConnectionSource getReadConnectionSource(int minWireVersion, ReadPreference fallbackReadPreference); + ConnectionSource getReadConnectionSource(int minWireVersion, ReadPreference fallbackReadPreference, OperationContext operationContext); @Override ReadBinding retain(); diff --git a/driver-core/src/main/com/mongodb/internal/binding/SingleServerBinding.java b/driver-core/src/main/com/mongodb/internal/binding/SingleServerBinding.java index 7d7e948c344..50497ae2526 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/SingleServerBinding.java +++ b/driver-core/src/main/com/mongodb/internal/binding/SingleServerBinding.java @@ -35,23 +35,21 @@ public class SingleServerBinding extends AbstractReferenceCounted implements ReadWriteBinding { private final Cluster cluster; private final ServerAddress serverAddress; - private final OperationContext operationContext; /** * Creates an instance, defaulting to {@link com.mongodb.ReadPreference#primary()} for reads. * @param cluster a non-null Cluster which will be used to select a server to bind to * @param serverAddress a non-null address of the server to bind to - * @param operationContext the operation context */ - public SingleServerBinding(final Cluster cluster, final ServerAddress serverAddress, final OperationContext operationContext) { + public SingleServerBinding(final Cluster cluster, final ServerAddress serverAddress) { this.cluster = notNull("cluster", cluster); this.serverAddress = notNull("serverAddress", serverAddress); - this.operationContext = notNull("operationContext", operationContext); } @Override - public ConnectionSource getWriteConnectionSource() { - return new SingleServerBindingConnectionSource(); + public ConnectionSource getWriteConnectionSource(final OperationContext operationContext) { + ServerTuple serverTuple = cluster.selectServer(new ServerAddressSelector(serverAddress), operationContext); + return new SingleServerBindingConnectionSource(serverTuple); } @Override @@ -60,20 +58,17 @@ public ReadPreference getReadPreference() { } @Override - public ConnectionSource getReadConnectionSource() { - return new SingleServerBindingConnectionSource(); + public ConnectionSource getReadConnectionSource(final OperationContext operationContext) { + ServerTuple serverTuple = cluster.selectServer(new ServerAddressSelector(serverAddress), operationContext); + return new SingleServerBindingConnectionSource(serverTuple); } @Override - public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference) { + public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference, + final OperationContext operationContext) { throw new UnsupportedOperationException(); } - @Override - public OperationContext getOperationContext() { - return operationContext; - } - @Override public SingleServerBinding retain() { super.retain(); @@ -83,10 +78,9 @@ public SingleServerBinding retain() { private final class SingleServerBindingConnectionSource extends AbstractReferenceCounted implements ConnectionSource { private final ServerDescription serverDescription; - private SingleServerBindingConnectionSource() { + private SingleServerBindingConnectionSource(final ServerTuple serverTuple) { SingleServerBinding.this.retain(); - ServerTuple serverTuple = cluster.selectServer(new ServerAddressSelector(serverAddress), operationContext); - serverDescription = serverTuple.getServerDescription(); + this.serverDescription = serverTuple.getServerDescription(); } @Override @@ -94,18 +88,13 @@ public ServerDescription getServerDescription() { return serverDescription; } - @Override - public OperationContext getOperationContext() { - return operationContext; - } - @Override public ReadPreference getReadPreference() { return ReadPreference.primary(); } @Override - public Connection getConnection() { + public Connection getConnection(final OperationContext operationContext) { return cluster .selectServer(new ServerAddressSelector(serverAddress), operationContext) .getServer() diff --git a/driver-core/src/main/com/mongodb/internal/binding/WriteBinding.java b/driver-core/src/main/com/mongodb/internal/binding/WriteBinding.java index b0ac674489c..beeee4c3bf2 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/WriteBinding.java +++ b/driver-core/src/main/com/mongodb/internal/binding/WriteBinding.java @@ -16,6 +16,8 @@ package com.mongodb.internal.binding; +import com.mongodb.internal.connection.OperationContext; + /** * A factory of connection sources to servers that can be written to, e.g, a standalone, a mongos, or a replica set primary. * @@ -27,7 +29,7 @@ public interface WriteBinding extends BindingContext, ReferenceCounted { * * @return a connection source */ - ConnectionSource getWriteConnectionSource(); + ConnectionSource getWriteConnectionSource(OperationContext operationContext); @Override WriteBinding retain(); diff --git a/driver-core/src/main/com/mongodb/internal/connection/CommandHelper.java b/driver-core/src/main/com/mongodb/internal/connection/CommandHelper.java index fa7c1f0739d..9670db5af2c 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/CommandHelper.java +++ b/driver-core/src/main/com/mongodb/internal/connection/CommandHelper.java @@ -129,7 +129,6 @@ private static CommandMessage getCommandMessage(final String database, final Bso public static void applyMaxTimeMS(final TimeoutContext timeoutContext, final BsonDocument command) { if (!timeoutContext.hasTimeoutMS()) { command.append("maxTimeMS", new BsonInt64(timeoutContext.getTimeoutSettings().getMaxTimeMS())); - timeoutContext.disableMaxTimeOverride(); } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java b/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java index 13e7ad987b5..ec611ee7b9b 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java +++ b/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java @@ -451,7 +451,7 @@ void doMaintenance() { } private boolean shouldEnsureMinSize() { - return settings.getMinSize() > 0; + return settings.getMinSize() > -1; } private boolean shouldPrune(final UsageTrackingInternalConnection connection) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index bf009aa1b07..f231f795a97 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -226,13 +226,15 @@ public int getGeneration() { public void open(final OperationContext originalOperationContext) { isTrue("Open already called", stream == null); stream = streamFactory.create(serverId.getAddress()); + OperationContext operationContext = originalOperationContext; try { - OperationContext operationContext = originalOperationContext - .withTimeoutContext(originalOperationContext.getTimeoutContext().withComputedServerSelectionTimeoutContext()); - + //COMMENT given that we already use serverSelection timeout in SyncOperationHelper, this step is not needed. +// OperationContext operationContext = originalOperationContext +// .withTimeoutContext(originalOperationContext.getTimeoutContext().withComputedServerSelectionTimeoutContext()); stream.open(operationContext); - InternalConnectionInitializationDescription initializationDescription = connectionInitializer.startHandshake(this, operationContext); + + operationContext = operationContext.withTimeoutContextOverride(TimeoutContext::withNewlyStartedTimeoutMaintenanceTimeout); initAfterHandshakeStart(initializationDescription); initializationDescription = connectionInitializer.finishHandshake(this, initializationDescription, operationContext); @@ -250,9 +252,11 @@ public void open(final OperationContext originalOperationContext) { @Override public void openAsync(final OperationContext originalOperationContext, final SingleResultCallback callback) { assertNull(stream); + OperationContext operationContext = originalOperationContext; try { - OperationContext operationContext = originalOperationContext - .withTimeoutContext(originalOperationContext.getTimeoutContext().withComputedServerSelectionTimeoutContext()); + //COMMENT given that we already use serverSelection timeout in SyncOperationHelper, this step is not needed. +// OperationContext operationContext = originalOperationContext +// .withTimeoutContext(originalOperationContext.getTimeoutContext().withComputedServerSelectionTimeoutContext()); stream = streamFactory.create(serverId.getAddress()); stream.openAsync(operationContext, new AsyncCompletionHandler() { @@ -268,7 +272,8 @@ public void completed(@Nullable final Void aVoid) { assertNotNull(initialResult); initAfterHandshakeStart(initialResult); connectionInitializer.finishHandshakeAsync(InternalStreamConnection.this, - initialResult, operationContext, (completedResult, completedException) -> { + initialResult, operationContext.withTimeoutContextOverride(TimeoutContext::withNewlyStartedTimeoutMaintenanceTimeout), + (completedResult, completedException) -> { if (completedException != null) { close(); callback.onResult(null, completedException); diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java index 79c21f33356..574a85669d0 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java @@ -144,8 +144,6 @@ private InternalConnectionInitializationDescription initializeConnectionDescript helloResult = executeCommand("admin", helloCommandDocument, clusterConnectionMode, serverApi, internalConnection, operationContext); } catch (MongoException e) { throw mapHelloException(e); - } finally { - operationContext.getTimeoutContext().resetMaintenanceTimeout(); } setSpeculativeAuthenticateResponse(helloResult); return createInitializationDescription(helloResult, internalConnection, start); diff --git a/driver-core/src/main/com/mongodb/internal/connection/OperationContext.java b/driver-core/src/main/com/mongodb/internal/connection/OperationContext.java index 7e0de92da1d..ecceed7c099 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OperationContext.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OperationContext.java @@ -15,6 +15,7 @@ */ package com.mongodb.internal.connection; +import com.mongodb.Function; import com.mongodb.MongoConnectionPoolClearedException; import com.mongodb.RequestContext; import com.mongodb.ServerAddress; @@ -33,6 +34,7 @@ import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import static java.util.stream.Collectors.toList; @@ -160,6 +162,19 @@ public ServerDeprioritization getServerDeprioritization() { return serverDeprioritization; } + public OperationContext withNewlyStartedTimeout() { + TimeoutContext tc = this.timeoutContext.withNewlyStartedTimeout(); + return this.withTimeoutContext(tc); + } + + public OperationContext withMinRoundTripTime(final ServerDescription serverDescription) { + return this.withTimeoutContext(this.timeoutContext.withMinRoundTripTime(TimeUnit.NANOSECONDS.toMillis(serverDescription.getMinRoundTripTimeNanos()))); + } + + public OperationContext withTimeoutContextOverride(final Function timeoutContextOverrideFunction) { + return this.withTimeoutContext(timeoutContextOverrideFunction.apply(timeoutContext)); + } + public static final class ServerDeprioritization { @Nullable private ServerAddress candidate; diff --git a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java index eeee3a31abd..059f810c611 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java @@ -27,6 +27,7 @@ import com.mongodb.SubjectProvider; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; +import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.diagnostics.logging.Logger; import com.mongodb.internal.diagnostics.logging.Loggers; @@ -63,12 +64,18 @@ abstract class SaslAuthenticator extends Authenticator implements SpeculativeAut } public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription, - final OperationContext operationContext) { + final OperationContext originalOperationContext) { doAsSubject(() -> { + OperationContext operationContext = originalOperationContext; SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress(), operationContext); throwIfSaslClientIsNull(saslClient); try { - BsonDocument responseDocument = getNextSaslResponse(saslClient, connection, operationContext); + BsonDocument responseDocument = connection.opened() ? null : getSpeculativeAuthenticateResponse(); + if (responseDocument == null) { + responseDocument = getNextSaslResponse(saslClient, connection, operationContext); + operationContext = operationContext.withTimeoutContextOverride(TimeoutContext::withNewlyStartedTimeoutMaintenanceTimeout); + } + BsonInt32 conversationId = responseDocument.getInt32("conversationId"); while (!(responseDocument.getBoolean("done")).getValue()) { @@ -81,7 +88,7 @@ public void authenticate(final InternalConnection connection, final ConnectionDe } responseDocument = sendSaslContinue(conversationId, response, connection, operationContext); - operationContext.getTimeoutContext().resetMaintenanceTimeout(); + operationContext = operationContext.withTimeoutContextOverride(TimeoutContext::withNewlyStartedTimeoutMaintenanceTimeout); } if (!saslClient.isComplete()) { saslClient.evaluateChallenge((responseDocument.getBinary("payload")).getData()); @@ -117,6 +124,9 @@ void authenticateAsync(final InternalConnection connection, final ConnectionDesc public abstract String getMechanismName(); + /** + * Does not send any commands to the server + */ protected abstract SaslClient createSaslClient(ServerAddress serverAddress, OperationContext operationContext); protected void appendSaslStartOptions(final BsonDocument saslStartCommand) { @@ -131,11 +141,6 @@ private void throwIfSaslClientIsNull(@Nullable final SaslClient saslClient) { private BsonDocument getNextSaslResponse(final SaslClient saslClient, final InternalConnection connection, final OperationContext operationContext) { - BsonDocument response = connection.opened() ? null : getSpeculativeAuthenticateResponse(); - if (response != null) { - return response; - } - try { byte[] serverResponse = saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : null; return sendSaslStart(serverResponse, connection, operationContext); @@ -160,7 +165,9 @@ private void getNextSaslResponseAsync(final SaslClient saslClient, final Interna if (result.getBoolean("done").getValue()) { verifySaslClientComplete(saslClient, result, errHandlingCallback); } else { - new Continuator(saslClient, result, connection, operationContext, errHandlingCallback).start(); + OperationContext saslContinueOperationContext = + operationContext.withTimeoutContextOverride(TimeoutContext::withNewlyStartedTimeoutMaintenanceTimeout); + new Continuator(saslClient, result, connection, saslContinueOperationContext, errHandlingCallback).start(); } }); } else if (response.getBoolean("done").getValue()) { @@ -232,22 +239,14 @@ private BsonDocument sendSaslStart(@Nullable final byte[] outToken, final Intern final OperationContext operationContext) { BsonDocument startDocument = createSaslStartCommandDocument(outToken); appendSaslStartOptions(startDocument); - try { return executeCommand(getMongoCredential().getSource(), startDocument, getClusterConnectionMode(), getServerApi(), connection, operationContext); - } finally { - operationContext.getTimeoutContext().resetMaintenanceTimeout(); - } } private BsonDocument sendSaslContinue(final BsonInt32 conversationId, final byte[] outToken, final InternalConnection connection, final OperationContext operationContext) { - try { return executeCommand(getMongoCredential().getSource(), createSaslContinueDocument(conversationId, outToken), getClusterConnectionMode(), getServerApi(), connection, operationContext); - } finally { - operationContext.getTimeoutContext().resetMaintenanceTimeout(); - } } private void sendSaslStartAsync(@Nullable final byte[] outToken, final InternalConnection connection, @@ -256,19 +255,13 @@ private void sendSaslStartAsync(@Nullable final byte[] outToken, final InternalC appendSaslStartOptions(startDocument); executeCommandAsync(getMongoCredential().getSource(), startDocument, getClusterConnectionMode(), getServerApi(), connection, - operationContext, (r, t) -> { - operationContext.getTimeoutContext().resetMaintenanceTimeout(); - callback.onResult(r, t); - }); + operationContext, callback::onResult); } private void sendSaslContinueAsync(final BsonInt32 conversationId, final byte[] outToken, final InternalConnection connection, final OperationContext operationContext, final SingleResultCallback callback) { executeCommandAsync(getMongoCredential().getSource(), createSaslContinueDocument(conversationId, outToken), - getClusterConnectionMode(), getServerApi(), connection, operationContext, (r, t) -> { - operationContext.getTimeoutContext().resetMaintenanceTimeout(); - callback.onResult(r, t); - }); + getClusterConnectionMode(), getServerApi(), connection, operationContext, callback::onResult); } protected BsonDocument createSaslStartCommandDocument(@Nullable final byte[] outToken) { @@ -323,7 +316,7 @@ private final class Continuator implements SingleResultCallback { private final SaslClient saslClient; private final BsonDocument saslStartDocument; private final InternalConnection connection; - private final OperationContext operationContext; + private OperationContext operationContext; private final SingleResultCallback callback; Continuator(final SaslClient saslClient, final BsonDocument saslStartDocument, final InternalConnection connection, @@ -347,6 +340,7 @@ public void onResult(@Nullable final BsonDocument result, @Nullable final Throwa verifySaslClientComplete(saslClient, result, callback); disposeOfSaslClient(saslClient); } else { + operationContext = operationContext.withTimeoutContextOverride(TimeoutContext::withNewlyStartedTimeoutMaintenanceTimeout); continueConversation(result); } } diff --git a/driver-core/src/main/com/mongodb/internal/operation/AbortTransactionOperation.java b/driver-core/src/main/com/mongodb/internal/operation/AbortTransactionOperation.java index bc7e6655bc7..c2598a9bddd 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/AbortTransactionOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/AbortTransactionOperation.java @@ -51,9 +51,10 @@ public String getCommandName() { @Override CommandCreator getCommandCreator() { return (operationContext, serverDescription, connectionDescription) -> { - operationContext.getTimeoutContext().resetToDefaultMaxTime(); BsonDocument command = AbortTransactionOperation.super.getCommandCreator() - .create(operationContext, serverDescription, connectionDescription); + .create(operationContext.withTimeoutContextOverride(TimeoutContext::withDefaultMaxTime), + serverDescription, + connectionDescription); putIfNotNull(command, "recoveryToken", recoveryToken); return command; }; diff --git a/driver-core/src/main/com/mongodb/internal/operation/AbstractWriteSearchIndexOperation.java b/driver-core/src/main/com/mongodb/internal/operation/AbstractWriteSearchIndexOperation.java index 6ebcfda6dbe..87092c02a24 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/AbstractWriteSearchIndexOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/AbstractWriteSearchIndexOperation.java @@ -22,6 +22,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; @@ -45,12 +46,12 @@ abstract class AbstractWriteSearchIndexOperation implements WriteOperation } @Override - public Void execute(final WriteBinding binding) { - return withConnection(binding, connection -> { + public Void execute(final WriteBinding binding, final OperationContext operationContext) { + return withConnection(binding, operationContext, (connection, operationContextWithMinRtt) -> { try { - executeCommand(binding, namespace.getDatabaseName(), buildCommand(), + executeCommand(binding, operationContextWithMinRtt, namespace.getDatabaseName(), buildCommand(), connection, - writeConcernErrorTransformer(binding.getOperationContext().getTimeoutContext())); + writeConcernErrorTransformer(operationContextWithMinRtt.getTimeoutContext())); } catch (MongoCommandException mongoCommandException) { swallowOrThrow(mongoCommandException); } @@ -59,20 +60,21 @@ public Void execute(final WriteBinding binding) { } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - withAsyncSourceAndConnection(binding::getWriteConnectionSource, false, callback, - (connectionSource, connection, cb) -> - executeCommandAsync(binding, namespace.getDatabaseName(), buildCommand(), connection, - writeConcernErrorTransformerAsync(binding.getOperationContext().getTimeoutContext()), (result, commandExecutionError) -> { + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + withAsyncSourceAndConnection(binding::getWriteConnectionSource, false, operationContext, callback, + (connectionSource, connection, operationContextWithMinRtt, cb) -> + executeCommandAsync(binding, operationContextWithMinRtt, namespace.getDatabaseName(), buildCommand(), connection, + writeConcernErrorTransformerAsync(operationContextWithMinRtt.getTimeoutContext()), (result, commandExecutionError) -> { try { swallowOrThrow(commandExecutionError); + //TODO why call callback and not cb? callback.onResult(result, null); } catch (Throwable mongoCommandException) { + //TODO why call callback and not cb? callback.onResult(null, mongoCommandException); } } - ) - ); + )); } /** diff --git a/driver-core/src/main/com/mongodb/internal/operation/AggregateOperation.java b/driver-core/src/main/com/mongodb/internal/operation/AggregateOperation.java index 1c9abfc68ca..8bbf52fe9ce 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/AggregateOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/AggregateOperation.java @@ -25,6 +25,7 @@ import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; import com.mongodb.internal.client.model.AggregationLevel; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonValue; @@ -141,13 +142,13 @@ public String getCommandName() { } @Override - public BatchCursor execute(final ReadBinding binding) { - return wrapped.execute(binding); + public BatchCursor execute(final ReadBinding binding, final OperationContext operationContext) { + return wrapped.execute(binding, operationContext); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { - wrapped.executeAsync(binding, callback); + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback> callback) { + wrapped.executeAsync(binding, operationContext, callback); } @Override @@ -156,7 +157,7 @@ public ReadOperationSimple asExplainableOperation(@Nullable final Explain } CommandReadOperation createExplainableOperation(@Nullable final ExplainVerbosity verbosity, final Decoder resultDecoder) { - return new CommandReadOperation<>(getNamespace().getDatabaseName(), wrapped.getCommandName(), + return new ExplainCommandOperation<>(getNamespace().getDatabaseName(), getCommandName(), (operationContext, serverDescription, connectionDescription) -> { BsonDocument command = wrapped.getCommand(operationContext, UNKNOWN_WIRE_VERSION); applyMaxTimeMS(operationContext.getTimeoutContext(), command); diff --git a/driver-core/src/main/com/mongodb/internal/operation/AggregateOperationImpl.java b/driver-core/src/main/com/mongodb/internal/operation/AggregateOperationImpl.java index 4c9bc3828b7..91cc5581599 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/AggregateOperationImpl.java +++ b/driver-core/src/main/com/mongodb/internal/operation/AggregateOperationImpl.java @@ -47,7 +47,7 @@ import static com.mongodb.internal.operation.AsyncOperationHelper.executeRetryableReadAsync; import static com.mongodb.internal.operation.CommandOperationHelper.CommandCreator; import static com.mongodb.internal.operation.OperationHelper.LOGGER; -import static com.mongodb.internal.operation.OperationHelper.setNonTailableCursorMaxTimeSupplier; +import static com.mongodb.internal.operation.OperationHelper.applyTimeoutModeToOperationContext; import static com.mongodb.internal.operation.OperationReadConcernHelper.appendReadConcernToCommand; import static com.mongodb.internal.operation.SyncOperationHelper.CommandReadTransformer; import static com.mongodb.internal.operation.SyncOperationHelper.executeRetryableRead; @@ -192,16 +192,16 @@ public String getCommandName() { } @Override - public BatchCursor execute(final ReadBinding binding) { - return executeRetryableRead(binding, namespace.getDatabaseName(), + public BatchCursor execute(final ReadBinding binding, final OperationContext operationContext) { + return executeRetryableRead(binding, applyTimeoutModeToOperationContext(timeoutMode, operationContext), namespace.getDatabaseName(), getCommandCreator(), CommandResultDocumentCodec.create(decoder, FIELD_NAMES_WITH_RESULT), transformer(), retryReads); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback> callback) { SingleResultCallback> errHandlingCallback = errorHandlingCallback(callback, LOGGER); - executeRetryableReadAsync(binding, namespace.getDatabaseName(), + executeRetryableReadAsync(binding, applyTimeoutModeToOperationContext(timeoutMode, operationContext), namespace.getDatabaseName(), getCommandCreator(), CommandResultDocumentCodec.create(decoder, FIELD_NAMES_WITH_RESULT), asyncTransformer(), retryReads, errHandlingCallback); @@ -216,7 +216,6 @@ BsonDocument getCommand(final OperationContext operationContext, final int maxWi BsonDocument commandDocument = new BsonDocument(getCommandName(), aggregateTarget.create()); appendReadConcernToCommand(operationContext.getSessionContext(), maxWireVersion, commandDocument); commandDocument.put("pipeline", pipelineCreator.create()); - setNonTailableCursorMaxTimeSupplier(timeoutMode, operationContext); BsonDocument cursor = new BsonDocument(); if (batchSize != null) { cursor.put("batchSize", new BsonInt32(batchSize)); @@ -242,15 +241,19 @@ BsonDocument getCommand(final OperationContext operationContext, final int maxWi } private CommandReadTransformer> transformer() { - return (result, source, connection) -> - new CommandBatchCursor<>(getTimeoutMode(), result, batchSize != null ? batchSize : 0, - getMaxTimeForCursor(source.getOperationContext().getTimeoutContext()), decoder, comment, source, connection); + return (result, source, connection, operationContext) -> + new CommandBatchCursor<>(getTimeoutMode(), getMaxTimeForCursor(operationContext.getTimeoutContext()), operationContext, new CommandCoreCursor<>( + result, batchSize != null ? batchSize : 0, + decoder, comment, source, connection + )); } private CommandReadTransformerAsync> asyncTransformer() { - return (result, source, connection) -> - new AsyncCommandBatchCursor<>(getTimeoutMode(), result, batchSize != null ? batchSize : 0, - getMaxTimeForCursor(source.getOperationContext().getTimeoutContext()), decoder, comment, source, connection); + return (result, source, connection, operationContext) -> + new AsyncCommandBatchCursor<>(getTimeoutMode(), getMaxTimeForCursor(operationContext.getTimeoutContext()), + operationContext, new AsyncCommandCoreCursor<>( + result, batchSize != null ? batchSize : 0, decoder, comment, source, connection + )); } private TimeoutMode getTimeoutMode() { diff --git a/driver-core/src/main/com/mongodb/internal/operation/AggregateToCollectionOperation.java b/driver-core/src/main/com/mongodb/internal/operation/AggregateToCollectionOperation.java index 16f33ad45e5..36e7e1aad5f 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/AggregateToCollectionOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/AggregateToCollectionOperation.java @@ -26,6 +26,7 @@ import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; import com.mongodb.internal.client.model.AggregationLevel; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonArray; import org.bson.BsonBoolean; @@ -39,8 +40,10 @@ import static com.mongodb.assertions.Assertions.isTrueArgument; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.operation.AsyncOperationHelper.CommandReadTransformerAsync; import static com.mongodb.internal.operation.AsyncOperationHelper.executeRetryableReadAsync; import static com.mongodb.internal.operation.ServerVersionHelper.FIVE_DOT_ZERO_WIRE_VERSION; +import static com.mongodb.internal.operation.SyncOperationHelper.CommandReadTransformer; import static com.mongodb.internal.operation.SyncOperationHelper.executeRetryableRead; import static com.mongodb.internal.operation.WriteConcernHelper.appendWriteConcernToCommand; import static com.mongodb.internal.operation.WriteConcernHelper.throwOnWriteConcernError; @@ -158,30 +161,33 @@ public String getCommandName() { } @Override - public Void execute(final ReadBinding binding) { - return executeRetryableRead(binding, - () -> binding.getReadConnectionSource(FIVE_DOT_ZERO_WIRE_VERSION, ReadPreference.primary()), - namespace.getDatabaseName(), - getCommandCreator(), - new BsonDocumentCodec(), (result, source, connection) -> { - throwOnWriteConcernError(result, connection.getDescription().getServerAddress(), - connection.getDescription().getMaxWireVersion(), binding.getOperationContext().getTimeoutContext()); - return null; - }, false); + public Void execute(final ReadBinding binding, final OperationContext operationContext) { + return executeRetryableRead( + operationContext, + (serverSelectionOperationContext) -> + binding.getReadConnectionSource( + FIVE_DOT_ZERO_WIRE_VERSION, + ReadPreference.primary(), + serverSelectionOperationContext), + namespace.getDatabaseName(), + getCommandCreator(), + new BsonDocumentCodec(), transformer(), false); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback callback) { - executeRetryableReadAsync(binding, - (connectionSourceCallback) -> - binding.getReadConnectionSource(FIVE_DOT_ZERO_WIRE_VERSION, ReadPreference.primary(), connectionSourceCallback), - namespace.getDatabaseName(), - getCommandCreator(), - new BsonDocumentCodec(), (result, source, connection) -> { - throwOnWriteConcernError(result, connection.getDescription().getServerAddress(), - connection.getDescription().getMaxWireVersion(), binding.getOperationContext().getTimeoutContext()); - return null; - }, false, callback); + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, + final SingleResultCallback callback) { + executeRetryableReadAsync( + binding, + operationContext, + (serverSelectionOperationContext, connectionSourceCallback) -> + binding.getReadConnectionSource(FIVE_DOT_ZERO_WIRE_VERSION, ReadPreference.primary(), serverSelectionOperationContext, connectionSourceCallback), + namespace.getDatabaseName(), + getCommandCreator(), + new BsonDocumentCodec(), + asyncTransformer(), + false, + callback); } private CommandOperationHelper.CommandCreator getCommandCreator() { @@ -220,4 +226,20 @@ private CommandOperationHelper.CommandCreator getCommandCreator() { return commandDocument; }; } + + private static CommandReadTransformer transformer() { + return (result, source, connection, operationContext) -> { + throwOnWriteConcernError(result, connection.getDescription().getServerAddress(), + connection.getDescription().getMaxWireVersion(), operationContext.getTimeoutContext()); + return null; + }; + } + + private static CommandReadTransformerAsync asyncTransformer() { + return (result, source, connection, operationContext) -> { + throwOnWriteConcernError(result, connection.getDescription().getServerAddress(), + connection.getDescription().getMaxWireVersion(), operationContext.getTimeoutContext()); + return null; + }; + } } diff --git a/driver-core/src/main/com/mongodb/internal/operation/AsyncChangeStreamBatchCursor.java b/driver-core/src/main/com/mongodb/internal/operation/AsyncChangeStreamBatchCursor.java index a4cfbafedb6..78b09c36539 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/AsyncChangeStreamBatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/AsyncChangeStreamBatchCursor.java @@ -1,3 +1,5 @@ +package com.mongodb.internal.operation; + /* * Copyright 2008-present MongoDB, Inc. * @@ -14,14 +16,13 @@ * limitations under the License. */ -package com.mongodb.internal.operation; import com.mongodb.MongoException; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.async.AsyncAggregateResponseBatchCursor; -import com.mongodb.internal.async.AsyncBatchCursor; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; @@ -43,7 +44,7 @@ final class AsyncChangeStreamBatchCursor implements AsyncAggregateResponseBatchCursor { private final AsyncReadBinding binding; - private final TimeoutContext timeoutContext; + private OperationContext operationContext; private final ChangeStreamOperation changeStreamOperation; private final int maxWireVersion; @@ -53,40 +54,41 @@ final class AsyncChangeStreamBatchCursor implements AsyncAggregateResponseBat * {@code wrapped} containing {@code null} and {@link #isClosed} being {@code false}. * This represents a situation in which the wrapped object was closed by {@code this} but {@code this} remained open. */ - private final AtomicReference> wrapped; + private final AtomicReference> wrapped; private final AtomicBoolean isClosed; AsyncChangeStreamBatchCursor(final ChangeStreamOperation changeStreamOperation, - final AsyncCommandBatchCursor wrapped, + final AsyncCoreCursor wrapped, final AsyncReadBinding binding, + final OperationContext operationContext, @Nullable final BsonDocument resumeToken, final int maxWireVersion) { this.changeStreamOperation = changeStreamOperation; this.wrapped = new AtomicReference<>(assertNotNull(wrapped)); this.binding = binding; binding.retain(); - this.timeoutContext = binding.getOperationContext().getTimeoutContext(); + this.operationContext = operationContext.withTimeoutContextOverride(TimeoutContext::withMaxTimeAsMaxAwaitTimeOverride); this.resumeToken = resumeToken; this.maxWireVersion = maxWireVersion; isClosed = new AtomicBoolean(); } @NonNull - AsyncCommandBatchCursor getWrapped() { + AsyncCoreCursor getWrapped() { return assertNotNull(wrapped.get()); } @Override public void next(final SingleResultCallback> callback) { - resumeableOperation(AsyncBatchCursor::next, callback, false); + OperationContext operationContext = resetTimeout(); + resumeableOperation(AsyncCoreCursor::next, callback, operationContext, false); } @Override public void close() { - timeoutContext.resetTimeoutIfPresent(); if (isClosed.compareAndSet(false, true)) { try { - nullifyAndCloseWrapped(); + nullifyAndCloseWrapped(operationContext.withNewlyStartedTimeout()); } finally { binding.release(); } @@ -116,7 +118,7 @@ public boolean isClosed() { } private boolean wrappedClosedItself() { - AsyncAggregateResponseBatchCursor observedWrapped = wrapped.get(); + AsyncCoreCursor observedWrapped = wrapped.get(); return observedWrapped != null && observedWrapped.isClosed(); } @@ -125,10 +127,10 @@ private boolean wrappedClosedItself() { * if {@link #wrappedClosedItself()} observes a {@linkplain AsyncAggregateResponseBatchCursor#isClosed() closed} wrapped object, * then it closed itself as opposed to being closed by {@code this}. */ - private void nullifyAndCloseWrapped() { - AsyncAggregateResponseBatchCursor observedWrapped = wrapped.getAndSet(null); + private void nullifyAndCloseWrapped(final OperationContext operationContext) { + AsyncCoreCursor observedWrapped = wrapped.getAndSet(null); if (observedWrapped != null) { - observedWrapped.close(); + observedWrapped.close(operationContext); } } @@ -137,14 +139,14 @@ private void nullifyAndCloseWrapped() { * {@code setWrappedOrCloseIt(AsyncCommandBatchCursor)} is called concurrently with or after (in the happens-before order) * the method {@link #close()}. */ - private void setWrappedOrCloseIt(final AsyncCommandBatchCursor newValue) { + private void setWrappedOrCloseIt(final AsyncCoreCursor newValue, final OperationContext operationContext) { if (isClosed()) { assertNull(wrapped.get()); - newValue.close(); + newValue.close(operationContext); } else { assertNull(wrapped.getAndSet(newValue)); if (isClosed()) { - nullifyAndCloseWrapped(); + nullifyAndCloseWrapped(operationContext); } } } @@ -169,7 +171,7 @@ public int getMaxWireVersion() { return maxWireVersion; } - private void cachePostBatchResumeToken(final AsyncCommandBatchCursor cursor) { + private void cachePostBatchResumeToken(final AsyncCoreCursor cursor) { BsonDocument resumeToken = cursor.getPostBatchResumeToken(); if (resumeToken != null) { this.resumeToken = resumeToken; @@ -177,19 +179,23 @@ private void cachePostBatchResumeToken(final AsyncCommandBatchCursor cursor, SingleResultCallback> callback); + void apply(AsyncCoreCursor cursor, OperationContext operationContext, + SingleResultCallback> callback); } - private void resumeableOperation(final AsyncBlock asyncBlock, final SingleResultCallback> callback, final boolean tryNext) { - timeoutContext.resetTimeoutIfPresent(); + private void resumeableOperation(final AsyncBlock asyncBlock, + final SingleResultCallback> callback, + final OperationContext operationContext, + final boolean tryNext) { + //timeoutContext.resetTimeoutIfPresent(); //FIXme it was a bug, we reset timeout on retry which is against the spec. Moved to next() method. SingleResultCallback> errHandlingCallback = errorHandlingCallback(callback, LOGGER); if (isClosed()) { errHandlingCallback.onResult(null, new MongoException(format("%s called after the cursor was closed.", tryNext ? "tryNext()" : "next()"))); return; } - AsyncCommandBatchCursor wrappedCursor = getWrapped(); - asyncBlock.apply(wrappedCursor, (result, t) -> { + AsyncCoreCursor wrappedCursor = getWrapped(); + asyncBlock.apply(wrappedCursor, operationContext, (result, t) -> { if (t == null) { try { List convertedResults; @@ -206,8 +212,8 @@ private void resumeableOperation(final AsyncBlock asyncBlock, final SingleResult } else { cachePostBatchResumeToken(wrappedCursor); if (isResumableError(t, maxWireVersion)) { - nullifyAndCloseWrapped(); - retryOperation(asyncBlock, errHandlingCallback, tryNext); + nullifyAndCloseWrapped(operationContext); + retryOperation(asyncBlock, errHandlingCallback, operationContext, tryNext); } else { errHandlingCallback.onResult(null, t); } @@ -215,26 +221,29 @@ private void resumeableOperation(final AsyncBlock asyncBlock, final SingleResult }); } - private void retryOperation(final AsyncBlock asyncBlock, final SingleResultCallback> callback, + private void retryOperation(final AsyncBlock asyncBlock, + final SingleResultCallback> callback, + final OperationContext operationContext, final boolean tryNext) { - withAsyncReadConnectionSource(binding, (source, t) -> { + withAsyncReadConnectionSource(binding, operationContext, (source, t) -> { if (t != null) { callback.onResult(null, t); } else { changeStreamOperation.setChangeStreamOptionsForResume(resumeToken, assertNotNull(source).getServerDescription().getMaxWireVersion()); source.release(); - changeStreamOperation.executeAsync(binding, (asyncBatchCursor, t1) -> { + changeStreamOperation.executeAsync(binding, operationContext, (asyncBatchCursor, t1) -> { if (t1 != null) { callback.onResult(null, t1); } else { try { - setWrappedOrCloseIt(assertNotNull((AsyncChangeStreamBatchCursor) asyncBatchCursor).getWrapped()); + setWrappedOrCloseIt(assertNotNull((AsyncChangeStreamBatchCursor) asyncBatchCursor).getWrapped(), + operationContext); } finally { try { binding.release(); // release the new change stream batch cursor's reference to the binding } finally { - resumeableOperation(asyncBlock, callback, tryNext); + resumeableOperation(asyncBlock, callback, operationContext, tryNext); } } } @@ -242,4 +251,10 @@ private void retryOperation(final AsyncBlock asyncBlock, final SingleResultCallb } }); } + + private OperationContext resetTimeout() { + operationContext = operationContext.withNewlyStartedTimeout(); + return operationContext; + } } + diff --git a/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandBatchCursor.java b/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandBatchCursor.java index 942721a27ad..036d91e1530 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandBatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandBatchCursor.java @@ -16,354 +16,94 @@ package com.mongodb.internal.operation; -import com.mongodb.MongoCommandException; -import com.mongodb.MongoException; -import com.mongodb.MongoNamespace; -import com.mongodb.MongoOperationTimeoutException; -import com.mongodb.MongoSocketException; -import com.mongodb.ReadPreference; -import com.mongodb.ServerAddress; -import com.mongodb.ServerCursor; -import com.mongodb.annotations.ThreadSafe; import com.mongodb.client.cursor.TimeoutMode; -import com.mongodb.connection.ConnectionDescription; -import com.mongodb.connection.ServerType; -import com.mongodb.internal.TimeoutContext; -import com.mongodb.internal.VisibleForTesting; import com.mongodb.internal.async.AsyncAggregateResponseBatchCursor; import com.mongodb.internal.async.SingleResultCallback; -import com.mongodb.internal.async.function.AsyncCallbackSupplier; -import com.mongodb.internal.binding.AsyncConnectionSource; -import com.mongodb.internal.connection.AsyncConnection; -import com.mongodb.internal.connection.Connection; import com.mongodb.internal.connection.OperationContext; -import com.mongodb.internal.operation.AsyncOperationHelper.AsyncCallableConnectionWithCallback; -import com.mongodb.internal.validator.NoOpFieldNameValidator; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonTimestamp; -import org.bson.BsonValue; -import org.bson.codecs.BsonDocumentCodec; -import org.bson.codecs.Decoder; import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -import static com.mongodb.assertions.Assertions.assertNotNull; -import static com.mongodb.assertions.Assertions.assertTrue; -import static com.mongodb.assertions.Assertions.doesNotThrow; -import static com.mongodb.internal.async.AsyncRunnable.beginAsync; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.FIRST_BATCH; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.MESSAGE_IF_CLOSED_AS_CURSOR; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.NEXT_BATCH; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.getKillCursorsCommand; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.getMoreCommandDocument; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.logCommandCursorResult; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.translateCommandException; -import static java.util.Collections.emptyList; +public class AsyncCommandBatchCursor implements AsyncAggregateResponseBatchCursor { -class AsyncCommandBatchCursor implements AsyncAggregateResponseBatchCursor { - - private final MongoNamespace namespace; - private final Decoder decoder; - @Nullable - private final BsonValue comment; - private final int maxWireVersion; - private final boolean firstBatchEmpty; - private final ResourceManager resourceManager; - private final OperationContext operationContext; private final TimeoutMode timeoutMode; - private final AtomicBoolean processedInitial = new AtomicBoolean(); - private int batchSize; - private volatile CommandCursorResult commandCursorResult; - private boolean resetTimeoutWhenClosing; + private OperationContext operationContext; + + private AsyncCoreCursor wrapped; AsyncCommandBatchCursor( final TimeoutMode timeoutMode, - final BsonDocument commandCursorDocument, - final int batchSize, final long maxTimeMS, - final Decoder decoder, - @Nullable final BsonValue comment, - final AsyncConnectionSource connectionSource, - final AsyncConnection connection) { - ConnectionDescription connectionDescription = connection.getDescription(); - this.commandCursorResult = toCommandCursorResult(connectionDescription.getServerAddress(), FIRST_BATCH, commandCursorDocument); - this.namespace = commandCursorResult.getNamespace(); - this.batchSize = batchSize; - this.decoder = decoder; - this.comment = comment; - this.maxWireVersion = connectionDescription.getMaxWireVersion(); - this.firstBatchEmpty = commandCursorResult.getResults().isEmpty(); - operationContext = connectionSource.getOperationContext(); + final long maxTimeMs, + final OperationContext operationContext, + final AsyncCoreCursor wrapped) { + this.operationContext = operationContext.withTimeoutContextOverride(timeoutContext -> + timeoutContext.withMaxTimeOverride(maxTimeMs)); this.timeoutMode = timeoutMode; - - operationContext.getTimeoutContext().setMaxTimeOverride(maxTimeMS); - - AsyncConnection connectionToPin = connectionSource.getServerDescription().getType() == ServerType.LOAD_BALANCER - ? connection : null; - resourceManager = new ResourceManager(namespace, connectionSource, connectionToPin, commandCursorResult.getServerCursor()); - resetTimeoutWhenClosing = true; + this.wrapped = wrapped; } @Override public void next(final SingleResultCallback> callback) { - resourceManager.execute(funcCallback -> { - checkTimeoutModeAndResetTimeoutContextIfIteration(); - ServerCursor localServerCursor = resourceManager.getServerCursor(); - boolean serverCursorIsNull = localServerCursor == null; - List batchResults = emptyList(); - if (!processedInitial.getAndSet(true) && !firstBatchEmpty) { - batchResults = commandCursorResult.getResults(); - } - - if (serverCursorIsNull || !batchResults.isEmpty()) { - funcCallback.onResult(batchResults, null); - } else { - getMore(localServerCursor, funcCallback); - } - }, callback); + resetTimeout(); + wrapped.next(operationContext, callback); } @Override - public boolean isClosed() { - return !resourceManager.operable(); + public void setBatchSize(final int batchSize) { + wrapped.setBatchSize(batchSize); } @Override - public void setBatchSize(final int batchSize) { - this.batchSize = batchSize; + public int getBatchSize() { + return wrapped.getBatchSize(); } @Override - public int getBatchSize() { - return batchSize; + public boolean isClosed() { + return wrapped.isClosed(); } @Override public void close() { - resourceManager.close(); + wrapped.close(operationContext + .withTimeoutContextOverride(timeoutContext -> timeoutContext + .withNewlyStartedTimeout() + .withDefaultMaxTime() + )); } @Nullable - @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) - ServerCursor getServerCursor() { - if (!resourceManager.operable()) { - return null; - } - return resourceManager.getServerCursor(); - } - @Override public BsonDocument getPostBatchResumeToken() { - return commandCursorResult.getPostBatchResumeToken(); + return wrapped.getPostBatchResumeToken(); } + @Nullable @Override public BsonTimestamp getOperationTime() { - return commandCursorResult.getOperationTime(); + return wrapped.getOperationTime(); } @Override public boolean isFirstBatchEmpty() { - return firstBatchEmpty; + return wrapped.isFirstBatchEmpty(); } @Override public int getMaxWireVersion() { - return maxWireVersion; + return wrapped.getMaxWireVersion(); } - void checkTimeoutModeAndResetTimeoutContextIfIteration() { + private void resetTimeout() { if (timeoutMode == TimeoutMode.ITERATION) { - operationContext.getTimeoutContext().resetTimeoutIfPresent(); + operationContext = operationContext.withNewlyStartedTimeout(); } } - private void getMore(final ServerCursor cursor, final SingleResultCallback> callback) { - resourceManager.executeWithConnection((connection, wrappedCallback) -> - getMoreLoop(assertNotNull(connection), cursor, wrappedCallback), callback); - } - - private void getMoreLoop(final AsyncConnection connection, final ServerCursor serverCursor, - final SingleResultCallback> callback) { - connection.commandAsync(namespace.getDatabaseName(), - getMoreCommandDocument(serverCursor.getId(), connection.getDescription(), namespace, batchSize, comment), - NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), - CommandResultDocumentCodec.create(decoder, NEXT_BATCH), - assertNotNull(resourceManager.getConnectionSource()).getOperationContext(), - (commandResult, t) -> { - if (t != null) { - Throwable translatedException = - t instanceof MongoCommandException - ? translateCommandException((MongoCommandException) t, serverCursor) - : t; - callback.onResult(null, translatedException); - return; - } - commandCursorResult = toCommandCursorResult( - connection.getDescription().getServerAddress(), NEXT_BATCH, assertNotNull(commandResult)); - ServerCursor nextServerCursor = commandCursorResult.getServerCursor(); - resourceManager.setServerCursor(nextServerCursor); - List nextBatch = commandCursorResult.getResults(); - if (nextServerCursor == null || !nextBatch.isEmpty()) { - callback.onResult(nextBatch, null); - return; - } - - if (!resourceManager.operable()) { - callback.onResult(emptyList(), null); - return; - } - - getMoreLoop(connection, nextServerCursor, callback); - }); - } - - private CommandCursorResult toCommandCursorResult(final ServerAddress serverAddress, final String fieldNameContainingBatch, - final BsonDocument commandCursorDocument) { - CommandCursorResult commandCursorResult = new CommandCursorResult<>(serverAddress, fieldNameContainingBatch, - commandCursorDocument); - logCommandCursorResult(commandCursorResult); - return commandCursorResult; - } - - /** - * Configures the cursor to {@link #close()} - * without {@linkplain TimeoutContext#resetTimeoutIfPresent() resetting} its {@linkplain TimeoutContext#getTimeout() timeout}. - * This is useful when managing the {@link #close()} behavior externally. - */ - AsyncCommandBatchCursor disableTimeoutResetWhenClosing() { - resetTimeoutWhenClosing = false; - return this; - } - - @ThreadSafe - private final class ResourceManager extends CursorResourceManager { - ResourceManager( - final MongoNamespace namespace, - final AsyncConnectionSource connectionSource, - @Nullable final AsyncConnection connectionToPin, - @Nullable final ServerCursor serverCursor) { - super(namespace, connectionSource, connectionToPin, serverCursor); - } - - /** - * Thread-safe. - * Executes {@code operation} within the {@link #tryStartOperation()}/{@link #endOperation()} bounds. - */ - void execute(final AsyncCallbackSupplier operation, final SingleResultCallback callback) { - boolean canStartOperation = doesNotThrow(this::tryStartOperation); - if (!canStartOperation) { - callback.onResult(null, new IllegalStateException(MESSAGE_IF_CLOSED_AS_CURSOR)); - } else { - operation.whenComplete(() -> { - endOperation(); - if (super.getServerCursor() == null) { - // At this point all resources have been released, - // but `isClose` may still be returning `false` if `close` have not been called. - // Self-close to update the state managed by `ResourceManger`, and so that `isClosed` return `true`. - close(); - } - }).get(callback); - } - } - - @Override - void markAsPinned(final AsyncConnection connectionToPin, final Connection.PinningMode pinningMode) { - connectionToPin.markAsPinned(pinningMode); - } - - @Override - void doClose() { - TimeoutContext timeoutContext = operationContext.getTimeoutContext(); - timeoutContext.resetToDefaultMaxTime(); - SingleResultCallback thenDoNothing = (r, t) -> {}; - if (resetTimeoutWhenClosing) { - timeoutContext.doWithResetTimeout(this::releaseResourcesAsync, thenDoNothing); - } else { - releaseResourcesAsync(thenDoNothing); - } - } - - private void releaseResourcesAsync(final SingleResultCallback callback) { - beginAsync().thenRunTryCatchAsyncBlocks(c -> { - if (isSkipReleasingServerResourcesOnClose()) { - unsetServerCursor(); - } - if (super.getServerCursor() != null) { - beginAsync().thenSupply(c2 -> { - getConnection(c2); - }).thenConsume((connection, c3) -> { - beginAsync().thenRun(c4 -> { - releaseServerResourcesAsync(connection, c4); - }).thenAlwaysRunAndFinish(() -> { - connection.release(); - }, c3); - }).finish(c); - } else { - c.complete(c); - } - }, MongoException.class, (e, c5) -> { - c5.complete(c5); // ignore exceptions when releasing server resources - }).thenAlwaysRunAndFinish(() -> { - // guarantee that regardless of exceptions, `serverCursor` is null and client resources are released - unsetServerCursor(); - releaseClientResources(); - }, callback); - } - - void executeWithConnection(final AsyncCallableConnectionWithCallback callable, final SingleResultCallback callback) { - getConnection((connection, t) -> { - if (t != null) { - callback.onResult(null, t); - return; - } - callable.call(assertNotNull(connection), (result, t1) -> { - if (t1 != null) { - handleException(connection, t1); - } - connection.release(); - callback.onResult(result, t1); - }); - }); - } - - private void handleException(final AsyncConnection connection, final Throwable exception) { - if (exception instanceof MongoOperationTimeoutException && exception.getCause() instanceof MongoSocketException) { - onCorruptedConnection(connection, (MongoSocketException) exception.getCause()); - } else if (exception instanceof MongoSocketException) { - onCorruptedConnection(connection, (MongoSocketException) exception); - } - } - - private void getConnection(final SingleResultCallback callback) { - assertTrue(getState() != State.IDLE); - AsyncConnection pinnedConnection = getPinnedConnection(); - if (pinnedConnection != null) { - callback.onResult(assertNotNull(pinnedConnection).retain(), null); - } else { - assertNotNull(getConnectionSource()).getConnection(callback); - } - } - - private void releaseServerResourcesAsync(final AsyncConnection connection, final SingleResultCallback callback) { - beginAsync().thenRun((c) -> { - ServerCursor localServerCursor = super.getServerCursor(); - if (localServerCursor != null) { - killServerCursorAsync(getNamespace(), localServerCursor, connection, callback); - } else { - c.complete(c); - } - }).thenAlwaysRunAndFinish(() -> { - unsetServerCursor(); - }, callback); - } - - private void killServerCursorAsync(final MongoNamespace namespace, final ServerCursor localServerCursor, - final AsyncConnection localConnection, final SingleResultCallback callback) { - localConnection.commandAsync(namespace.getDatabaseName(), getKillCursorsCommand(namespace, localServerCursor), - NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), - operationContext, (r, t) -> callback.onResult(null, null)); - } + AsyncCoreCursor getWrapped() { + return wrapped; } } + diff --git a/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandCoreCursor.java b/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandCoreCursor.java new file mode 100644 index 00000000000..754c63208fd --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandCoreCursor.java @@ -0,0 +1,348 @@ +package com.mongodb.internal.operation; + +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import com.mongodb.MongoCommandException; +import com.mongodb.MongoException; +import com.mongodb.MongoNamespace; +import com.mongodb.MongoOperationTimeoutException; +import com.mongodb.MongoSocketException; +import com.mongodb.ReadPreference; +import com.mongodb.ServerAddress; +import com.mongodb.ServerCursor; +import com.mongodb.annotations.ThreadSafe; +import com.mongodb.connection.ConnectionDescription; +import com.mongodb.connection.ServerType; +import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.async.function.AsyncCallbackSupplier; +import com.mongodb.internal.binding.AsyncConnectionSource; +import com.mongodb.internal.connection.AsyncConnection; +import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.OperationContext; +import com.mongodb.internal.operation.AsyncOperationHelper.AsyncCallableConnectionWithCallback; +import com.mongodb.internal.validator.NoOpFieldNameValidator; +import com.mongodb.lang.Nullable; +import org.bson.BsonDocument; +import org.bson.BsonTimestamp; +import org.bson.BsonValue; +import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.Decoder; + +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.assertions.Assertions.doesNotThrow; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static com.mongodb.internal.async.SingleResultCallback.THEN_DO_NOTHING; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.FIRST_BATCH; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.MESSAGE_IF_CLOSED_AS_CURSOR; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.NEXT_BATCH; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.getKillCursorsCommand; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.getMoreCommandDocument; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.logCommandCursorResult; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.translateCommandException; +import static java.util.Collections.emptyList; + +class AsyncCommandCoreCursor implements AsyncCoreCursor { + + private final MongoNamespace namespace; + private final Decoder decoder; + @Nullable + private final BsonValue comment; + private final int maxWireVersion; + private final boolean firstBatchEmpty; + private final ResourceManager resourceManager; + private final AtomicBoolean processedInitial = new AtomicBoolean(); + private int batchSize; + private volatile CommandCursorResult commandCursorResult; + + AsyncCommandCoreCursor( + final BsonDocument commandCursorDocument, + final int batchSize, + final Decoder decoder, + @Nullable final BsonValue comment, + final AsyncConnectionSource connectionSource, + final AsyncConnection connection) { + ConnectionDescription connectionDescription = connection.getDescription(); + this.commandCursorResult = toCommandCursorResult(connectionDescription.getServerAddress(), FIRST_BATCH, commandCursorDocument); + this.namespace = commandCursorResult.getNamespace(); + this.batchSize = batchSize; + this.decoder = decoder; + this.comment = comment; + this.maxWireVersion = connectionDescription.getMaxWireVersion(); + this.firstBatchEmpty = commandCursorResult.getResults().isEmpty(); + AsyncConnection connectionToPin = connectionSource.getServerDescription().getType() == ServerType.LOAD_BALANCER + ? connection : null; + resourceManager = new ResourceManager(namespace, connectionSource, connectionToPin, commandCursorResult.getServerCursor()); + } + + @Override + public void next(final OperationContext operationContext, final SingleResultCallback> callback) { + resourceManager.execute(funcCallback -> { + //checkTimeoutModeAndResetTimeoutContextIfIteration(); //FIXME it was a bug? we should have reset the timeout when next was request to execute connection checkout and subsequent read wait on one timeout + ServerCursor localServerCursor = resourceManager.getServerCursor(); + boolean serverCursorIsNull = localServerCursor == null; + List batchResults = emptyList(); + if (!processedInitial.getAndSet(true) && !firstBatchEmpty) { + batchResults = commandCursorResult.getResults(); + } + + if (serverCursorIsNull || !batchResults.isEmpty()) { + funcCallback.onResult(batchResults, null); + } else { + getMore(localServerCursor, operationContext, funcCallback); + } + }, operationContext, callback); + } + + @Override + public boolean isClosed() { + return !resourceManager.operable(); + } + + @Override + public void setBatchSize(final int batchSize) { + this.batchSize = batchSize; + } + + @Override + public int getBatchSize() { + return batchSize; + } + + @Override + public void close(final OperationContext operationContext) { + resourceManager.close(operationContext); + } + + @Nullable + @Override + public ServerCursor getServerCursor() { + if (!resourceManager.operable()) { + return null; + } + return resourceManager.getServerCursor(); + } + + @Override + public BsonDocument getPostBatchResumeToken() { + return commandCursorResult.getPostBatchResumeToken(); + } + + @Override + public BsonTimestamp getOperationTime() { + return commandCursorResult.getOperationTime(); + } + + @Override + public boolean isFirstBatchEmpty() { + return firstBatchEmpty; + } + + @Override + public int getMaxWireVersion() { + return maxWireVersion; + } + + private void getMore(final ServerCursor cursor, final OperationContext operationContext, final SingleResultCallback> callback) { + resourceManager.executeWithConnection(operationContext, (connection, wrappedCallback) -> + getMoreLoop(assertNotNull(connection), cursor, operationContext, wrappedCallback), callback); + } + + private void getMoreLoop(final AsyncConnection connection, final ServerCursor serverCursor, + final OperationContext operationContext, + final SingleResultCallback> callback) { + connection.commandAsync(namespace.getDatabaseName(), + getMoreCommandDocument(serverCursor.getId(), connection.getDescription(), namespace, batchSize, comment), + NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), + CommandResultDocumentCodec.create(decoder, NEXT_BATCH), + operationContext, + (commandResult, t) -> { + if (t != null) { + Throwable translatedException = + t instanceof MongoCommandException + ? translateCommandException((MongoCommandException) t, serverCursor) + : t; + callback.onResult(null, translatedException); + return; + } + commandCursorResult = toCommandCursorResult( + connection.getDescription().getServerAddress(), NEXT_BATCH, assertNotNull(commandResult)); + ServerCursor nextServerCursor = commandCursorResult.getServerCursor(); + resourceManager.setServerCursor(nextServerCursor); + List nextBatch = commandCursorResult.getResults(); + if (nextServerCursor == null || !nextBatch.isEmpty()) { + callback.onResult(nextBatch, null); + return; + } + + if (!resourceManager.operable()) { + callback.onResult(emptyList(), null); + return; + } + + getMoreLoop(connection, nextServerCursor, operationContext, callback); + }); + } + + private CommandCursorResult toCommandCursorResult(final ServerAddress serverAddress, final String fieldNameContainingBatch, + final BsonDocument commandCursorDocument) { + CommandCursorResult commandCursorResult = new CommandCursorResult<>(serverAddress, fieldNameContainingBatch, + commandCursorDocument); + logCommandCursorResult(commandCursorResult); + return commandCursorResult; + } + + @ThreadSafe + private final class ResourceManager extends CursorResourceManagerNew { + ResourceManager( + final MongoNamespace namespace, + final AsyncConnectionSource connectionSource, + @Nullable final AsyncConnection connectionToPin, + @Nullable final ServerCursor serverCursor) { + super(namespace, connectionSource, connectionToPin, serverCursor); + } + + /** + * Thread-safe. + */ + void execute(final AsyncCallbackSupplier operation, final OperationContext operationContext, final SingleResultCallback callback) { + boolean canStartOperation = doesNotThrow(this::tryStartOperation); + if (!canStartOperation) { + callback.onResult(null, new IllegalStateException(MESSAGE_IF_CLOSED_AS_CURSOR)); + } else { + operation.whenComplete(() -> { + endOperation(operationContext); + if (super.getServerCursor() == null) { + // At this point all resources have been released, + // but `isClose` may still be returning `false` if `close` have not been called. + // Self-close to update the state managed by `ResourceManger`, and so that `isClosed` return `true`. + close(operationContext); + } + }).get(callback); + } + } + + @Override + void markAsPinned(final AsyncConnection connectionToPin, final Connection.PinningMode pinningMode) { + connectionToPin.markAsPinned(pinningMode); + } + + @Override + void doClose(final OperationContext operationContext) { + releaseResourcesAsync(operationContext, THEN_DO_NOTHING); + } + + private void releaseResourcesAsync(final OperationContext operationContext, final SingleResultCallback callback) { + beginAsync().thenRunTryCatchAsyncBlocks(c -> { + if (isSkipReleasingServerResourcesOnClose()) { + unsetServerCursor(); + } + if (super.getServerCursor() != null) { + beginAsync().thenSupply(c2 -> { + getConnection(operationContext, c2); + }).thenConsume((connection, c3) -> { + beginAsync().thenRun(c4 -> { + releaseServerResourcesAsync(connection, operationContext, c4); + }).thenAlwaysRunAndFinish(() -> { + connection.release(); + }, c3); + }).finish(c); + } else { + c.complete(c); + } + }, MongoException.class, (e, c5) -> { + c5.complete(c5); // ignore exceptions when releasing server resources + }).thenAlwaysRunAndFinish(() -> { + // guarantee that regardless of exceptions, `serverCursor` is null and client resources are released + unsetServerCursor(); + releaseClientResources(); + }, callback); + } + + void executeWithConnection(final OperationContext operationContext, final AsyncCallableConnectionWithCallback callable, + final SingleResultCallback callback) { + getConnection(operationContext, (connection, t) -> { + if (t != null) { + callback.onResult(null, t); + return; + } + callable.call(assertNotNull(connection), (result, t1) -> { + if (t1 != null) { + handleException(connection, t1); + } + connection.release(); + callback.onResult(result, t1); + }); + }); + } + + private void handleException(final AsyncConnection connection, final Throwable exception) { + if (exception instanceof MongoOperationTimeoutException && exception.getCause() instanceof MongoSocketException) { + onCorruptedConnection(connection, (MongoSocketException) exception.getCause()); + } else if (exception instanceof MongoSocketException) { + onCorruptedConnection(connection, (MongoSocketException) exception); + } + } + + private void getConnection(final OperationContext operationContext, final SingleResultCallback callback) { + assertTrue(getState() != State.IDLE); + AsyncConnection pinnedConnection = getPinnedConnection(); + if (pinnedConnection != null) { + callback.onResult(assertNotNull(pinnedConnection).retain(), null); + } else { + assertNotNull(getConnectionSource()).getConnection(operationContext, callback); + } + } + + private void releaseServerResourcesAsync(final AsyncConnection connection, final OperationContext operationContext, + final SingleResultCallback callback) { + beginAsync().thenRun((c) -> { + ServerCursor localServerCursor = super.getServerCursor(); + if (localServerCursor != null) { + killServerCursorAsync(getNamespace(), localServerCursor, connection, operationContext, callback); + } else { + c.complete(c); + } + }).thenAlwaysRunAndFinish(() -> { + unsetServerCursor(); + }, callback); + } + + private void killServerCursorAsync( + final MongoNamespace namespace, + final ServerCursor localServerCursor, + final AsyncConnection localConnection, + final OperationContext operationContext, + final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + localConnection.commandAsync( + namespace.getDatabaseName(), + getKillCursorsCommand(namespace, localServerCursor), + NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), + new BsonDocumentCodec(), + operationContext, + (r, t) -> c.complete(c)); + }).finish(callback); + } + } +} + diff --git a/driver-core/src/main/com/mongodb/internal/operation/AsyncCoreCursor.java b/driver-core/src/main/com/mongodb/internal/operation/AsyncCoreCursor.java new file mode 100644 index 00000000000..c689d48f7df --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/operation/AsyncCoreCursor.java @@ -0,0 +1,68 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.operation; + +import com.mongodb.ServerCursor; +import com.mongodb.internal.async.AsyncBatchCursor; +import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.connection.OperationContext; +import com.mongodb.lang.Nullable; +import org.bson.BsonDocument; +import org.bson.BsonTimestamp; + +import java.util.List; + +public interface AsyncCoreCursor { + void close(OperationContext operationContext); + void next(OperationContext operationContext, SingleResultCallback> callback); + + /** + * Sets the batch size to use when requesting the next batch. This is the number of documents to request in the next batch. + * + * @param batchSize the non-negative batch size. 0 means to use the server default. + */ + void setBatchSize(int batchSize); + + /** + * Gets the batch size to use when requesting the next batch. This is the number of documents to request in the next batch. + * + * @return the non-negative batch size. 0 means to use the server default. + */ + int getBatchSize(); + + @Nullable + ServerCursor getServerCursor(); + + + @Nullable + BsonDocument getPostBatchResumeToken(); + + @Nullable + BsonTimestamp getOperationTime(); + + boolean isFirstBatchEmpty(); + + int getMaxWireVersion(); + + + /** + * Implementations of {@link AsyncBatchCursor} are allowed to close themselves, see {@link #close()} for more details. + * + * @return {@code true} if {@code this} has been closed or has closed itself. + */ + boolean isClosed(); +} diff --git a/driver-core/src/main/com/mongodb/internal/operation/AsyncOperationHelper.java b/driver-core/src/main/com/mongodb/internal/operation/AsyncOperationHelper.java index f158b3944ae..7b134df2ee8 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/AsyncOperationHelper.java +++ b/driver-core/src/main/com/mongodb/internal/operation/AsyncOperationHelper.java @@ -21,12 +21,13 @@ import com.mongodb.ReadPreference; import com.mongodb.assertions.Assertions; import com.mongodb.client.cursor.TimeoutMode; +import com.mongodb.connection.ServerDescription; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.async.AsyncBatchCursor; import com.mongodb.internal.async.SingleResultCallback; -import com.mongodb.internal.async.function.AsyncCallbackBiFunction; import com.mongodb.internal.async.function.AsyncCallbackFunction; import com.mongodb.internal.async.function.AsyncCallbackSupplier; +import com.mongodb.internal.async.function.AsyncCallbackTriFunction; import com.mongodb.internal.async.function.RetryState; import com.mongodb.internal.async.function.RetryingAsyncCallbackSupplier; import com.mongodb.internal.binding.AsyncConnectionSource; @@ -62,7 +63,7 @@ final class AsyncOperationHelper { interface AsyncCallableWithConnection { - void call(@Nullable AsyncConnection connection, @Nullable Throwable t); + void call(@Nullable AsyncConnection connection, OperationContext operationContext, @Nullable Throwable t); } interface AsyncCallableConnectionWithCallback { @@ -94,39 +95,69 @@ interface CommandReadTransformerAsync { * @return the function result */ @Nullable - R apply(T t, AsyncConnectionSource source, AsyncConnection connection); + R apply(T t, AsyncConnectionSource source, AsyncConnection connection, OperationContext operationContext); } - static void withAsyncReadConnectionSource(final AsyncReadBinding binding, final AsyncCallableWithSource callable) { - binding.getReadConnectionSource(errorHandlingCallback(new AsyncCallableWithSourceCallback(callable), OperationHelper.LOGGER)); + static void withAsyncReadConnectionSource(final AsyncReadBinding binding, final OperationContext operationContext, + final AsyncCallableWithSource callable) { + binding.getReadConnectionSource(operationContext, + errorHandlingCallback(new AsyncCallableWithSourceCallback(callable), OperationHelper.LOGGER)); } - static void withAsyncConnection(final AsyncWriteBinding binding, final AsyncCallableWithConnection callable) { - binding.getWriteConnectionSource(errorHandlingCallback(new AsyncCallableWithConnectionCallback(callable), OperationHelper.LOGGER)); + static void withAsyncConnection(final AsyncWriteBinding binding, + final OperationContext originalOperationContext, + final AsyncCallableWithConnection callable) { + OperationContext serverSelectionOperationContext = originalOperationContext.withTimeoutContextOverride(TimeoutContext::withComputedServerSelectionTimeoutContextNew); + binding.getWriteConnectionSource( + serverSelectionOperationContext, + errorHandlingCallback( + new AsyncCallableWithConnectionCallback(callable, serverSelectionOperationContext, originalOperationContext), + OperationHelper.LOGGER)); } /** - * @see #withAsyncSuppliedResource(AsyncCallbackSupplier, boolean, SingleResultCallback, AsyncCallbackFunction) + * @see #withAsyncSuppliedResource(AsyncCallbackFunction, boolean, OperationContext, SingleResultCallback, AsyncCallbackFunction) */ - static void withAsyncSourceAndConnection(final AsyncCallbackSupplier sourceSupplier, - final boolean wrapConnectionSourceException, final SingleResultCallback callback, - final AsyncCallbackBiFunction asyncFunction) + static void withAsyncSourceAndConnection( + final AsyncCallbackFunction sourceAsyncFunction, + final boolean wrapConnectionSourceException, + final OperationContext operationContext, + final SingleResultCallback callback, + final AsyncCallbackTriFunction asyncFunction) throws OperationHelper.ResourceSupplierInternalException { SingleResultCallback errorHandlingCallback = errorHandlingCallback(callback, OperationHelper.LOGGER); - withAsyncSuppliedResource(sourceSupplier, wrapConnectionSourceException, errorHandlingCallback, + + OperationContext serverSelectionOperationContext = + operationContext.withTimeoutContextOverride(TimeoutContext::withComputedServerSelectionTimeoutContextNew); + withAsyncSuppliedResource( + sourceAsyncFunction, + wrapConnectionSourceException, + serverSelectionOperationContext, + errorHandlingCallback, (source, sourceReleasingCallback) -> - withAsyncSuppliedResource(source::getConnection, wrapConnectionSourceException, sourceReleasingCallback, + withAsyncSuppliedResource( + source::getConnection, + wrapConnectionSourceException, + serverSelectionOperationContext.withMinRoundTripTime(source.getServerDescription()), + sourceReleasingCallback, (connection, connectionAndSourceReleasingCallback) -> - asyncFunction.apply(source, connection, connectionAndSourceReleasingCallback))); + asyncFunction.apply( + source, + connection, + operationContext.withMinRoundTripTime(source.getServerDescription()), + connectionAndSourceReleasingCallback))); } - static void withAsyncSuppliedResource(final AsyncCallbackSupplier resourceSupplier, - final boolean wrapSourceConnectionException, final SingleResultCallback callback, - final AsyncCallbackFunction function) throws OperationHelper.ResourceSupplierInternalException { + static void withAsyncSuppliedResource(final AsyncCallbackFunction resourceSupplier, + final boolean wrapSourceConnectionException, + final OperationContext operationContext, + final SingleResultCallback callback, + final AsyncCallbackFunction function) + throws OperationHelper.ResourceSupplierInternalException { SingleResultCallback errorHandlingCallback = errorHandlingCallback(callback, OperationHelper.LOGGER); - resourceSupplier.get((resource, supplierException) -> { + resourceSupplier.apply(operationContext, (resource, supplierException) -> { if (supplierException != null) { if (wrapSourceConnectionException) { supplierException = new OperationHelper.ResourceSupplierInternalException(supplierException); @@ -144,57 +175,47 @@ static void withAsyncSuppliedResource(final Asyn }); } - static void withAsyncConnectionSourceCallableConnection(final AsyncConnectionSource source, - final AsyncCallableWithConnection callable) { - source.getConnection((connection, t) -> { - source.release(); - if (t != null) { - callable.call(null, t); - } else { - callable.call(connection, null); - } - }); - } - - static void withAsyncConnectionSource(final AsyncConnectionSource source, final AsyncCallableWithSource callable) { + static void withAsyncConnectionSource(final AsyncConnectionSource source, + final AsyncCallableWithSource callable) { callable.call(source, null); } static void executeRetryableReadAsync( final AsyncReadBinding binding, + final OperationContext operationContext, final String database, final CommandCreator commandCreator, final Decoder decoder, final CommandReadTransformerAsync transformer, final boolean retryReads, final SingleResultCallback callback) { - executeRetryableReadAsync(binding, binding::getReadConnectionSource, database, commandCreator, + executeRetryableReadAsync(binding, operationContext, binding::getReadConnectionSource, database, commandCreator, decoder, transformer, retryReads, callback); } static void executeRetryableReadAsync( final AsyncReadBinding binding, - final AsyncCallbackSupplier sourceAsyncSupplier, + final OperationContext operationContext, + final AsyncCallbackFunction sourceAsyncFunction, final String database, final CommandCreator commandCreator, final Decoder decoder, final CommandReadTransformerAsync transformer, final boolean retryReads, final SingleResultCallback callback) { - RetryState retryState = initialRetryState(retryReads, binding.getOperationContext().getTimeoutContext()); + RetryState retryState = initialRetryState(retryReads, operationContext.getTimeoutContext()); binding.retain(); - OperationContext operationContext = binding.getOperationContext(); - AsyncCallbackSupplier asyncRead = decorateReadWithRetriesAsync(retryState, binding.getOperationContext(), + AsyncCallbackSupplier asyncRead = decorateReadWithRetriesAsync(retryState, operationContext, (AsyncCallbackSupplier) funcCallback -> - withAsyncSourceAndConnection(sourceAsyncSupplier, false, funcCallback, - (source, connection, releasingCallback) -> { + withAsyncSourceAndConnection(sourceAsyncFunction, false, operationContext, funcCallback, + (source, connection, operationContextWithMinRtt, releasingCallback) -> { if (retryState.breakAndCompleteIfRetryAnd( () -> !OperationHelper.canRetryRead(source.getServerDescription(), - operationContext), + operationContextWithMinRtt), releasingCallback)) { return; } - createReadCommandAndExecuteAsync(retryState, operationContext, source, database, + createReadCommandAndExecuteAsync(retryState, operationContextWithMinRtt, source, database, commandCreator, decoder, transformer, connection, releasingCallback); }) ).whenComplete(binding::release); @@ -203,20 +224,31 @@ static void executeRetryableReadAsync( static void executeCommandAsync( final AsyncWriteBinding binding, + final OperationContext operationContext, final String database, final CommandCreator commandCreator, final CommandWriteTransformerAsync transformer, final SingleResultCallback callback) { Assertions.notNull("binding", binding); - withAsyncSourceAndConnection(binding::getWriteConnectionSource, false, callback, - (source, connection, releasingCallback) -> - executeCommandAsync(binding, database, commandCreator.create( - binding.getOperationContext(), source.getServerDescription(), connection.getDescription()), - connection, transformer, releasingCallback) - ); + withAsyncSourceAndConnection( + binding::getWriteConnectionSource, + false, + operationContext, + callback, + (source, connection, operationContextWithMinRtt, releasingCallback) -> + executeCommandAsync( + binding, + operationContextWithMinRtt, + database, + commandCreator.create( + operationContextWithMinRtt, source.getServerDescription(), connection.getDescription()), + connection, + transformer, + releasingCallback)); } static void executeCommandAsync(final AsyncWriteBinding binding, + final OperationContext operationContext, final String database, final BsonDocument command, final AsyncConnection connection, @@ -226,11 +258,12 @@ static void executeCommandAsync(final AsyncWriteBinding binding, SingleResultCallback addingRetryableLabelCallback = addingRetryableLabelCallback(callback, connection.getDescription().getMaxWireVersion()); connection.commandAsync(database, command, NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), - binding.getOperationContext(), transformingWriteCallback(transformer, connection, addingRetryableLabelCallback)); + operationContext, transformingWriteCallback(transformer, connection, addingRetryableLabelCallback)); } static void executeRetryableWriteAsync( final AsyncWriteBinding binding, + final OperationContext operationContext, final String database, @Nullable final ReadPreference readPreference, final FieldNameValidator fieldNameValidator, @@ -240,9 +273,8 @@ static void executeRetryableWriteAsync( final Function retryCommandModifier, final SingleResultCallback callback) { - RetryState retryState = initialRetryState(true, binding.getOperationContext().getTimeoutContext()); + RetryState retryState = initialRetryState(true, operationContext.getTimeoutContext()); binding.retain(); - OperationContext operationContext = binding.getOperationContext(); AsyncCallbackSupplier asyncWrite = decorateWriteWithRetriesAsync(retryState, operationContext, (AsyncCallbackSupplier) funcCallback -> { @@ -250,14 +282,14 @@ static void executeRetryableWriteAsync( if (!firstAttempt && operationContext.getSessionContext().hasActiveTransaction()) { operationContext.getSessionContext().clearTransactionContext(); } - withAsyncSourceAndConnection(binding::getWriteConnectionSource, true, funcCallback, - (source, connection, releasingCallback) -> { + withAsyncSourceAndConnection(binding::getWriteConnectionSource, true, operationContext, funcCallback, + (source, connection, operationContextWithMinRtt, releasingCallback) -> { int maxWireVersion = connection.getDescription().getMaxWireVersion(); SingleResultCallback addingRetryableLabelCallback = firstAttempt ? releasingCallback : addingRetryableLabelCallback(releasingCallback, maxWireVersion); if (retryState.breakAndCompleteIfRetryAnd(() -> - !OperationHelper.canRetryWrite(connection.getDescription(), operationContext.getSessionContext()), + !OperationHelper.canRetryWrite(connection.getDescription(), operationContextWithMinRtt.getSessionContext()), addingRetryableLabelCallback)) { return; } @@ -268,7 +300,7 @@ static void executeRetryableWriteAsync( Assertions.assertFalse(firstAttempt); return retryCommandModifier.apply(previousAttemptCommand); }).orElseGet(() -> commandCreator.create( - operationContext, + operationContextWithMinRtt, source.getServerDescription(), connection.getDescription())); // attach `maxWireVersion`, `retryableCommandFlag` ASAP because they are used to check whether we should retry @@ -281,7 +313,8 @@ static void executeRetryableWriteAsync( return; } connection.commandAsync(database, command, fieldNameValidator, readPreference, commandResultDecoder, - operationContext, transformingWriteCallback(transformer, connection, addingRetryableLabelCallback)); + operationContextWithMinRtt, + transformingWriteCallback(transformer, connection, addingRetryableLabelCallback)); }); }).whenComplete(binding::release); @@ -307,7 +340,7 @@ static void createReadCommandAndExecuteAsync( return; } connection.commandAsync(database, command, NoOpFieldNameValidator.INSTANCE, source.getReadPreference(), decoder, - operationContext, transformingReadCallback(transformer, source, connection, callback)); + operationContext, transformingReadCallback(transformer, source, connection, operationContext, callback)); } static AsyncCallbackSupplier decorateReadWithRetriesAsync(final RetryState retryState, final OperationContext operationContext, @@ -339,14 +372,21 @@ static CommandWriteTransformerAsync writeConcernErrorTransfo } static CommandReadTransformerAsync> asyncSingleBatchCursorTransformer(final String fieldName) { - return (result, source, connection) -> + return (result, source, connection, operationContext) -> new AsyncSingleBatchCursor<>(BsonDocumentWrapperHelper.toList(result, fieldName), 0); } - static AsyncBatchCursor cursorDocumentToAsyncBatchCursor(final TimeoutMode timeoutMode, final BsonDocument cursorDocument, - final int batchSize, final Decoder decoder, @Nullable final BsonValue comment, final AsyncConnectionSource source, - final AsyncConnection connection) { - return new AsyncCommandBatchCursor<>(timeoutMode, cursorDocument, batchSize, 0, decoder, comment, source, connection); + static AsyncBatchCursor cursorDocumentToAsyncBatchCursor(final TimeoutMode timeoutMode, + final BsonDocument cursorDocument, + final int batchSize, + final Decoder decoder, + @Nullable final BsonValue comment, + final AsyncConnectionSource source, + final AsyncConnection connection, + final OperationContext operationContext) { + return new AsyncCommandBatchCursor<>(timeoutMode, 0, operationContext, new AsyncCommandCoreCursor<>( + cursorDocument, batchSize, decoder, comment, source, connection + )); } static SingleResultCallback releasingCallback(final SingleResultCallback wrapped, final AsyncConnection connection) { @@ -388,19 +428,37 @@ private static SingleResultCallback transformingWriteCallback(final Co private static class AsyncCallableWithConnectionCallback implements SingleResultCallback { private final AsyncCallableWithConnection callable; + private final OperationContext serverSelectionOperationContext; + private final OperationContext originalOperationContext; - AsyncCallableWithConnectionCallback(final AsyncCallableWithConnection callable) { + AsyncCallableWithConnectionCallback(final AsyncCallableWithConnection callable, + final OperationContext serverSelectionOperationContext, + final OperationContext originalOperationContext) { this.callable = callable; + this.serverSelectionOperationContext = serverSelectionOperationContext; + this.originalOperationContext = originalOperationContext; } @Override public void onResult(@Nullable final AsyncConnectionSource source, @Nullable final Throwable t) { if (t != null) { - callable.call(null, t); + callable.call(null, originalOperationContext, t); } else { - withAsyncConnectionSourceCallableConnection(Assertions.assertNotNull(source), callable); + withAsyncConnectionSourceCallableConnection(assertNotNull(source)); } } + + void withAsyncConnectionSourceCallableConnection(final AsyncConnectionSource source) { + source.getConnection(serverSelectionOperationContext, (connection, t) -> { + source.release(); + ServerDescription serverDescription = source.getServerDescription(); + if (t != null) { + callable.call(null, originalOperationContext.withMinRoundTripTime(serverDescription), t); + } else { + callable.call(connection, originalOperationContext.withMinRoundTripTime(serverDescription), null); + } + }); + } } private static class AsyncCallableWithSourceCallback implements SingleResultCallback { @@ -456,14 +514,14 @@ private static SingleResultCallback addingRetryableLabelCallback(final Si } private static SingleResultCallback transformingReadCallback(final CommandReadTransformerAsync transformer, - final AsyncConnectionSource source, final AsyncConnection connection, final SingleResultCallback callback) { + final AsyncConnectionSource source, final AsyncConnection connection, final OperationContext operationContext, final SingleResultCallback callback) { return (result, t) -> { if (t != null) { callback.onResult(null, t); } else { R transformedResult; try { - transformedResult = transformer.apply(assertNotNull(result), source, connection); + transformedResult = transformer.apply(assertNotNull(result), source, connection, operationContext); } catch (Throwable e) { callback.onResult(null, e); return; diff --git a/driver-core/src/main/com/mongodb/internal/operation/BaseFindAndModifyOperation.java b/driver-core/src/main/com/mongodb/internal/operation/BaseFindAndModifyOperation.java index c5d56fda81c..0f7a7a5196d 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/BaseFindAndModifyOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/BaseFindAndModifyOperation.java @@ -23,6 +23,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.session.SessionContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; @@ -76,8 +77,8 @@ public String getCommandName() { @Override - public T execute(final WriteBinding binding) { - return executeRetryableWrite(binding, getDatabaseName(), null, getFieldNameValidator(), + public T execute(final WriteBinding binding, final OperationContext operationContext) { + return executeRetryableWrite(binding, operationContext, getDatabaseName(), null, getFieldNameValidator(), CommandResultDocumentCodec.create(getDecoder(), "value"), getCommandCreator(), FindAndModifyHelper.transformer(), @@ -85,8 +86,8 @@ public T execute(final WriteBinding binding) { } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - executeRetryableWriteAsync(binding, getDatabaseName(), null, getFieldNameValidator(), + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + executeRetryableWriteAsync(binding, operationContext, getDatabaseName(), null, getFieldNameValidator(), CommandResultDocumentCodec.create(getDecoder(), "value"), getCommandCreator(), FindAndModifyHelper.asyncTransformer(), cmd -> cmd, callback); diff --git a/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java b/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java index c4bd72a4775..103751e7df3 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java @@ -1,3 +1,5 @@ +package com.mongodb.internal.operation; + /* * Copyright 2008-present MongoDB, Inc. * @@ -14,7 +16,6 @@ * limitations under the License. */ -package com.mongodb.internal.operation; import com.mongodb.MongoChangeStreamException; import com.mongodb.MongoException; @@ -23,6 +24,7 @@ import com.mongodb.ServerCursor; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonTimestamp; @@ -55,11 +57,11 @@ *

*/ final class ChangeStreamBatchCursor implements AggregateResponseBatchCursor { - private final ReadBinding binding; + private ReadBinding binding; + private OperationContext operationContext; private final ChangeStreamOperation changeStreamOperation; private final int maxWireVersion; - private final TimeoutContext timeoutContext; - private CommandBatchCursor wrapped; + private CoreCursor wrapped; private BsonDocument resumeToken; private final AtomicBoolean closed; @@ -71,13 +73,14 @@ final class ChangeStreamBatchCursor implements AggregateResponseBatchCursor changeStreamOperation, - final CommandBatchCursor wrapped, + final CoreCursor wrapped, final ReadBinding binding, + final OperationContext operationContext, @Nullable final BsonDocument resumeToken, final int maxWireVersion) { - this.timeoutContext = binding.getOperationContext().getTimeoutContext(); this.changeStreamOperation = changeStreamOperation; this.binding = binding.retain(); + this.operationContext = operationContext.withTimeoutContextOverride(TimeoutContext::withMaxTimeAsMaxAwaitTimeOverride); this.wrapped = wrapped; this.resumeToken = resumeToken; this.maxWireVersion = maxWireVersion; @@ -85,7 +88,7 @@ final class ChangeStreamBatchCursor implements AggregateResponseBatchCursor getWrapped() { + CoreCursor getWrapped() { return wrapped; } @@ -93,7 +96,7 @@ CommandBatchCursor getWrapped() { public boolean hasNext() { return resumeableOperation(commandBatchCursor -> { try { - return commandBatchCursor.hasNext(); + return commandBatchCursor.hasNext(operationContext); } finally { cachePostBatchResumeToken(commandBatchCursor); } @@ -104,7 +107,7 @@ public boolean hasNext() { public List next() { return resumeableOperation(commandBatchCursor -> { try { - return convertAndProduceLastId(commandBatchCursor.next(), changeStreamOperation.getDecoder(), + return convertAndProduceLastId(commandBatchCursor.next(operationContext), changeStreamOperation.getDecoder(), lastId -> resumeToken = lastId); } finally { cachePostBatchResumeToken(commandBatchCursor); @@ -112,6 +115,10 @@ public List next() { }); } + private void restartTimeout() { + operationContext = operationContext.withNewlyStartedTimeout(); + } + @Override public int available() { return wrapped.available(); @@ -121,7 +128,7 @@ public int available() { public List tryNext() { return resumeableOperation(commandBatchCursor -> { try { - List tryNext = commandBatchCursor.tryNext(); + List tryNext = commandBatchCursor.tryNext(operationContext); return tryNext == null ? null : convertAndProduceLastId(tryNext, changeStreamOperation.getDecoder(), lastId -> resumeToken = lastId); } finally { @@ -133,8 +140,7 @@ public List tryNext() { @Override public void close() { if (!closed.getAndSet(true)) { - timeoutContext.resetTimeoutIfPresent(); - wrapped.close(); + wrapped.close(operationContext); binding.release(); } } @@ -184,7 +190,7 @@ public int getMaxWireVersion() { return maxWireVersion; } - private void cachePostBatchResumeToken(final AggregateResponseBatchCursor commandBatchCursor) { + private void cachePostBatchResumeToken(final CoreCursor commandBatchCursor) { if (commandBatchCursor.getPostBatchResumeToken() != null) { resumeToken = commandBatchCursor.getPostBatchResumeToken(); } @@ -210,8 +216,8 @@ static List convertAndProduceLastId(final List rawDocume return results; } - R resumeableOperation(final Function, R> function) { - timeoutContext.resetTimeoutIfPresent(); + R resumeableOperation(final Function, R> function) { + restartTimeout(); try { R result = execute(function); lastOperationTimedOut = false; @@ -222,7 +228,7 @@ R resumeableOperation(final Function R execute(final Function, R> function) { + private R execute(final Function, R> function) { boolean shouldBeResumed = hasPreviousNextTimedOut(); while (true) { if (shouldBeResumed) { @@ -240,13 +246,16 @@ private R execute(final Function { + wrapped.close(operationContextWithDefaultMaxTime); + //TODO why do we initiate server selection here just to get server description and then ignore the selected server further on line 259? So we do two server selections? + withReadConnectionSource(binding, operationContext, (source, operationContextWithMinRtt) -> { changeStreamOperation.setChangeStreamOptionsForResume(resumeToken, source.getServerDescription().getMaxWireVersion()); return null; }); - wrapped = ((ChangeStreamBatchCursor) changeStreamOperation.execute(binding)).getWrapped(); + wrapped = ((ChangeStreamBatchCursor) changeStreamOperation.execute(binding, operationContextWithDefaultMaxTime)).getWrapped(); binding.release(); // release the new change stream batch cursor's reference to the binding } diff --git a/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamOperation.java b/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamOperation.java index f4c896ba6e9..e0ea7063d7f 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamOperation.java @@ -22,11 +22,13 @@ import com.mongodb.client.model.changestream.FullDocument; import com.mongodb.client.model.changestream.FullDocumentBeforeChange; import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.VisibleForTesting; import com.mongodb.internal.async.AsyncBatchCursor; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; import com.mongodb.internal.client.model.changestream.ChangeStreamLevel; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonArray; import org.bson.BsonBoolean; @@ -64,7 +66,7 @@ public class ChangeStreamOperation implements ReadOperationCursor { private BsonTimestamp startAtOperationTime; private boolean showExpandedEvents; - + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) public ChangeStreamOperation(final MongoNamespace namespace, final FullDocument fullDocument, final FullDocumentBeforeChange fullDocumentBeforeChange, final List pipeline, final Decoder decoder) { this(namespace, fullDocument, fullDocumentBeforeChange, pipeline, decoder, ChangeStreamLevel.COLLECTION); @@ -198,28 +200,34 @@ public String getCommandName() { } @Override - public BatchCursor execute(final ReadBinding binding) { - TimeoutContext timeoutContext = binding.getOperationContext().getTimeoutContext(); - CommandBatchCursor cursor = ((CommandBatchCursor) getAggregateOperation(timeoutContext).execute(binding)) - .disableTimeoutResetWhenClosing(); - - return new ChangeStreamBatchCursor<>(ChangeStreamOperation.this, cursor, binding, - setChangeStreamOptions(cursor.getPostBatchResumeToken(), cursor.getOperationTime(), - cursor.getMaxWireVersion(), cursor.isFirstBatchEmpty()), cursor.getMaxWireVersion()); + public BatchCursor execute(final ReadBinding binding, final OperationContext operationContext) { + CoreCursor cursor = ((CommandBatchCursor) getAggregateOperation(operationContext.getTimeoutContext()) + .execute(binding, operationContext)) + .getWrapped(); + + return new ChangeStreamBatchCursor<>(ChangeStreamOperation.this, + cursor, + binding, + operationContext, + setChangeStreamOptions( + cursor.getPostBatchResumeToken(), + cursor.getOperationTime(), + cursor.getMaxWireVersion(), + cursor.isFirstBatchEmpty()), + cursor.getMaxWireVersion()); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { - TimeoutContext timeoutContext = binding.getOperationContext().getTimeoutContext(); - getAggregateOperation(timeoutContext).executeAsync(binding, (result, t) -> { + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback> callback) { + getAggregateOperation(operationContext.getTimeoutContext()).executeAsync(binding, operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { - AsyncCommandBatchCursor cursor = ((AsyncCommandBatchCursor) assertNotNull(result)) - .disableTimeoutResetWhenClosing(); + AsyncCoreCursor cursor = ((AsyncCommandBatchCursor) assertNotNull(result)) + .getWrapped(); callback.onResult(new AsyncChangeStreamBatchCursor<>(ChangeStreamOperation.this, cursor, binding, - setChangeStreamOptions(cursor.getPostBatchResumeToken(), cursor.getOperationTime(), + operationContext, setChangeStreamOptions(cursor.getPostBatchResumeToken(), cursor.getOperationTime(), cursor.getMaxWireVersion(), cursor.isFirstBatchEmpty()), cursor.getMaxWireVersion()), null); } }); diff --git a/driver-core/src/main/com/mongodb/internal/operation/ClientBulkWriteOperation.java b/driver-core/src/main/com/mongodb/internal/operation/ClientBulkWriteOperation.java index 2b9e79f6f06..14adf9727cd 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ClientBulkWriteOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ClientBulkWriteOperation.java @@ -183,13 +183,13 @@ public String getCommandName() { } @Override - public ClientBulkWriteResult execute(final WriteBinding binding) throws ClientBulkWriteException { - WriteConcern effectiveWriteConcern = validateAndGetEffectiveWriteConcern(binding.getOperationContext().getSessionContext()); + public ClientBulkWriteResult execute(final WriteBinding binding, final OperationContext operationContext) throws ClientBulkWriteException { + WriteConcern effectiveWriteConcern = validateAndGetEffectiveWriteConcern(operationContext.getSessionContext()); ResultAccumulator resultAccumulator = new ResultAccumulator(); MongoException transformedTopLevelError = null; try { - executeAllBatches(effectiveWriteConcern, binding, resultAccumulator); + executeAllBatches(effectiveWriteConcern, binding, operationContext, resultAccumulator); } catch (MongoException topLevelError) { transformedTopLevelError = transformWriteException(topLevelError); } @@ -199,13 +199,14 @@ public ClientBulkWriteResult execute(final WriteBinding binding) throws ClientBu @Override public void executeAsync(final AsyncWriteBinding binding, + final OperationContext operationContext, final SingleResultCallback finalCallback) { - WriteConcern effectiveWriteConcern = validateAndGetEffectiveWriteConcern(binding.getOperationContext().getSessionContext()); + WriteConcern effectiveWriteConcern = validateAndGetEffectiveWriteConcern(operationContext.getSessionContext()); ResultAccumulator resultAccumulator = new ResultAccumulator(); MutableValue transformedTopLevelError = new MutableValue<>(); beginAsync().thenSupply(c -> { - executeAllBatchesAsync(effectiveWriteConcern, binding, resultAccumulator, c); + executeAllBatchesAsync(effectiveWriteConcern, binding, operationContext, resultAccumulator, c); }).onErrorIf(topLevelError -> topLevelError instanceof MongoException, (topLevelError, c) -> { transformedTopLevelError.set(transformWriteException((MongoException) topLevelError)); c.complete(c); @@ -226,27 +227,29 @@ public void executeAsync(final AsyncWriteBinding binding, private void executeAllBatches( final WriteConcern effectiveWriteConcern, final WriteBinding binding, + final OperationContext operationContext, final ResultAccumulator resultAccumulator) throws MongoException { Integer nextBatchStartModelIndex = INITIAL_BATCH_MODEL_START_INDEX; do { - nextBatchStartModelIndex = executeBatch(nextBatchStartModelIndex, effectiveWriteConcern, binding, resultAccumulator); + nextBatchStartModelIndex = executeBatch(nextBatchStartModelIndex, effectiveWriteConcern, binding, operationContext, resultAccumulator); } while (nextBatchStartModelIndex != null); } /** - * @see #executeAllBatches(WriteConcern, WriteBinding, ResultAccumulator) + * @see #executeAllBatches(WriteConcern, WriteBinding, OperationContext, ResultAccumulator) */ private void executeAllBatchesAsync( final WriteConcern effectiveWriteConcern, final AsyncWriteBinding binding, + final OperationContext operationContext, final ResultAccumulator resultAccumulator, final SingleResultCallback finalCallback) { MutableValue nextBatchStartModelIndex = new MutableValue<>(INITIAL_BATCH_MODEL_START_INDEX); beginAsync().thenRunDoWhileLoop(iterationCallback -> { beginAsync().thenSupply(c -> { - executeBatchAsync(nextBatchStartModelIndex.get(), effectiveWriteConcern, binding, resultAccumulator, c); + executeBatchAsync(nextBatchStartModelIndex.get(), effectiveWriteConcern, binding, operationContext, resultAccumulator, c); }).thenApply((nextBatchStartModelIdx, c) -> { nextBatchStartModelIndex.set(nextBatchStartModelIdx); c.complete(c); @@ -265,10 +268,10 @@ private Integer executeBatch( final int batchStartModelIndex, final WriteConcern effectiveWriteConcern, final WriteBinding binding, + final OperationContext operationContext, final ResultAccumulator resultAccumulator) { List unexecutedModels = models.subList(batchStartModelIndex, models.size()); assertFalse(unexecutedModels.isEmpty()); - OperationContext operationContext = binding.getOperationContext(); SessionContext sessionContext = operationContext.getSessionContext(); TimeoutContext timeoutContext = operationContext.getTimeoutContext(); RetryState retryState = initialRetryState(retryWritesSetting, timeoutContext); @@ -281,7 +284,7 @@ private Integer executeBatch( // If connection pinning is required, `binding` handles that, // and `ClientSession`, `TransactionContext` are aware of that. () -> withSourceAndConnection(binding::getWriteConnectionSource, true, - (connectionSource, connection) -> { + (connectionSource, connection, commandOperationContext) -> { ConnectionDescription connectionDescription = connection.getDescription(); boolean effectiveRetryWrites = isRetryableWrite( retryWritesSetting, effectiveWriteConcern, connectionDescription, sessionContext); @@ -293,8 +296,8 @@ private Integer executeBatch( retryState, effectiveRetryWrites, effectiveWriteConcern, sessionContext, unexecutedModels, batchEncoder, () -> retryState.attach(AttachmentKeys.retryableCommandFlag(), true, true)); return executeBulkWriteCommandAndExhaustOkResponse( - retryState, connectionSource, connection, bulkWriteCommand, effectiveWriteConcern, operationContext); - }) + retryState, connectionSource, connection, bulkWriteCommand, effectiveWriteConcern, commandOperationContext); + }, operationContext) ); try { @@ -318,17 +321,17 @@ private Integer executeBatch( } /** - * @see #executeBatch(int, WriteConcern, WriteBinding, ResultAccumulator) + * @see #executeBatch(int, WriteConcern, WriteBinding, OperationContext, ResultAccumulator) */ private void executeBatchAsync( final int batchStartModelIndex, final WriteConcern effectiveWriteConcern, final AsyncWriteBinding binding, + final OperationContext operationContext, final ResultAccumulator resultAccumulator, final SingleResultCallback finalCallback) { List unexecutedModels = models.subList(batchStartModelIndex, models.size()); assertFalse(unexecutedModels.isEmpty()); - OperationContext operationContext = binding.getOperationContext(); SessionContext sessionContext = operationContext.getSessionContext(); TimeoutContext timeoutContext = operationContext.getTimeoutContext(); RetryState retryState = initialRetryState(retryWritesSetting, timeoutContext); @@ -340,8 +343,8 @@ private void executeBatchAsync( // and it is allowed by https://jira.mongodb.org/browse/DRIVERS-2502. // If connection pinning is required, `binding` handles that, // and `ClientSession`, `TransactionContext` are aware of that. - funcCallback -> withAsyncSourceAndConnection(binding::getWriteConnectionSource, true, funcCallback, - (connectionSource, connection, resultCallback) -> { + funcCallback -> withAsyncSourceAndConnection(binding::getWriteConnectionSource, true, operationContext, funcCallback, + (connectionSource, connection, operationContextWithMinRtt, resultCallback) -> { ConnectionDescription connectionDescription = connection.getDescription(); boolean effectiveRetryWrites = isRetryableWrite( retryWritesSetting, effectiveWriteConcern, connectionDescription, sessionContext); @@ -353,7 +356,7 @@ private void executeBatchAsync( retryState, effectiveRetryWrites, effectiveWriteConcern, sessionContext, unexecutedModels, batchEncoder, () -> retryState.attach(AttachmentKeys.retryableCommandFlag(), true, true)); executeBulkWriteCommandAndExhaustOkResponseAsync( - retryState, connectionSource, connection, bulkWriteCommand, effectiveWriteConcern, operationContext, resultCallback); + retryState, connectionSource, connection, bulkWriteCommand, effectiveWriteConcern, operationContextWithMinRtt, resultCallback); }) ); @@ -413,7 +416,7 @@ private ExhaustiveClientBulkWriteCommandOkResponse executeBulkWriteCommandAndExh return null; } List> cursorExhaustBatches = doWithRetriesDisabledForCommand(retryState, "getMore", () -> - exhaustBulkWriteCommandOkResponseCursor(connectionSource, connection, bulkWriteCommandOkResponse)); + exhaustBulkWriteCommandOkResponseCursor(connectionSource, operationContext, connection, bulkWriteCommandOkResponse)); return createExhaustiveClientBulkWriteCommandOkResponse( bulkWriteCommandOkResponse, cursorExhaustBatches, @@ -448,7 +451,7 @@ private void executeBulkWriteCommandAndExhaustOkResponseAsync( } beginAsync().>>thenSupply(c -> { doWithRetriesDisabledForCommandAsync(retryState, "getMore", (c1) -> { - exhaustBulkWriteCommandOkResponseCursorAsync(connectionSource, connection, bulkWriteCommandOkResponse, c1); + exhaustBulkWriteCommandOkResponseCursorAsync(connectionSource, connection, bulkWriteCommandOkResponse, operationContext, c1); }, c); }).thenApply((cursorExhaustBatches, c) -> { c.complete(createExhaustiveClientBulkWriteCommandOkResponse( @@ -514,6 +517,7 @@ private void doWithRetriesDisabledForCommandAsync( private List> exhaustBulkWriteCommandOkResponseCursor( final ConnectionSource connectionSource, + final OperationContext operationContext, final Connection connection, final BsonDocument response) { try (CommandBatchCursor cursor = cursorDocumentToBatchCursor( @@ -523,7 +527,8 @@ private List> exhaustBulkWriteCommandOkResponseCursor( codecRegistry.get(BsonDocument.class), options.getComment().orElse(null), connectionSource, - connection)) { + connection, + operationContext)) { return cursor.exhaust(); } @@ -532,6 +537,7 @@ private List> exhaustBulkWriteCommandOkResponseCursor( private void exhaustBulkWriteCommandOkResponseCursorAsync(final AsyncConnectionSource connectionSource, final AsyncConnection connection, final BsonDocument bulkWriteCommandOkResponse, + final OperationContext operationContext, final SingleResultCallback>> finalCallback) { AsyncBatchCursor cursor = cursorDocumentToAsyncBatchCursor( TimeoutMode.CURSOR_LIFETIME, @@ -540,7 +546,8 @@ private void exhaustBulkWriteCommandOkResponseCursorAsync(final AsyncConnectionS codecRegistry.get(BsonDocument.class), options.getComment().orElse(null), connectionSource, - connection); + connection, + operationContext); beginAsync().>>thenSupply(callback -> { cursor.exhaust(callback); @@ -838,7 +845,7 @@ Integer onBulkWriteCommandOkResponseOrNoResponse( } /** - * @return See {@link #executeBatch(int, WriteConcern, WriteBinding, ResultAccumulator)}. + * @return See {@link #executeBatch(int, WriteConcern, WriteBinding, OperationContext, ResultAccumulator)}. */ @Nullable Integer onBulkWriteCommandOkResponseWithWriteConcernError( @@ -852,7 +859,7 @@ Integer onBulkWriteCommandOkResponseWithWriteConcernError( } /** - * @return See {@link #executeBatch(int, WriteConcern, WriteBinding, ResultAccumulator)}. + * @return See {@link #executeBatch(int, WriteConcern, WriteBinding,OperationContext, ResultAccumulator)}. */ @Nullable private Integer onBulkWriteCommandOkResponseOrNoResponse( @@ -1165,7 +1172,7 @@ Map getInsertModelDocumentIds() { } /** - * Exactly one instance must be used per {@linkplain #executeBatch(int, WriteConcern, WriteBinding, ResultAccumulator) batch}. + * Exactly one instance must be used per {@linkplain #executeBatch(int, WriteConcern, WriteBinding, OperationContext, ResultAccumulator) batch}. */ @VisibleForTesting(otherwise = PRIVATE) public final class BatchEncoder { diff --git a/driver-core/src/main/com/mongodb/internal/operation/CommandBatchCursor.java b/driver-core/src/main/com/mongodb/internal/operation/CommandBatchCursor.java index d201976e5ed..78de7f33b93 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CommandBatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CommandBatchCursor.java @@ -16,150 +16,58 @@ package com.mongodb.internal.operation; -import com.mongodb.MongoCommandException; -import com.mongodb.MongoException; -import com.mongodb.MongoNamespace; -import com.mongodb.MongoOperationTimeoutException; -import com.mongodb.MongoSocketException; -import com.mongodb.ReadPreference; import com.mongodb.ServerAddress; import com.mongodb.ServerCursor; -import com.mongodb.annotations.ThreadSafe; import com.mongodb.client.cursor.TimeoutMode; -import com.mongodb.connection.ConnectionDescription; -import com.mongodb.connection.ServerType; -import com.mongodb.internal.TimeoutContext; -import com.mongodb.internal.VisibleForTesting; -import com.mongodb.internal.binding.ConnectionSource; -import com.mongodb.internal.connection.Connection; import com.mongodb.internal.connection.OperationContext; -import com.mongodb.internal.validator.NoOpFieldNameValidator; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonTimestamp; -import org.bson.BsonValue; -import org.bson.codecs.BsonDocumentCodec; -import org.bson.codecs.Decoder; import java.util.List; -import java.util.NoSuchElementException; -import java.util.function.Consumer; -import java.util.function.Supplier; - -import static com.mongodb.assertions.Assertions.assertNotNull; -import static com.mongodb.assertions.Assertions.assertTrue; -import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.FIRST_BATCH; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.MESSAGE_IF_CLOSED_AS_CURSOR; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.MESSAGE_IF_CLOSED_AS_ITERATOR; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.NEXT_BATCH; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.getKillCursorsCommand; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.getMoreCommandDocument; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.logCommandCursorResult; -import static com.mongodb.internal.operation.CommandBatchCursorHelper.translateCommandException; class CommandBatchCursor implements AggregateResponseBatchCursor { - private final MongoNamespace namespace; - private final Decoder decoder; - @Nullable - private final BsonValue comment; - private final int maxWireVersion; - private final boolean firstBatchEmpty; - private final ResourceManager resourceManager; - private final OperationContext operationContext; private final TimeoutMode timeoutMode; - - private int batchSize; - private CommandCursorResult commandCursorResult; - @Nullable - private List nextBatch; - private boolean resetTimeoutWhenClosing; + private OperationContext operationContext; + private CoreCursor wrapped; CommandBatchCursor( final TimeoutMode timeoutMode, - final BsonDocument commandCursorDocument, - final int batchSize, final long maxTimeMS, - final Decoder decoder, - @Nullable final BsonValue comment, - final ConnectionSource connectionSource, - final Connection connection) { - ConnectionDescription connectionDescription = connection.getDescription(); - this.commandCursorResult = toCommandCursorResult(connectionDescription.getServerAddress(), FIRST_BATCH, commandCursorDocument); - this.namespace = commandCursorResult.getNamespace(); - this.batchSize = batchSize; - this.decoder = decoder; - this.comment = comment; - this.maxWireVersion = connectionDescription.getMaxWireVersion(); - this.firstBatchEmpty = commandCursorResult.getResults().isEmpty(); - operationContext = connectionSource.getOperationContext(); + final long maxTimeMs, + final OperationContext operationContext, + final CoreCursor wrapped) { + this.operationContext = operationContext.withTimeoutContextOverride(timeoutContext -> + timeoutContext.withMaxTimeOverride(maxTimeMs)); this.timeoutMode = timeoutMode; - - operationContext.getTimeoutContext().setMaxTimeOverride(maxTimeMS); - - Connection connectionToPin = connectionSource.getServerDescription().getType() == ServerType.LOAD_BALANCER ? connection : null; - resourceManager = new ResourceManager(namespace, connectionSource, connectionToPin, commandCursorResult.getServerCursor()); - resetTimeoutWhenClosing = true; + this.wrapped = wrapped; } @Override public boolean hasNext() { - return assertNotNull(resourceManager.execute(MESSAGE_IF_CLOSED_AS_CURSOR, this::doHasNext)); - } - - private boolean doHasNext() { - if (nextBatch != null) { - return true; - } - - checkTimeoutModeAndResetTimeoutContextIfIteration(); - while (resourceManager.getServerCursor() != null) { - getMore(); - if (!resourceManager.operable()) { - throw new IllegalStateException(MESSAGE_IF_CLOSED_AS_CURSOR); - } - if (nextBatch != null) { - return true; - } - } - - return false; + resetTimeout(); + return wrapped.hasNext(operationContext); } @Override public List next() { - return assertNotNull(resourceManager.execute(MESSAGE_IF_CLOSED_AS_ITERATOR, this::doNext)); + resetTimeout(); + return wrapped.next(operationContext); } @Override public int available() { - return !resourceManager.operable() || nextBatch == null ? 0 : nextBatch.size(); - } - - @Nullable - private List doNext() { - if (!doHasNext()) { - throw new NoSuchElementException(); - } - - List retVal = nextBatch; - nextBatch = null; - return retVal; - } - - @VisibleForTesting(otherwise = PRIVATE) - boolean isClosed() { - return !resourceManager.operable(); + return wrapped.available(); } @Override public void setBatchSize(final int batchSize) { - this.batchSize = batchSize; + wrapped.setBatchSize(batchSize); } @Override public int getBatchSize() { - return batchSize; + return wrapped.getBatchSize(); } @Override @@ -169,225 +77,60 @@ public void remove() { @Override public void close() { - resourceManager.close(); + operationContext = operationContext.withTimeoutContextOverride(timeoutContext -> timeoutContext + .withNewlyStartedTimeout() + .withDefaultMaxTime()); + wrapped.close(operationContext); } @Nullable @Override public List tryNext() { - return resourceManager.execute(MESSAGE_IF_CLOSED_AS_CURSOR, () -> { - if (!tryHasNext()) { - return null; - } - return doNext(); - }); - } - - private boolean tryHasNext() { - if (nextBatch != null) { - return true; - } - - if (resourceManager.getServerCursor() != null) { - getMore(); - } - - return nextBatch != null; + resetTimeout(); + return wrapped.tryNext(operationContext); } @Override @Nullable public ServerCursor getServerCursor() { - if (!resourceManager.operable()) { - throw new IllegalStateException(MESSAGE_IF_CLOSED_AS_ITERATOR); - } - return resourceManager.getServerCursor(); + return wrapped.getServerCursor(); } @Override public ServerAddress getServerAddress() { - if (!resourceManager.operable()) { - throw new IllegalStateException(MESSAGE_IF_CLOSED_AS_ITERATOR); - } - - return commandCursorResult.getServerAddress(); + return wrapped.getServerAddress(); } @Override + @Nullable public BsonDocument getPostBatchResumeToken() { - return commandCursorResult.getPostBatchResumeToken(); + return wrapped.getPostBatchResumeToken(); } @Override + @Nullable public BsonTimestamp getOperationTime() { - return commandCursorResult.getOperationTime(); + return wrapped.getOperationTime(); } @Override public boolean isFirstBatchEmpty() { - return firstBatchEmpty; + return wrapped.isFirstBatchEmpty(); } @Override public int getMaxWireVersion() { - return maxWireVersion; + return wrapped.getMaxWireVersion(); } - void checkTimeoutModeAndResetTimeoutContextIfIteration() { + private void resetTimeout() { if (timeoutMode == TimeoutMode.ITERATION) { - operationContext.getTimeoutContext().resetTimeoutIfPresent(); + operationContext = operationContext.withNewlyStartedTimeout(); } } - private void getMore() { - ServerCursor serverCursor = assertNotNull(resourceManager.getServerCursor()); - resourceManager.executeWithConnection(connection -> { - ServerCursor nextServerCursor; - try { - this.commandCursorResult = toCommandCursorResult(connection.getDescription().getServerAddress(), NEXT_BATCH, - assertNotNull( - connection.command(namespace.getDatabaseName(), - getMoreCommandDocument(serverCursor.getId(), connection.getDescription(), namespace, batchSize, comment), - NoOpFieldNameValidator.INSTANCE, - ReadPreference.primary(), - CommandResultDocumentCodec.create(decoder, NEXT_BATCH), - assertNotNull(resourceManager.getConnectionSource()).getOperationContext()))); - nextServerCursor = commandCursorResult.getServerCursor(); - } catch (MongoCommandException e) { - throw translateCommandException(e, serverCursor); - } - resourceManager.setServerCursor(nextServerCursor); - }); - } - - private CommandCursorResult toCommandCursorResult(final ServerAddress serverAddress, final String fieldNameContainingBatch, - final BsonDocument commandCursorDocument) { - CommandCursorResult commandCursorResult = new CommandCursorResult<>(serverAddress, fieldNameContainingBatch, - commandCursorDocument); - logCommandCursorResult(commandCursorResult); - this.nextBatch = commandCursorResult.getResults().isEmpty() ? null : commandCursorResult.getResults(); - return commandCursorResult; - } - - /** - * Configures the cursor to {@link #close()} - * without {@linkplain TimeoutContext#resetTimeoutIfPresent() resetting} its {@linkplain TimeoutContext#getTimeout() timeout}. - * This is useful when managing the {@link #close()} behavior externally. - */ - CommandBatchCursor disableTimeoutResetWhenClosing() { - resetTimeoutWhenClosing = false; - return this; - } - - @ThreadSafe - private final class ResourceManager extends CursorResourceManager { - ResourceManager( - final MongoNamespace namespace, - final ConnectionSource connectionSource, - @Nullable final Connection connectionToPin, - @Nullable final ServerCursor serverCursor) { - super(namespace, connectionSource, connectionToPin, serverCursor); - } - - /** - * Thread-safe. - * Executes {@code operation} within the {@link #tryStartOperation()}/{@link #endOperation()} bounds. - * - * @throws IllegalStateException If {@linkplain CommandBatchCursor#close() closed}. - */ - @Nullable - R execute(final String exceptionMessageIfClosed, final Supplier operation) throws IllegalStateException { - if (!tryStartOperation()) { - throw new IllegalStateException(exceptionMessageIfClosed); - } - try { - return operation.get(); - } finally { - endOperation(); - } - } - - @Override - void markAsPinned(final Connection connectionToPin, final Connection.PinningMode pinningMode) { - connectionToPin.markAsPinned(pinningMode); - } - - @Override - void doClose() { - TimeoutContext timeoutContext = operationContext.getTimeoutContext(); - timeoutContext.resetToDefaultMaxTime(); - if (resetTimeoutWhenClosing) { - timeoutContext.doWithResetTimeout(this::releaseResources); - } else { - releaseResources(); - } - } - - private void releaseResources() { - try { - if (isSkipReleasingServerResourcesOnClose()) { - unsetServerCursor(); - } - if (super.getServerCursor() != null) { - Connection connection = getConnection(); - try { - releaseServerResources(connection); - } finally { - connection.release(); - } - } - } catch (MongoException e) { - // ignore exceptions when releasing server resources - } finally { - // guarantee that regardless of exceptions, `serverCursor` is null and client resources are released - unsetServerCursor(); - releaseClientResources(); - } - } - - void executeWithConnection(final Consumer action) { - Connection connection = getConnection(); - try { - action.accept(connection); - } catch (MongoSocketException e) { - onCorruptedConnection(connection, e); - throw e; - } catch (MongoOperationTimeoutException e) { - Throwable cause = e.getCause(); - if (cause instanceof MongoSocketException) { - onCorruptedConnection(connection, (MongoSocketException) cause); - } - throw e; - } finally { - connection.release(); - } - } - - private Connection getConnection() { - assertTrue(getState() != State.IDLE); - Connection pinnedConnection = getPinnedConnection(); - if (pinnedConnection == null) { - return assertNotNull(getConnectionSource()).getConnection(); - } else { - return pinnedConnection.retain(); - } - } - - private void releaseServerResources(final Connection connection) { - try { - ServerCursor localServerCursor = super.getServerCursor(); - if (localServerCursor != null) { - killServerCursor(getNamespace(), localServerCursor, connection); - } - } finally { - unsetServerCursor(); - } - } - - private void killServerCursor(final MongoNamespace namespace, final ServerCursor localServerCursor, - final Connection localConnection) { - localConnection.command(namespace.getDatabaseName(), getKillCursorsCommand(namespace, localServerCursor), - NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), operationContext); - } + CoreCursor getWrapped() { + return wrapped; } } + diff --git a/driver-core/src/main/com/mongodb/internal/operation/CommandCoreCursor.java b/driver-core/src/main/com/mongodb/internal/operation/CommandCoreCursor.java new file mode 100644 index 00000000000..624c4b9bb83 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/operation/CommandCoreCursor.java @@ -0,0 +1,364 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.operation; + +import com.mongodb.MongoCommandException; +import com.mongodb.MongoException; +import com.mongodb.MongoNamespace; +import com.mongodb.MongoOperationTimeoutException; +import com.mongodb.MongoSocketException; +import com.mongodb.ReadPreference; +import com.mongodb.ServerAddress; +import com.mongodb.ServerCursor; +import com.mongodb.annotations.ThreadSafe; +import com.mongodb.connection.ConnectionDescription; +import com.mongodb.connection.ServerType; +import com.mongodb.internal.VisibleForTesting; +import com.mongodb.internal.binding.ConnectionSource; +import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.OperationContext; +import com.mongodb.internal.validator.NoOpFieldNameValidator; +import com.mongodb.lang.Nullable; +import org.bson.BsonDocument; +import org.bson.BsonTimestamp; +import org.bson.BsonValue; +import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.Decoder; + +import java.util.List; +import java.util.NoSuchElementException; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.FIRST_BATCH; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.MESSAGE_IF_CLOSED_AS_CURSOR; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.MESSAGE_IF_CLOSED_AS_ITERATOR; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.NEXT_BATCH; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.getKillCursorsCommand; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.getMoreCommandDocument; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.logCommandCursorResult; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.translateCommandException; + +class CommandCoreCursor implements CoreCursor { + + private final MongoNamespace namespace; + private final Decoder decoder; + @Nullable + private final BsonValue comment; + private final int maxWireVersion; + private final boolean firstBatchEmpty; + private final ResourceManager resourceManager; + + private int batchSize; + private CommandCursorResult commandCursorResult; + @Nullable + private List nextBatch; + + CommandCoreCursor( + final BsonDocument commandCursorDocument, + final int batchSize, + final Decoder decoder, + @Nullable final BsonValue comment, + final ConnectionSource connectionSource, + final Connection connection) { + ConnectionDescription connectionDescription = connection.getDescription(); + this.commandCursorResult = toCommandCursorResult(connectionDescription.getServerAddress(), FIRST_BATCH, commandCursorDocument); + this.namespace = commandCursorResult.getNamespace(); + this.batchSize = batchSize; + this.decoder = decoder; + this.comment = comment; + this.maxWireVersion = connectionDescription.getMaxWireVersion(); + this.firstBatchEmpty = commandCursorResult.getResults().isEmpty(); + + Connection connectionToPin = connectionSource.getServerDescription().getType() == ServerType.LOAD_BALANCER ? connection : null; + resourceManager = new ResourceManager(namespace, connectionSource, connectionToPin, commandCursorResult.getServerCursor()); + } + + @Override + public boolean hasNext(final OperationContext operationContext) { + return assertNotNull(resourceManager.execute(MESSAGE_IF_CLOSED_AS_CURSOR, () -> doHasNext(operationContext), operationContext)); + } + + + private boolean doHasNext(final OperationContext operationContext) { + if (nextBatch != null) { + return true; + } + + while (resourceManager.getServerCursor() != null) { + getMore(operationContext); + if (!resourceManager.operable()) { + throw new IllegalStateException(MESSAGE_IF_CLOSED_AS_CURSOR); + } + if (nextBatch != null) { + return true; + } + } + + return false; + } + + @Override + public List next(final OperationContext operationContext) { + return assertNotNull(resourceManager.execute(MESSAGE_IF_CLOSED_AS_ITERATOR, () -> doNext(operationContext), operationContext)); + } + + @Override + public int available() { + return !resourceManager.operable() || nextBatch == null ? 0 : nextBatch.size(); + } + + @Nullable + private List doNext(final OperationContext operationContext) { + if (!doHasNext(operationContext)) { + throw new NoSuchElementException(); + } + + List retVal = nextBatch; + nextBatch = null; + return retVal; + } + + @VisibleForTesting(otherwise = PRIVATE) + boolean isClosed() { + return !resourceManager.operable(); + } + + @Override + public void setBatchSize(final int batchSize) { + this.batchSize = batchSize; + } + + @Override + public int getBatchSize() { + return batchSize; + } + + + @Override + public void close(final OperationContext operationContext) { + resourceManager.close(operationContext); + } + + @Nullable + @Override + public List tryNext(final OperationContext operationContext) { + return resourceManager.execute(MESSAGE_IF_CLOSED_AS_CURSOR, () -> { + if (!tryHasNext(operationContext)) { + return null; + } + return doNext(operationContext); + }, operationContext); + } + + private boolean tryHasNext(final OperationContext operationContext) { + if (nextBatch != null) { + return true; + } + + if (resourceManager.getServerCursor() != null) { + getMore(operationContext); + } + + return nextBatch != null; + } + + @Override + @Nullable + public ServerCursor getServerCursor() { + if (!resourceManager.operable()) { + throw new IllegalStateException(MESSAGE_IF_CLOSED_AS_ITERATOR); + } + return resourceManager.getServerCursor(); + } + + @Override + public ServerAddress getServerAddress() { + if (!resourceManager.operable()) { + throw new IllegalStateException(MESSAGE_IF_CLOSED_AS_ITERATOR); + } + + return commandCursorResult.getServerAddress(); + } + + @Override + public BsonDocument getPostBatchResumeToken() { + return commandCursorResult.getPostBatchResumeToken(); + } + + @Override + public BsonTimestamp getOperationTime() { + return commandCursorResult.getOperationTime(); + } + + @Override + public boolean isFirstBatchEmpty() { + return firstBatchEmpty; + } + + @Override + public int getMaxWireVersion() { + return maxWireVersion; + } + + private void getMore(final OperationContext operationContext) { + ServerCursor serverCursor = assertNotNull(resourceManager.getServerCursor()); + resourceManager.executeWithConnection(connection -> { + ServerCursor nextServerCursor; + try { + this.commandCursorResult = toCommandCursorResult(connection.getDescription().getServerAddress(), NEXT_BATCH, + assertNotNull( + connection.command(namespace.getDatabaseName(), + getMoreCommandDocument(serverCursor.getId(), connection.getDescription(), namespace, batchSize, + comment), + NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), + CommandResultDocumentCodec.create(decoder, NEXT_BATCH), + operationContext))); + nextServerCursor = commandCursorResult.getServerCursor(); + } catch (MongoCommandException e) { + throw translateCommandException(e, serverCursor); + } + resourceManager.setServerCursor(nextServerCursor); + }, operationContext); + } + + private CommandCursorResult toCommandCursorResult(final ServerAddress serverAddress, final String fieldNameContainingBatch, + final BsonDocument commandCursorDocument) { + CommandCursorResult commandCursorResult = new CommandCursorResult<>(serverAddress, fieldNameContainingBatch, + commandCursorDocument); + logCommandCursorResult(commandCursorResult); + this.nextBatch = commandCursorResult.getResults().isEmpty() ? null : commandCursorResult.getResults(); + return commandCursorResult; + } + + + @ThreadSafe + private final class ResourceManager extends CursorResourceManagerNew { + ResourceManager( + final MongoNamespace namespace, + final ConnectionSource connectionSource, + @Nullable final Connection connectionToPin, + @Nullable final ServerCursor serverCursor) { + super(namespace, connectionSource, connectionToPin, serverCursor); + } + + /** + * Thread-safe. + */ + @Nullable + R execute(final String exceptionMessageIfClosed, final Supplier operation, final OperationContext operationContext) + throws IllegalStateException { + if (!tryStartOperation()) { + throw new IllegalStateException(exceptionMessageIfClosed); + } + try { + return operation.get(); + } finally { + endOperation(operationContext); + } + } + + @Override + void markAsPinned(final Connection connectionToPin, final Connection.PinningMode pinningMode) { + connectionToPin.markAsPinned(pinningMode); + } + + @Override + void doClose(final OperationContext operationContext) { + releaseResources(operationContext); + } + + private void releaseResources(final OperationContext operationContext) { + try { + if (isSkipReleasingServerResourcesOnClose()) { + unsetServerCursor(); + } + if (super.getServerCursor() != null) { + Connection connection = getConnection(operationContext); + try { + releaseServerResources(connection, operationContext); + } finally { + connection.release(); + } + } + } catch (MongoException e) { + // ignore exceptions when releasing server resources + } finally { + // guarantee that regardless of exceptions, `serverCursor` is null and client resources are released + unsetServerCursor(); + releaseClientResources(); + } + } + + void executeWithConnection(final Consumer action, final OperationContext operationContext) { + Connection connection = getConnection(operationContext); + try { + action.accept(connection); + } catch (MongoSocketException e) { + onCorruptedConnection(connection, e); + throw e; + } catch (MongoOperationTimeoutException e) { + Throwable cause = e.getCause(); + if (cause instanceof MongoSocketException) { + onCorruptedConnection(connection, (MongoSocketException) cause); + } + throw e; + } finally { + connection.release(); + } + } + + private Connection getConnection(final OperationContext operationContext) { + assertTrue(getState() != State.IDLE); + Connection pinnedConnection = getPinnedConnection(); + if (pinnedConnection == null) { + return assertNotNull(getConnectionSource()).getConnection(operationContext); + } else { + return pinnedConnection.retain(); + } + } + + private void releaseServerResources(final Connection connection, final OperationContext operationContext) { + try { + ServerCursor localServerCursor = super.getServerCursor(); + if (localServerCursor != null) { + killServerCursor(getNamespace(), localServerCursor, connection, operationContext); + } + } finally { + unsetServerCursor(); + } + } + + private void killServerCursor( + final MongoNamespace namespace, + final ServerCursor localServerCursor, + final Connection localConnection, + final OperationContext operationContext) { + localConnection.command( + namespace.getDatabaseName(), + getKillCursorsCommand(namespace, localServerCursor), + NoOpFieldNameValidator.INSTANCE, + ReadPreference.primary(), + new BsonDocumentCodec(), + operationContext); + } + } +} diff --git a/driver-core/src/main/com/mongodb/internal/operation/CommandOperationHelper.java b/driver-core/src/main/com/mongodb/internal/operation/CommandOperationHelper.java index db6870f52e8..8332ad916fb 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CommandOperationHelper.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CommandOperationHelper.java @@ -45,6 +45,7 @@ import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.internal.async.function.RetryState.INFINITE_ATTEMPTS; import static com.mongodb.internal.operation.OperationHelper.LOGGER; import static java.lang.String.format; import static java.util.Arrays.asList; @@ -122,7 +123,9 @@ private static Throwable chooseRetryableWriteException( static RetryState initialRetryState(final boolean retry, final TimeoutContext timeoutContext) { if (retry) { - return RetryState.withRetryableState(RetryState.RETRIES, timeoutContext); + boolean retryUntilTimeoutThrowsException = timeoutContext.hasTimeoutMS(); + int retries = retryUntilTimeoutThrowsException ? INFINITE_ATTEMPTS : RetryState.RETRIES; + return RetryState.withRetryableState(retries, retryUntilTimeoutThrowsException); } return RetryState.withNonRetryableState(); } diff --git a/driver-core/src/main/com/mongodb/internal/operation/CommandReadOperation.java b/driver-core/src/main/com/mongodb/internal/operation/CommandReadOperation.java index 6965bfc34a3..0a5d7ffadc7 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CommandReadOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CommandReadOperation.java @@ -19,12 +19,16 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; import org.bson.BsonDocument; import org.bson.codecs.Decoder; +import org.jetbrains.annotations.NotNull; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.operation.AsyncOperationHelper.CommandReadTransformerAsync; import static com.mongodb.internal.operation.AsyncOperationHelper.executeRetryableReadAsync; import static com.mongodb.internal.operation.CommandOperationHelper.CommandCreator; +import static com.mongodb.internal.operation.SyncOperationHelper.CommandReadTransformer; import static com.mongodb.internal.operation.SyncOperationHelper.executeRetryableRead; /** @@ -56,14 +60,23 @@ public String getCommandName() { } @Override - public T execute(final ReadBinding binding) { - return executeRetryableRead(binding, databaseName, commandCreator, decoder, - (result, source, connection) -> result, false); + public T execute(final ReadBinding binding, final OperationContext operationContext) { + return executeRetryableRead(binding, operationContext, databaseName, commandCreator, decoder, + transformer(), false); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback callback) { - executeRetryableReadAsync(binding, databaseName, commandCreator, decoder, - (result, source, connection) -> result, false, callback); + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, + final SingleResultCallback callback) { + executeRetryableReadAsync(binding, operationContext, databaseName, commandCreator, decoder, + asyncTransformer(), false, callback); + } + + private static CommandReadTransformer transformer() { + return (result, source, connection, operationContext) -> result; + } + + private static @NotNull CommandReadTransformerAsync asyncTransformer() { + return (result, source, connection, operationContext) -> result; } } diff --git a/driver-core/src/main/com/mongodb/internal/operation/CommitTransactionOperation.java b/driver-core/src/main/com/mongodb/internal/operation/CommitTransactionOperation.java index 998a002f348..c99102919f5 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CommitTransactionOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CommitTransactionOperation.java @@ -29,6 +29,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; @@ -65,9 +66,11 @@ public CommitTransactionOperation recoveryToken(@Nullable final BsonDocument rec } @Override - public Void execute(final WriteBinding binding) { + public Void execute(final WriteBinding binding, final OperationContext operationContext) { try { - return super.execute(binding); + return super.execute( + binding, + operationContext.withTimeoutContextOverride(TimeoutContext::withMaxTimeOverrideAsMaxCommitTime)); } catch (MongoException e) { addErrorLabels(e); throw e; @@ -75,8 +78,11 @@ public Void execute(final WriteBinding binding) { } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - super.executeAsync(binding, (result, t) -> { + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + super.executeAsync( + binding, + operationContext.withTimeoutContextOverride(TimeoutContext::withMaxTimeOverrideAsMaxCommitTime), + (result, t) -> { if (t instanceof MongoException) { addErrorLabels((MongoException) t); } @@ -121,7 +127,6 @@ CommandCreator getCommandCreator() { CommandCreator creator = (operationContext, serverDescription, connectionDescription) -> { BsonDocument command = CommitTransactionOperation.super.getCommandCreator() .create(operationContext, serverDescription, connectionDescription); - operationContext.getTimeoutContext().setMaxTimeOverrideToMaxCommitTime(); return command; }; if (alreadyCommitted) { diff --git a/driver-core/src/main/com/mongodb/internal/operation/CoreCursor.java b/driver-core/src/main/com/mongodb/internal/operation/CoreCursor.java new file mode 100644 index 00000000000..5f797c1089f --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/operation/CoreCursor.java @@ -0,0 +1,78 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.operation; + +import com.mongodb.ServerAddress; +import com.mongodb.ServerCursor; +import com.mongodb.internal.connection.OperationContext; +import com.mongodb.lang.Nullable; +import org.bson.BsonDocument; +import org.bson.BsonTimestamp; + +import java.util.List; + +interface CoreCursor { + void close(OperationContext operationContext); + + boolean hasNext(OperationContext operationContext); + + List next(OperationContext operationContext); + + /** + * A special {@code next()} case that returns the next batch if available or null. + * + *

Tailable cursors are an example where this is useful. A call to {@code tryNext()} may return null, but in the future calling + * {@code tryNext()} would return a new batch if a document had been added to the capped collection.

+ * + * @mongodb.driver.manual reference/glossary/#term-tailable-cursor Tailable Cursor + */ + @Nullable + List tryNext(OperationContext operationContext); + + + int available(); + + /** + * Sets the batch size to use when requesting the next batch. This is the number of documents to request in the next batch. + * + * @param batchSize the non-negative batch size. 0 means to use the server default. + */ + void setBatchSize(int batchSize); + + /** + * Gets the batch size to use when requesting the next batch. This is the number of documents to request in the next batch. + * + * @return the non-negative batch size. 0 means to use the server default. + */ + int getBatchSize(); + + @Nullable + ServerCursor getServerCursor(); + + ServerAddress getServerAddress(); + + @Nullable + BsonDocument getPostBatchResumeToken(); + + @Nullable + BsonTimestamp getOperationTime(); + + boolean isFirstBatchEmpty(); + + int getMaxWireVersion(); +} + diff --git a/driver-core/src/main/com/mongodb/internal/operation/CountDocumentsOperation.java b/driver-core/src/main/com/mongodb/internal/operation/CountDocumentsOperation.java index 9460026062a..d4791d2d0bb 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CountDocumentsOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CountDocumentsOperation.java @@ -21,6 +21,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonInt32; @@ -126,15 +127,16 @@ public String getCommandName() { } @Override - public Long execute(final ReadBinding binding) { - try (BatchCursor cursor = getAggregateOperation().execute(binding)) { + public Long execute(final ReadBinding binding, final OperationContext operationContext) { + try (BatchCursor cursor = getAggregateOperation().execute(binding, operationContext)) { return cursor.hasNext() ? getCountFromAggregateResults(cursor.next()) : 0; } } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback callback) { - getAggregateOperation().executeAsync(binding, (result, t) -> { + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, + final SingleResultCallback callback) { + getAggregateOperation().executeAsync(binding, operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { diff --git a/driver-core/src/main/com/mongodb/internal/operation/CountOperation.java b/driver-core/src/main/com/mongodb/internal/operation/CountOperation.java index 6d0b7b78f93..aa9603982c1 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CountOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CountOperation.java @@ -21,6 +21,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; @@ -116,23 +117,23 @@ public String getCommandName() { } @Override - public Long execute(final ReadBinding binding) { - return executeRetryableRead(binding, namespace.getDatabaseName(), + public Long execute(final ReadBinding binding, final OperationContext operationContext) { + return executeRetryableRead(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), DECODER, transformer(), retryReads); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback callback) { - executeRetryableReadAsync(binding, namespace.getDatabaseName(), + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + executeRetryableReadAsync(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), DECODER, asyncTransformer(), retryReads, callback); } private CommandReadTransformer transformer() { - return (result, source, connection) -> (result.getNumber("n")).longValue(); + return (result, source, connection, operationContext) -> (result.getNumber("n")).longValue(); } private CommandReadTransformerAsync asyncTransformer() { - return (result, source, connection) -> (result.getNumber("n")).longValue(); + return (result, source, connection, operationContext) -> (result.getNumber("n")).longValue(); } private CommandCreator getCommandCreator() { diff --git a/driver-core/src/main/com/mongodb/internal/operation/CreateCollectionOperation.java b/driver-core/src/main/com/mongodb/internal/operation/CreateCollectionOperation.java index 5284076eecb..024a97b7b64 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CreateCollectionOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CreateCollectionOperation.java @@ -30,6 +30,7 @@ import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; import com.mongodb.internal.connection.AsyncConnection; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonArray; import org.bson.BsonBoolean; @@ -237,20 +238,20 @@ public String getCommandName() { } @Override - public Void execute(final WriteBinding binding) { - return withConnection(binding, connection -> { + public Void execute(final WriteBinding binding, final OperationContext operationContext) { + return withConnection(binding, operationContext, (connection, operationContextWithMinRtt)-> { checkEncryptedFieldsSupported(connection.getDescription()); getCommandFunctions().forEach(commandCreator -> - executeCommand(binding, databaseName, commandCreator.get(), connection, - writeConcernErrorTransformer(binding.getOperationContext().getTimeoutContext())) + executeCommand(binding, operationContextWithMinRtt, databaseName, commandCreator.get(), connection, + writeConcernErrorTransformer(operationContextWithMinRtt.getTimeoutContext())) ); return null; }); } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - withAsyncConnection(binding, (connection, t) -> { + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + withAsyncConnection(binding, operationContext, (connection, operationContextWithMinRtt, t) -> { SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); if (t != null) { errHandlingCallback.onResult(null, t); @@ -259,7 +260,7 @@ public void executeAsync(final AsyncWriteBinding binding, final SingleResultCall if (!checkEncryptedFieldsSupported(connection.getDescription(), releasingCallback)) { return; } - new ProcessCommandsCallback(binding, connection, releasingCallback) + new ProcessCommandsCallback(binding, operationContextWithMinRtt, connection, releasingCallback) .onResult(null, null); } }); @@ -403,13 +404,15 @@ private boolean checkEncryptedFieldsSupported(final ConnectionDescription connec */ class ProcessCommandsCallback implements SingleResultCallback { private final AsyncWriteBinding binding; + private final OperationContext operationContext; private final AsyncConnection connection; private final SingleResultCallback finalCallback; private final Deque> commands; ProcessCommandsCallback( - final AsyncWriteBinding binding, final AsyncConnection connection, final SingleResultCallback finalCallback) { + final AsyncWriteBinding binding, final OperationContext operationContext, final AsyncConnection connection, final SingleResultCallback finalCallback) { this.binding = binding; + this.operationContext = operationContext; this.connection = connection; this.finalCallback = finalCallback; this.commands = new ArrayDeque<>(getCommandFunctions()); @@ -425,8 +428,8 @@ public void onResult(@Nullable final Void result, @Nullable final Throwable t) { if (nextCommandFunction == null) { finalCallback.onResult(null, null); } else { - executeCommandAsync(binding, databaseName, nextCommandFunction.get(), - connection, writeConcernErrorTransformerAsync(binding.getOperationContext().getTimeoutContext()), this); + executeCommandAsync(binding, operationContext, databaseName, nextCommandFunction.get(), + connection, writeConcernErrorTransformerAsync(operationContext.getTimeoutContext()), this); } } } diff --git a/driver-core/src/main/com/mongodb/internal/operation/CreateIndexesOperation.java b/driver-core/src/main/com/mongodb/internal/operation/CreateIndexesOperation.java index b9b4242a3f4..945d6b325b2 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CreateIndexesOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CreateIndexesOperation.java @@ -29,6 +29,7 @@ import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; import com.mongodb.internal.bulk.IndexRequest; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonArray; import org.bson.BsonBoolean; @@ -106,19 +107,18 @@ public String getCommandName() { } @Override - public Void execute(final WriteBinding binding) { + public Void execute(final WriteBinding binding, final OperationContext operationContext) { try { - return executeCommand(binding, namespace.getDatabaseName(), getCommandCreator(), writeConcernErrorTransformer( - binding.getOperationContext().getTimeoutContext())); + return executeCommand(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), writeConcernErrorTransformer( + operationContext.getTimeoutContext())); } catch (MongoCommandException e) { throw checkForDuplicateKeyError(e); } } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - executeCommandAsync(binding, namespace.getDatabaseName(), getCommandCreator(), writeConcernErrorTransformerAsync(binding - .getOperationContext().getTimeoutContext()), + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + executeCommandAsync(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), writeConcernErrorTransformerAsync(operationContext.getTimeoutContext()), ((result, t) -> { if (t != null) { callback.onResult(null, translateException(t)); diff --git a/driver-core/src/main/com/mongodb/internal/operation/CreateViewOperation.java b/driver-core/src/main/com/mongodb/internal/operation/CreateViewOperation.java index 49b47fb7e9c..32861e6be16 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CreateViewOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CreateViewOperation.java @@ -21,6 +21,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonArray; import org.bson.BsonDocument; @@ -129,24 +130,24 @@ public String getCommandName() { } @Override - public Void execute(final WriteBinding binding) { - return withConnection(binding, connection -> { - executeCommand(binding, databaseName, getCommand(), new BsonDocumentCodec(), - writeConcernErrorTransformer(binding.getOperationContext().getTimeoutContext())); + public Void execute(final WriteBinding binding, final OperationContext operationContext) { + return withConnection(binding, operationContext, (connection, operationContextWithMinRtt) -> { + executeCommand(binding, operationContextWithMinRtt, databaseName, getCommand(), new BsonDocumentCodec(), + writeConcernErrorTransformer(operationContextWithMinRtt.getTimeoutContext())); return null; }); } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - withAsyncConnection(binding, (connection, t) -> { + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + withAsyncConnection(binding, operationContext, (connection, operationContextWithMinRtt, t) -> { SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); if (t != null) { errHandlingCallback.onResult(null, t); } else { SingleResultCallback wrappedCallback = releasingCallback(errHandlingCallback, connection); - executeCommandAsync(binding, databaseName, getCommand(), connection, - writeConcernErrorTransformerAsync(binding.getOperationContext().getTimeoutContext()), + executeCommandAsync(binding, operationContextWithMinRtt, databaseName, getCommand(), connection, + writeConcernErrorTransformerAsync(operationContextWithMinRtt.getTimeoutContext()), wrappedCallback); } }); diff --git a/driver-core/src/main/com/mongodb/internal/operation/CursorResourceManagerNew.java b/driver-core/src/main/com/mongodb/internal/operation/CursorResourceManagerNew.java new file mode 100644 index 00000000000..4cf08fa008a --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/operation/CursorResourceManagerNew.java @@ -0,0 +1,262 @@ +package com.mongodb.internal.operation; + +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import com.mongodb.MongoNamespace; +import com.mongodb.MongoSocketException; +import com.mongodb.ServerCursor; +import com.mongodb.annotations.ThreadSafe; +import com.mongodb.internal.binding.ReferenceCounted; +import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.OperationContext; +import com.mongodb.lang.Nullable; + +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertNull; +import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.assertions.Assertions.fail; +import static com.mongodb.internal.Locks.withLock; +import static com.mongodb.internal.operation.CommandBatchCursorHelper.MESSAGE_IF_CONCURRENT_OPERATION; + +@ThreadSafe +abstract class CursorResourceManagerNew { + private final Lock lock; + private final MongoNamespace namespace; + private volatile State state; + @Nullable + private volatile CS connectionSource; + @Nullable + private volatile C pinnedConnection; + @Nullable + private volatile ServerCursor serverCursor; + private volatile boolean skipReleasingServerResourcesOnClose; + + CursorResourceManagerNew( + final MongoNamespace namespace, + final CS connectionSource, + @Nullable final C connectionToPin, + @Nullable final ServerCursor serverCursor) { + this.lock = new ReentrantLock(); + this.namespace = namespace; + this.state = State.IDLE; + if (serverCursor != null) { + connectionSource.retain(); + this.connectionSource = connectionSource; + if (connectionToPin != null) { + connectionToPin.retain(); + markAsPinned(connectionToPin, Connection.PinningMode.CURSOR); + this.pinnedConnection = connectionToPin; + } + } + this.skipReleasingServerResourcesOnClose = false; + this.serverCursor = serverCursor; + } + + /** + * Thread-safe. + */ + MongoNamespace getNamespace() { + return namespace; + } + + /** + * Thread-safe. + */ + State getState() { + return state; + } + + /** + * Thread-safe. + */ + @Nullable + CS getConnectionSource() { + return connectionSource; + } + + /** + * Thread-safe. + */ + @Nullable + C getPinnedConnection() { + return pinnedConnection; + } + + /** + * Thread-safe. + */ + boolean isSkipReleasingServerResourcesOnClose() { + return skipReleasingServerResourcesOnClose; + } + + @SuppressWarnings("SameParameterValue") + abstract void markAsPinned(C connectionToPin, Connection.PinningMode pinningMode); + + /** + * Thread-safe. + */ + boolean operable() { + return state.operable(); + } + + /** + * Thread-safe. + * Returns {@code true} iff started an operation. + * If {@linkplain #operable() closed}, then returns false, otherwise completes abruptly. + * + * @throws IllegalStateException Iff another operation is in progress. + */ + boolean tryStartOperation() throws IllegalStateException { + return withLock(lock, () -> { + State localState = state; + if (!localState.operable()) { + return false; + } else if (localState == State.IDLE) { + state = State.OPERATION_IN_PROGRESS; + return true; + } else if (localState == State.OPERATION_IN_PROGRESS) { + throw new IllegalStateException(MESSAGE_IF_CONCURRENT_OPERATION); + } else { + throw fail(state.toString()); + } + }); + } + + /** + * Thread-safe. + */ + void endOperation(final OperationContext operationContext) { + boolean doClose = withLock(lock, () -> { + State localState = state; + if (localState == State.OPERATION_IN_PROGRESS) { + state = State.IDLE; + } else if (localState == State.CLOSE_PENDING) { + state = State.CLOSED; + return true; + } else if (localState != State.CLOSED) { + throw fail(localState.toString()); + } + return false; + }); + if (doClose) { + doClose(operationContext); + } + } + + /** + * Thread-safe. + */ + void close(final OperationContext operationContext) { + boolean doClose = withLock(lock, () -> { + State localState = state; + if (localState.isOperationInProgress()) { + state = State.CLOSE_PENDING; + } else if (localState != State.CLOSED) { + state = State.CLOSED; + return true; + } + return false; + }); + if (doClose) { + doClose(operationContext); + } + } + + // /** +// * This method is never executed concurrently with either itself or other operations +// * demarcated by {@link #tryStartOperation()}/{@link #endOperation()}. +// */ + abstract void doClose(OperationContext operationContext); + + void onCorruptedConnection(@Nullable final C corruptedConnection, final MongoSocketException e) { + // if `pinnedConnection` is corrupted, then we cannot kill `serverCursor` via such a connection + C localPinnedConnection = pinnedConnection; + if (localPinnedConnection != null) { + if (corruptedConnection != localPinnedConnection) { + e.addSuppressed(new AssertionError("Corrupted connection does not equal the pinned connection.")); + } + skipReleasingServerResourcesOnClose = true; + } + } + + /** + * Thread-safe. + */ + @Nullable + final ServerCursor getServerCursor() { + return serverCursor; + } + + void setServerCursor(@Nullable final ServerCursor serverCursor) { + assertTrue(state.isOperationInProgress()); + assertNotNull(this.serverCursor); + // without `connectionSource` we will not be able to kill `serverCursor` later + assertNotNull(connectionSource); + this.serverCursor = serverCursor; + if (serverCursor == null) { + releaseClientResources(); + } + } + + void unsetServerCursor() { + this.serverCursor = null; + } + + void releaseClientResources() { + assertNull(serverCursor); + CS localConnectionSource = connectionSource; + if (localConnectionSource != null) { + localConnectionSource.release(); + connectionSource = null; + } + C localPinnedConnection = pinnedConnection; + if (localPinnedConnection != null) { + localPinnedConnection.release(); + pinnedConnection = null; + } + } + + enum State { + IDLE(true, false), + OPERATION_IN_PROGRESS(true, true), + /** + * Implies {@link #OPERATION_IN_PROGRESS}. + */ + CLOSE_PENDING(false, true), + CLOSED(false, false); + + private final boolean operable; + private final boolean operationInProgress; + + State(final boolean operable, final boolean operationInProgress) { + this.operable = operable; + this.operationInProgress = operationInProgress; + } + + boolean operable() { + return operable; + } + + boolean isOperationInProgress() { + return operationInProgress; + } + } +} + diff --git a/driver-core/src/main/com/mongodb/internal/operation/DistinctOperation.java b/driver-core/src/main/com/mongodb/internal/operation/DistinctOperation.java index 489e3923bdc..0ae7f257e7e 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/DistinctOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/DistinctOperation.java @@ -22,6 +22,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; @@ -114,14 +115,14 @@ public String getCommandName() { } @Override - public BatchCursor execute(final ReadBinding binding) { - return executeRetryableRead(binding, namespace.getDatabaseName(), getCommandCreator(), createCommandDecoder(), + public BatchCursor execute(final ReadBinding binding, final OperationContext operationContext) { + return executeRetryableRead(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), createCommandDecoder(), singleBatchCursorTransformer(VALUES), retryReads); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { - executeRetryableReadAsync(binding, namespace.getDatabaseName(), + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback> callback) { + executeRetryableReadAsync(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), createCommandDecoder(), asyncSingleBatchCursorTransformer(VALUES), retryReads, errorHandlingCallback(callback, LOGGER)); } diff --git a/driver-core/src/main/com/mongodb/internal/operation/DropCollectionOperation.java b/driver-core/src/main/com/mongodb/internal/operation/DropCollectionOperation.java index 5f61f2980f8..cfc4fa175d6 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/DropCollectionOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/DropCollectionOperation.java @@ -26,6 +26,7 @@ import com.mongodb.internal.binding.ReadWriteBinding; import com.mongodb.internal.binding.WriteBinding; import com.mongodb.internal.connection.AsyncConnection; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; @@ -92,13 +93,13 @@ public String getCommandName() { } @Override - public Void execute(final WriteBinding binding) { - BsonDocument localEncryptedFields = getEncryptedFields((ReadWriteBinding) binding); - return withConnection(binding, connection -> { + public Void execute(final WriteBinding binding, final OperationContext operationContext) { + BsonDocument localEncryptedFields = getEncryptedFields((ReadWriteBinding) binding, operationContext); + return withConnection(binding, operationContext, (connection, operationContextWithMinRtt) -> { getCommands(localEncryptedFields).forEach(command -> { try { - executeCommand(binding, namespace.getDatabaseName(), command.get(), - connection, writeConcernErrorTransformer(binding.getOperationContext().getTimeoutContext())); + executeCommand(binding, operationContextWithMinRtt, namespace.getDatabaseName(), command.get(), + connection, writeConcernErrorTransformer(operationContextWithMinRtt.getTimeoutContext())); } catch (MongoCommandException e) { rethrowIfNotNamespaceError(e); } @@ -108,17 +109,19 @@ public Void execute(final WriteBinding binding) { } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, + final SingleResultCallback callback) { SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); - getEncryptedFields((AsyncReadWriteBinding) binding, (result, t) -> { + getEncryptedFields((AsyncReadWriteBinding) binding, operationContext, (result, t) -> { if (t != null) { errHandlingCallback.onResult(null, t); } else { - withAsyncConnection(binding, (connection, t1) -> { + withAsyncConnection(binding, operationContext, (connection, operationContextWithMinRtt, t1) -> { if (t1 != null) { errHandlingCallback.onResult(null, t1); } else { - new ProcessCommandsCallback(binding, connection, getCommands(result), releasingCallback(errHandlingCallback, + new ProcessCommandsCallback(binding, operationContextWithMinRtt, connection, getCommands(result), + releasingCallback(errHandlingCallback, connection)) .onResult(null, null); } @@ -178,9 +181,9 @@ private BsonDocument dropCollectionCommand() { } @Nullable - private BsonDocument getEncryptedFields(final ReadWriteBinding readWriteBinding) { + private BsonDocument getEncryptedFields(final ReadWriteBinding readWriteBinding, final OperationContext operationContext) { if (encryptedFields == null && autoEncryptedFields) { - try (BatchCursor cursor = listCollectionOperation().execute(readWriteBinding)) { + try (BatchCursor cursor = listCollectionOperation().execute(readWriteBinding, operationContext)) { return getCollectionEncryptedFields(encryptedFields, cursor.tryNext()); } } @@ -189,9 +192,10 @@ private BsonDocument getEncryptedFields(final ReadWriteBinding readWriteBinding) private void getEncryptedFields( final AsyncReadWriteBinding asyncReadWriteBinding, + final OperationContext operationContext, final SingleResultCallback callback) { if (encryptedFields == null && autoEncryptedFields) { - listCollectionOperation().executeAsync(asyncReadWriteBinding, (cursor, t) -> { + listCollectionOperation().executeAsync(asyncReadWriteBinding, operationContext, (cursor, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -229,15 +233,19 @@ private ListCollectionsOperation listCollectionOperation() { */ class ProcessCommandsCallback implements SingleResultCallback { private final AsyncWriteBinding binding; + private final OperationContext operationContext; private final AsyncConnection connection; private final SingleResultCallback finalCallback; private final Deque> commands; ProcessCommandsCallback( - final AsyncWriteBinding binding, final AsyncConnection connection, + final AsyncWriteBinding binding, + final OperationContext operationContext, + final AsyncConnection connection, final List> commands, final SingleResultCallback finalCallback) { this.binding = binding; + this.operationContext = operationContext; this.connection = connection; this.finalCallback = finalCallback; this.commands = new ArrayDeque<>(commands); @@ -254,8 +262,8 @@ public void onResult(@Nullable final Void result, @Nullable final Throwable t) { finalCallback.onResult(null, null); } else { try { - executeCommandAsync(binding, namespace.getDatabaseName(), nextCommandFunction.get(), - connection, writeConcernErrorTransformerAsync(binding.getOperationContext().getTimeoutContext()), this); + executeCommandAsync(binding, operationContext, namespace.getDatabaseName(), nextCommandFunction.get(), + connection, writeConcernErrorTransformerAsync(operationContext.getTimeoutContext()), this); } catch (MongoOperationTimeoutException operationTimeoutException) { finalCallback.onResult(null, operationTimeoutException); } diff --git a/driver-core/src/main/com/mongodb/internal/operation/DropDatabaseOperation.java b/driver-core/src/main/com/mongodb/internal/operation/DropDatabaseOperation.java index d619176e8a3..e0a4417a3f4 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/DropDatabaseOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/DropDatabaseOperation.java @@ -20,6 +20,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonInt32; @@ -61,23 +62,23 @@ public String getCommandName() { } @Override - public Void execute(final WriteBinding binding) { - return withConnection(binding, connection -> { - executeCommand(binding, databaseName, getCommand(), connection, writeConcernErrorTransformer(binding.getOperationContext() + public Void execute(final WriteBinding binding, final OperationContext operationContext) { + return withConnection(binding, operationContext, (connection, operationContextWithMinRtt) -> { + executeCommand(binding, operationContextWithMinRtt, databaseName, getCommand(), connection, writeConcernErrorTransformer(operationContextWithMinRtt .getTimeoutContext())); return null; }); } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - withAsyncConnection(binding, (connection, t) -> { + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + withAsyncConnection(binding, operationContext, (connection, operationContextWithMinRtt, t) -> { SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); if (t != null) { errHandlingCallback.onResult(null, t); } else { - executeCommandAsync(binding, databaseName, getCommand(), connection, - writeConcernErrorTransformerAsync(binding.getOperationContext().getTimeoutContext()), + executeCommandAsync(binding, operationContextWithMinRtt, databaseName, getCommand(), connection, + writeConcernErrorTransformerAsync(operationContextWithMinRtt.getTimeoutContext()), releasingCallback(errHandlingCallback, connection)); } diff --git a/driver-core/src/main/com/mongodb/internal/operation/DropIndexOperation.java b/driver-core/src/main/com/mongodb/internal/operation/DropIndexOperation.java index 3671a90aa56..5970b17265d 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/DropIndexOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/DropIndexOperation.java @@ -22,6 +22,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; @@ -71,11 +72,10 @@ public String getCommandName() { } @Override - public Void execute(final WriteBinding binding) { + public Void execute(final WriteBinding binding, final OperationContext operationContext) { try { - executeCommand(binding, namespace.getDatabaseName(), getCommandCreator(), writeConcernErrorTransformer(binding - .getOperationContext() - .getTimeoutContext())); + executeCommand(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), writeConcernErrorTransformer( + operationContext.getTimeoutContext())); } catch (MongoCommandException e) { rethrowIfNotNamespaceError(e); } @@ -83,9 +83,10 @@ public Void execute(final WriteBinding binding) { } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - executeCommandAsync(binding, namespace.getDatabaseName(), getCommandCreator(), - writeConcernErrorTransformerAsync(binding.getOperationContext().getTimeoutContext()), (result, t) -> { + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, + final SingleResultCallback callback) { + executeCommandAsync(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), + writeConcernErrorTransformerAsync(operationContext.getTimeoutContext()), (result, t) -> { if (t != null && !isNamespaceError(t)) { callback.onResult(null, t); } else { diff --git a/driver-core/src/main/com/mongodb/internal/operation/EstimatedDocumentCountOperation.java b/driver-core/src/main/com/mongodb/internal/operation/EstimatedDocumentCountOperation.java index 427cd40dc40..2fadba49fde 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/EstimatedDocumentCountOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/EstimatedDocumentCountOperation.java @@ -22,6 +22,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; @@ -76,9 +77,9 @@ public String getCommandName() { } @Override - public Long execute(final ReadBinding binding) { + public Long execute(final ReadBinding binding, final OperationContext operationContext) { try { - return executeRetryableRead(binding, namespace.getDatabaseName(), + return executeRetryableRead(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), CommandResultDocumentCodec.create(DECODER, singletonList("firstBatch")), transformer(), retryReads); } catch (MongoCommandException e) { @@ -87,8 +88,8 @@ public Long execute(final ReadBinding binding) { } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback callback) { - executeRetryableReadAsync(binding, namespace.getDatabaseName(), + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + executeRetryableReadAsync(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), CommandResultDocumentCodec.create(DECODER, singletonList("firstBatch")), asyncTransformer(), retryReads, (result, t) -> { @@ -101,11 +102,11 @@ public void executeAsync(final AsyncReadBinding binding, final SingleResultCallb } private CommandReadTransformer transformer() { - return (result, source, connection) -> transformResult(result, connection.getDescription()); + return (result, source, connection, operationContext) -> transformResult(result, connection.getDescription()); } private CommandReadTransformerAsync asyncTransformer() { - return (result, source, connection) -> transformResult(result, connection.getDescription()); + return (result, source, connection, operationContext) -> transformResult(result, connection.getDescription()); } private long transformResult(final BsonDocument result, final ConnectionDescription connectionDescription) { diff --git a/driver-core/src/main/com/mongodb/internal/operation/ExplainCommandOperation.java b/driver-core/src/main/com/mongodb/internal/operation/ExplainCommandOperation.java new file mode 100644 index 00000000000..b63c9ae4419 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/operation/ExplainCommandOperation.java @@ -0,0 +1,63 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.operation; + +import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.binding.AsyncReadBinding; +import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; +import org.bson.BsonDocument; +import org.bson.codecs.Decoder; + +import static com.mongodb.internal.operation.CommandOperationHelper.CommandCreator; + +/** + *

This class is not part of the public API and may be removed or changed at any time

+ */ +public class ExplainCommandOperation extends CommandReadOperation { + + public ExplainCommandOperation(final String databaseName, final BsonDocument command, final Decoder decoder) { + super(databaseName, command, decoder); + } + + public ExplainCommandOperation(final String databaseName, final String commandName, final CommandCreator commandCreator, final Decoder decoder) { + super(databaseName, commandName, commandCreator, decoder); + } + + @Override + public T execute(final ReadBinding binding, final OperationContext operationContext) { + return super.execute(binding, operationContext.withTimeoutContextOverride(timeoutContext -> { + if (!timeoutContext.hasTimeoutMS()) { + return timeoutContext.withDisabledMaxTimeOverride(); + } + + return timeoutContext; + })); + } + + @Override + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, + final SingleResultCallback callback) { + super.executeAsync(binding, operationContext.withTimeoutContextOverride(timeoutContext -> { + if (!timeoutContext.hasTimeoutMS()) { + return timeoutContext.withDisabledMaxTimeOverride(); + } + + return timeoutContext; + }), callback); + } +} diff --git a/driver-core/src/main/com/mongodb/internal/operation/FindOperation.java b/driver-core/src/main/com/mongodb/internal/operation/FindOperation.java index 04d4d7afd67..c5dbd2251ff 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/FindOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/FindOperation.java @@ -23,6 +23,7 @@ import com.mongodb.MongoQueryException; import com.mongodb.client.cursor.TimeoutMode; import com.mongodb.client.model.Collation; +import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.async.AsyncBatchCursor; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.async.function.AsyncCallbackSupplier; @@ -54,7 +55,6 @@ import static com.mongodb.internal.operation.ExplainHelper.asExplainCommand; import static com.mongodb.internal.operation.OperationHelper.LOGGER; import static com.mongodb.internal.operation.OperationHelper.canRetryRead; -import static com.mongodb.internal.operation.OperationHelper.setNonTailableCursorMaxTimeSupplier; import static com.mongodb.internal.operation.OperationReadConcernHelper.appendReadConcernToCommand; import static com.mongodb.internal.operation.ServerVersionHelper.UNKNOWN_WIRE_VERSION; import static com.mongodb.internal.operation.SyncOperationHelper.CommandReadTransformer; @@ -291,48 +291,67 @@ public String getCommandName() { } @Override - public BatchCursor execute(final ReadBinding binding) { + public BatchCursor execute(final ReadBinding binding, final OperationContext operationContext) { IllegalStateException invalidTimeoutModeException = invalidTimeoutModeException(); if (invalidTimeoutModeException != null) { throw invalidTimeoutModeException; } - RetryState retryState = initialRetryState(retryReads, binding.getOperationContext().getTimeoutContext()); - Supplier> read = decorateReadWithRetries(retryState, binding.getOperationContext(), () -> - withSourceAndConnection(binding::getReadConnectionSource, false, (source, connection) -> { - retryState.breakAndThrowIfRetryAnd(() -> !canRetryRead(source.getServerDescription(), binding.getOperationContext())); + OperationContext findOperationContext; + if (shouldDisableMaxTimeMS()) { + findOperationContext = operationContext.withTimeoutContextOverride(TimeoutContext::withDisabledMaxTimeOverride); + } else { + findOperationContext = operationContext; + } + + RetryState retryState = initialRetryState(retryReads, findOperationContext.getTimeoutContext()); + Supplier> read = decorateReadWithRetries(retryState, findOperationContext, () -> + withSourceAndConnection(binding::getReadConnectionSource, false, + (source, connection, commandOperationContext) -> { + retryState.breakAndThrowIfRetryAnd(() -> !canRetryRead(source.getServerDescription(), commandOperationContext)); try { - return createReadCommandAndExecute(retryState, binding.getOperationContext(), source, namespace.getDatabaseName(), + return createReadCommandAndExecute(retryState, commandOperationContext, source, namespace.getDatabaseName(), getCommandCreator(), CommandResultDocumentCodec.create(decoder, FIRST_BATCH), transformer(), connection); } catch (MongoCommandException e) { throw new MongoQueryException(e.getResponse(), e.getServerAddress()); } - }) + }, findOperationContext) ); return read.get(); } + private boolean shouldDisableMaxTimeMS() { + return isTailableCursor() && !isAwaitData() || timeoutMode == TimeoutMode.ITERATION; + } + @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback> callback) { IllegalStateException invalidTimeoutModeException = invalidTimeoutModeException(); if (invalidTimeoutModeException != null) { callback.onResult(null, invalidTimeoutModeException); return; } - RetryState retryState = initialRetryState(retryReads, binding.getOperationContext().getTimeoutContext()); + OperationContext findOperationContext; + if (shouldDisableMaxTimeMS()) { + findOperationContext = operationContext.withTimeoutContextOverride(TimeoutContext::withDisabledMaxTimeOverride); + } else { + findOperationContext = operationContext; + } + + RetryState retryState = initialRetryState(retryReads, findOperationContext.getTimeoutContext()); binding.retain(); AsyncCallbackSupplier> asyncRead = decorateReadWithRetriesAsync( - retryState, binding.getOperationContext(), (AsyncCallbackSupplier>) funcCallback -> - withAsyncSourceAndConnection(binding::getReadConnectionSource, false, funcCallback, - (source, connection, releasingCallback) -> { + retryState, operationContext, (AsyncCallbackSupplier>) funcCallback -> + withAsyncSourceAndConnection(binding::getReadConnectionSource, false, findOperationContext, funcCallback, + (source, connection, operationContextWithMinRTT, releasingCallback) -> { if (retryState.breakAndCompleteIfRetryAnd(() -> !canRetryRead(source.getServerDescription(), - binding.getOperationContext()), releasingCallback)) { + findOperationContext), releasingCallback)) { return; } SingleResultCallback> wrappedCallback = exceptionTransformingCallback(releasingCallback); - createReadCommandAndExecuteAsync(retryState, binding.getOperationContext(), source, + createReadCommandAndExecuteAsync(retryState, operationContextWithMinRTT, source, namespace.getDatabaseName(), getCommandCreator(), CommandResultDocumentCodec.create(decoder, FIRST_BATCH), asyncTransformer(), connection, wrappedCallback); @@ -363,7 +382,7 @@ public CommandReadOperation asExplainableOperation(@Nullable final Explai } CommandReadOperation createExplainableOperation(@Nullable final ExplainVerbosity verbosity, final Decoder resultDecoder) { - return new CommandReadOperation<>(getNamespace().getDatabaseName(), getCommandName(), + return new ExplainCommandOperation<>(getNamespace().getDatabaseName(), getCommandName(), (operationContext, serverDescription, connectionDescription) -> { BsonDocument command = getCommand(operationContext, UNKNOWN_WIRE_VERSION); applyMaxTimeMS(operationContext.getTimeoutContext(), command); @@ -404,11 +423,7 @@ private BsonDocument getCommand(final OperationContext operationContext, final i commandDocument.put("tailable", BsonBoolean.TRUE); if (isAwaitData()) { commandDocument.put("awaitData", BsonBoolean.TRUE); - } else { - operationContext.getTimeoutContext().disableMaxTimeOverride(); } - } else { - setNonTailableCursorMaxTimeSupplier(timeoutMode, operationContext); } if (noCursorTimeout) { @@ -468,15 +483,20 @@ private TimeoutMode getTimeoutMode() { } private CommandReadTransformer> transformer() { - return (result, source, connection) -> - new CommandBatchCursor<>(getTimeoutMode(), result, batchSize, getMaxTimeForCursor(source.getOperationContext()), decoder, - comment, source, connection); + return (result, source, connection, operationContext) -> + new CommandBatchCursor<>(getTimeoutMode(), getMaxTimeForCursor(operationContext), operationContext, + new CommandCoreCursor<>( + result, batchSize, decoder, comment, source, connection + )); } private CommandReadTransformerAsync> asyncTransformer() { - return (result, source, connection) -> - new AsyncCommandBatchCursor<>(getTimeoutMode(), result, batchSize, getMaxTimeForCursor(source.getOperationContext()), decoder, - comment, source, connection); + return (result, source, connection, operationContext) -> + new AsyncCommandBatchCursor<>(getTimeoutMode(), getMaxTimeForCursor(operationContext), operationContext, + new AsyncCommandCoreCursor<>( + result, batchSize, decoder, + comment, source, connection + )); } private long getMaxTimeForCursor(final OperationContext operationContext) { diff --git a/driver-core/src/main/com/mongodb/internal/operation/ListCollectionsOperation.java b/driver-core/src/main/com/mongodb/internal/operation/ListCollectionsOperation.java index 8740986b23f..610e8d7e022 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ListCollectionsOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ListCollectionsOperation.java @@ -25,6 +25,7 @@ import com.mongodb.internal.async.function.RetryState; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonInt32; @@ -51,8 +52,8 @@ import static com.mongodb.internal.operation.DocumentHelper.putIfNotNull; import static com.mongodb.internal.operation.DocumentHelper.putIfTrue; import static com.mongodb.internal.operation.OperationHelper.LOGGER; +import static com.mongodb.internal.operation.OperationHelper.applyTimeoutModeToOperationContext; import static com.mongodb.internal.operation.OperationHelper.canRetryRead; -import static com.mongodb.internal.operation.OperationHelper.setNonTailableCursorMaxTimeSupplier; import static com.mongodb.internal.operation.SingleBatchCursor.createEmptySingleBatchCursor; import static com.mongodb.internal.operation.SyncOperationHelper.CommandReadTransformer; import static com.mongodb.internal.operation.SyncOperationHelper.createReadCommandAndExecute; @@ -164,36 +165,41 @@ public String getCommandName() { } @Override - public BatchCursor execute(final ReadBinding binding) { - RetryState retryState = initialRetryState(retryReads, binding.getOperationContext().getTimeoutContext()); - Supplier> read = decorateReadWithRetries(retryState, binding.getOperationContext(), () -> - withSourceAndConnection(binding::getReadConnectionSource, false, (source, connection) -> { - retryState.breakAndThrowIfRetryAnd(() -> !canRetryRead(source.getServerDescription(), binding.getOperationContext())); + public BatchCursor execute(final ReadBinding binding, final OperationContext operationContext) { + OperationContext listCollectionsOperationContext = applyTimeoutModeToOperationContext(timeoutMode, operationContext); + + RetryState retryState = initialRetryState(retryReads, listCollectionsOperationContext.getTimeoutContext()); + Supplier> read = decorateReadWithRetries(retryState, listCollectionsOperationContext, () -> + withSourceAndConnection(binding::getReadConnectionSource, false, (source, connection, operationContextWithMinRTT) -> { + retryState.breakAndThrowIfRetryAnd(() -> !canRetryRead(source.getServerDescription(), operationContextWithMinRTT)); try { - return createReadCommandAndExecute(retryState, binding.getOperationContext(), source, databaseName, + return createReadCommandAndExecute(retryState, operationContextWithMinRTT, source, databaseName, getCommandCreator(), createCommandDecoder(), transformer(), connection); } catch (MongoCommandException e) { return rethrowIfNotNamespaceError(e, createEmptySingleBatchCursor(source.getServerDescription().getAddress(), batchSize)); } - }) + }, listCollectionsOperationContext) ); return read.get(); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { - RetryState retryState = initialRetryState(retryReads, binding.getOperationContext().getTimeoutContext()); + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, + final SingleResultCallback> callback) { + OperationContext listCollectionsOperationContext = applyTimeoutModeToOperationContext(timeoutMode, operationContext); + + RetryState retryState = initialRetryState(retryReads, listCollectionsOperationContext.getTimeoutContext()); binding.retain(); AsyncCallbackSupplier> asyncRead = decorateReadWithRetriesAsync( - retryState, binding.getOperationContext(), (AsyncCallbackSupplier>) funcCallback -> - withAsyncSourceAndConnection(binding::getReadConnectionSource, false, funcCallback, - (source, connection, releasingCallback) -> { + retryState, listCollectionsOperationContext, (AsyncCallbackSupplier>) funcCallback -> + withAsyncSourceAndConnection(binding::getReadConnectionSource, false, listCollectionsOperationContext, funcCallback, + (source, connection, operationContextWithMinRtt, releasingCallback) -> { if (retryState.breakAndCompleteIfRetryAnd(() -> !canRetryRead(source.getServerDescription(), - binding.getOperationContext()), releasingCallback)) { + operationContextWithMinRtt), releasingCallback)) { return; } - createReadCommandAndExecuteAsync(retryState, binding.getOperationContext(), source, databaseName, + createReadCommandAndExecuteAsync(retryState, operationContextWithMinRtt, source, databaseName, getCommandCreator(), createCommandDecoder(), asyncTransformer(), connection, (result, t) -> { if (t != null && !isNamespaceError(t)) { @@ -209,13 +215,13 @@ public void executeAsync(final AsyncReadBinding binding, final SingleResultCallb } private CommandReadTransformer> transformer() { - return (result, source, connection) -> - cursorDocumentToBatchCursor(timeoutMode, result, batchSize, decoder, comment, source, connection); + return (result, source, connection, operationContext) -> + cursorDocumentToBatchCursor(timeoutMode, result, batchSize, decoder, comment, source, connection, operationContext); } private CommandReadTransformerAsync> asyncTransformer() { - return (result, source, connection) -> - cursorDocumentToAsyncBatchCursor(timeoutMode, result, batchSize, decoder, comment, source, connection); + return (result, source, connection, operationContext) -> + cursorDocumentToAsyncBatchCursor(timeoutMode, result, batchSize, decoder, comment, source, connection, operationContext); } @@ -226,7 +232,6 @@ private CommandCreator getCommandCreator() { putIfNotNull(commandDocument, "filter", filter); putIfTrue(commandDocument, "nameOnly", nameOnly); putIfTrue(commandDocument, "authorizedCollections", authorizedCollections); - setNonTailableCursorMaxTimeSupplier(timeoutMode, operationContext); putIfNotNull(commandDocument, "comment", comment); return commandDocument; }; diff --git a/driver-core/src/main/com/mongodb/internal/operation/ListDatabasesOperation.java b/driver-core/src/main/com/mongodb/internal/operation/ListDatabasesOperation.java index 4787153190b..ec58bd977ba 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ListDatabasesOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ListDatabasesOperation.java @@ -20,6 +20,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonInt32; @@ -108,14 +109,16 @@ public String getCommandName() { } @Override - public BatchCursor execute(final ReadBinding binding) { - return executeRetryableRead(binding, "admin", getCommandCreator(), CommandResultDocumentCodec.create(decoder, DATABASES), + public BatchCursor execute(final ReadBinding binding, final OperationContext operationContext) { + return executeRetryableRead(binding, operationContext, "admin", getCommandCreator(), + CommandResultDocumentCodec.create(decoder, DATABASES), singleBatchCursorTransformer(DATABASES), retryReads); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { - executeRetryableReadAsync(binding, "admin", getCommandCreator(), CommandResultDocumentCodec.create(decoder, DATABASES), + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, + final SingleResultCallback> callback) { + executeRetryableReadAsync(binding, operationContext, "admin", getCommandCreator(), CommandResultDocumentCodec.create(decoder, DATABASES), asyncSingleBatchCursorTransformer(DATABASES), retryReads, errorHandlingCallback(callback, LOGGER)); } diff --git a/driver-core/src/main/com/mongodb/internal/operation/ListIndexesOperation.java b/driver-core/src/main/com/mongodb/internal/operation/ListIndexesOperation.java index a97acd64d58..b3ddf31e48c 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ListIndexesOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ListIndexesOperation.java @@ -25,6 +25,7 @@ import com.mongodb.internal.async.function.RetryState; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; @@ -49,8 +50,8 @@ import static com.mongodb.internal.operation.CursorHelper.getCursorDocumentFromBatchSize; import static com.mongodb.internal.operation.DocumentHelper.putIfNotNull; import static com.mongodb.internal.operation.OperationHelper.LOGGER; +import static com.mongodb.internal.operation.OperationHelper.applyTimeoutModeToOperationContext; import static com.mongodb.internal.operation.OperationHelper.canRetryRead; -import static com.mongodb.internal.operation.OperationHelper.setNonTailableCursorMaxTimeSupplier; import static com.mongodb.internal.operation.SingleBatchCursor.createEmptySingleBatchCursor; import static com.mongodb.internal.operation.SyncOperationHelper.CommandReadTransformer; import static com.mongodb.internal.operation.SyncOperationHelper.createReadCommandAndExecute; @@ -123,36 +124,40 @@ public String getCommandName() { } @Override - public BatchCursor execute(final ReadBinding binding) { - RetryState retryState = initialRetryState(retryReads, binding.getOperationContext().getTimeoutContext()); - Supplier> read = decorateReadWithRetries(retryState, binding.getOperationContext(), () -> - withSourceAndConnection(binding::getReadConnectionSource, false, (source, connection) -> { - retryState.breakAndThrowIfRetryAnd(() -> !canRetryRead(source.getServerDescription(), binding.getOperationContext())); + public BatchCursor execute(final ReadBinding binding, final OperationContext operationContext) { + OperationContext listIndexesOperationContext = applyTimeoutModeToOperationContext(timeoutMode, operationContext); + + RetryState retryState = initialRetryState(retryReads, listIndexesOperationContext.getTimeoutContext()); + Supplier> read = decorateReadWithRetries(retryState, listIndexesOperationContext, () -> + withSourceAndConnection(binding::getReadConnectionSource, false, (source, connection, operationContextWithMinRTT) -> { + retryState.breakAndThrowIfRetryAnd(() -> !canRetryRead(source.getServerDescription(), operationContextWithMinRTT)); try { - return createReadCommandAndExecute(retryState, binding.getOperationContext(), source, namespace.getDatabaseName(), + return createReadCommandAndExecute(retryState, operationContextWithMinRTT, source, namespace.getDatabaseName(), getCommandCreator(), createCommandDecoder(), transformer(), connection); } catch (MongoCommandException e) { return rethrowIfNotNamespaceError(e, createEmptySingleBatchCursor(source.getServerDescription().getAddress(), batchSize)); } - }) + }, listIndexesOperationContext) ); return read.get(); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { - RetryState retryState = initialRetryState(retryReads, binding.getOperationContext().getTimeoutContext()); + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback> callback) { + OperationContext listIndexesOperationContext = applyTimeoutModeToOperationContext(timeoutMode, operationContext); + + RetryState retryState = initialRetryState(retryReads, operationContext.getTimeoutContext()); binding.retain(); AsyncCallbackSupplier> asyncRead = decorateReadWithRetriesAsync( - retryState, binding.getOperationContext(), (AsyncCallbackSupplier>) funcCallback -> - withAsyncSourceAndConnection(binding::getReadConnectionSource, false, funcCallback, - (source, connection, releasingCallback) -> { + retryState, listIndexesOperationContext, (AsyncCallbackSupplier>) funcCallback -> + withAsyncSourceAndConnection(binding::getReadConnectionSource, false, listIndexesOperationContext, funcCallback, + (source, connection, operationContextWithMinRtt, releasingCallback) -> { if (retryState.breakAndCompleteIfRetryAnd(() -> !canRetryRead(source.getServerDescription(), - binding.getOperationContext()), releasingCallback)) { + operationContextWithMinRtt), releasingCallback)) { return; } - createReadCommandAndExecuteAsync(retryState, binding.getOperationContext(), source, + createReadCommandAndExecuteAsync(retryState, operationContextWithMinRtt, source, namespace.getDatabaseName(), getCommandCreator(), createCommandDecoder(), asyncTransformer(), connection, (result, t) -> { @@ -173,20 +178,19 @@ private CommandCreator getCommandCreator() { return (operationContext, serverDescription, connectionDescription) -> { BsonDocument commandDocument = new BsonDocument(getCommandName(), new BsonString(namespace.getCollectionName())) .append("cursor", getCursorDocumentFromBatchSize(batchSize == 0 ? null : batchSize)); - setNonTailableCursorMaxTimeSupplier(timeoutMode, operationContext); putIfNotNull(commandDocument, "comment", comment); return commandDocument; }; } private CommandReadTransformer> transformer() { - return (result, source, connection) -> - cursorDocumentToBatchCursor(timeoutMode, result, batchSize, decoder, comment, source, connection); + return (result, source, connection, operationContext) -> + cursorDocumentToBatchCursor(timeoutMode, result, batchSize, decoder, comment, source, connection, operationContext); } private CommandReadTransformerAsync> asyncTransformer() { - return (result, source, connection) -> - cursorDocumentToAsyncBatchCursor(timeoutMode, result, batchSize, decoder, comment, source, connection); + return (result, source, connection, operationContext) -> + cursorDocumentToAsyncBatchCursor(timeoutMode, result, batchSize, decoder, comment, source, connection, operationContext); } private Codec createCommandDecoder() { diff --git a/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java b/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java index 7fadead0b57..e7b5d84752d 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java @@ -24,6 +24,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; @@ -79,9 +80,9 @@ public String getCommandName() { } @Override - public BatchCursor execute(final ReadBinding binding) { + public BatchCursor execute(final ReadBinding binding, final OperationContext operationContext) { try { - return asAggregateOperation().execute(binding); + return asAggregateOperation().execute(binding, operationContext); } catch (MongoCommandException exception) { int cursorBatchSize = batchSize == null ? 0 : batchSize; if (!isNamespaceError(exception)) { @@ -93,8 +94,8 @@ public BatchCursor execute(final ReadBinding binding) { } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { - asAggregateOperation().executeAsync(binding, (cursor, exception) -> { + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback> callback) { + asAggregateOperation().executeAsync(binding, operationContext, (cursor, exception) -> { if (exception != null && !isNamespaceError(exception)) { callback.onResult(null, exception); } else if (exception != null) { diff --git a/driver-core/src/main/com/mongodb/internal/operation/MapReduceToCollectionOperation.java b/driver-core/src/main/com/mongodb/internal/operation/MapReduceToCollectionOperation.java index bfcc73a5aa6..4e46598a82f 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/MapReduceToCollectionOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/MapReduceToCollectionOperation.java @@ -24,6 +24,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonBoolean; import org.bson.BsonDocument; @@ -215,17 +216,19 @@ public String getCommandName() { } @Override - public MapReduceStatistics execute(final WriteBinding binding) { - return executeCommand(binding, namespace.getDatabaseName(), getCommandCreator(), transformer(binding - .getOperationContext() - .getTimeoutContext())); + public MapReduceStatistics execute(final WriteBinding binding, final OperationContext operationContext) { + return executeCommand(binding, + operationContext, + namespace.getDatabaseName(), + getCommandCreator(), + transformer(operationContext.getTimeoutContext())); } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - executeCommandAsync(binding, namespace.getDatabaseName(), getCommandCreator(), transformerAsync(binding - .getOperationContext() - .getTimeoutContext()), callback); + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, + final SingleResultCallback callback) { + executeCommandAsync(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), + transformerAsync(operationContext.getTimeoutContext()), callback); } /** @@ -239,7 +242,7 @@ public ReadOperationSimple asExplainableOperation(final ExplainVer } private CommandReadOperation createExplainableOperation(final ExplainVerbosity explainVerbosity) { - return new CommandReadOperation<>(getNamespace().getDatabaseName(), getCommandName(), + return new ExplainCommandOperation<>(getNamespace().getDatabaseName(), getCommandName(), (operationContext, serverDescription, connectionDescription) -> { BsonDocument command = getCommandCreator().create(operationContext, serverDescription, connectionDescription); applyMaxTimeMS(operationContext.getTimeoutContext(), command); diff --git a/driver-core/src/main/com/mongodb/internal/operation/MapReduceWithInlineResultsOperation.java b/driver-core/src/main/com/mongodb/internal/operation/MapReduceWithInlineResultsOperation.java index 6661c2a5c77..80993fd24be 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/MapReduceWithInlineResultsOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/MapReduceWithInlineResultsOperation.java @@ -22,6 +22,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonInt32; @@ -170,16 +171,17 @@ public String getCommandName() { } @Override - public MapReduceBatchCursor execute(final ReadBinding binding) { - return executeRetryableRead(binding, namespace.getDatabaseName(), + public MapReduceBatchCursor execute(final ReadBinding binding, final OperationContext operationContext) { + return executeRetryableRead(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), CommandResultDocumentCodec.create(decoder, "results"), transformer(), false); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, + final SingleResultCallback> callback) { SingleResultCallback> errHandlingCallback = errorHandlingCallback(callback, LOGGER); - executeRetryableReadAsync(binding, namespace.getDatabaseName(), + executeRetryableReadAsync(binding, operationContext, namespace.getDatabaseName(), getCommandCreator(), CommandResultDocumentCodec.create(decoder, "results"), asyncTransformer(), false, errHandlingCallback); } @@ -189,7 +191,7 @@ public ReadOperationSimple asExplainableOperation(final ExplainVer } private CommandReadOperation createExplainableOperation(final ExplainVerbosity explainVerbosity) { - return new CommandReadOperation<>(namespace.getDatabaseName(), getCommandName(), + return new ExplainCommandOperation<>(namespace.getDatabaseName(), getCommandName(), (operationContext, serverDescription, connectionDescription) -> { BsonDocument command = getCommandCreator().create(operationContext, serverDescription, connectionDescription); applyMaxTimeMS(operationContext.getTimeoutContext(), command); @@ -199,7 +201,7 @@ private CommandReadOperation createExplainableOperation(final Expl } private CommandReadTransformer> transformer() { - return (result, source, connection) -> + return (result, source, connection, operationContext) -> new MapReduceInlineResultsCursor<>( new SingleBatchCursor<>(BsonDocumentWrapperHelper.toList(result, "results"), 0, connection.getDescription().getServerAddress()), @@ -207,7 +209,7 @@ private CommandReadTransformer> transforme } private CommandReadTransformerAsync> asyncTransformer() { - return (result, source, connection) -> new MapReduceInlineResultsAsyncCursor<>( + return (result, source, connection, operationContext) -> new MapReduceInlineResultsAsyncCursor<>( new AsyncSingleBatchCursor<>(BsonDocumentWrapperHelper.toList(result, "results"), 0), MapReduceHelper.createStatistics(result)); } diff --git a/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java b/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java index 39ff2dab17f..517a4f81688 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java @@ -185,8 +185,8 @@ public String getCommandName() { } @Override - public BulkWriteResult execute(final WriteBinding binding) { - TimeoutContext timeoutContext = binding.getOperationContext().getTimeoutContext(); + public BulkWriteResult execute(final WriteBinding binding, final OperationContext operationContext) { + TimeoutContext timeoutContext = operationContext.getTimeoutContext(); /* We cannot use the tracking of attempts built in the `RetryState` class because conceptually we have to maintain multiple attempt * counters while executing a single bulk write operation: * - a counter that limits attempts to select server and checkout a connection before we created a batch; @@ -196,24 +196,26 @@ public BulkWriteResult execute(final WriteBinding binding) { * and the code related to the attempt tracking in `BulkWriteTracker` will be removed. */ RetryState retryState = new RetryState(timeoutContext); BulkWriteTracker.attachNew(retryState, retryWrites, timeoutContext); - Supplier retryingBulkWrite = decorateWriteWithRetries(retryState, binding.getOperationContext(), () -> - withSourceAndConnection(binding::getWriteConnectionSource, true, (source, connection) -> { + Supplier retryingBulkWrite = decorateWriteWithRetries(retryState, operationContext, () -> + withSourceAndConnection(binding::getWriteConnectionSource, true, (source, connection, operationContextWithMinRTT) -> { + TimeoutContext timeoutContextWithMinRtt = operationContextWithMinRTT.getTimeoutContext(); ConnectionDescription connectionDescription = connection.getDescription(); // attach `maxWireVersion` ASAP because it is used to check whether we can retry retryState.attach(AttachmentKeys.maxWireVersion(), connectionDescription.getMaxWireVersion(), true); - SessionContext sessionContext = binding.getOperationContext().getSessionContext(); + SessionContext sessionContext = operationContext.getSessionContext(); WriteConcern writeConcern = validateAndGetEffectiveWriteConcern(this.writeConcern, sessionContext); if (!isRetryableWrite(retryWrites, writeConcern, connectionDescription, sessionContext)) { - handleMongoWriteConcernWithResponseException(retryState, true, timeoutContext); + handleMongoWriteConcernWithResponseException(retryState, true, timeoutContextWithMinRtt); } validateWriteRequests(connectionDescription, bypassDocumentValidation, writeRequests, writeConcern); if (!retryState.attachment(AttachmentKeys.bulkWriteTracker()).orElseThrow(Assertions::fail).batch().isPresent()) { BulkWriteTracker.attachNew(retryState, BulkWriteBatch.createBulkWriteBatch(namespace, connectionDescription, ordered, writeConcern, - bypassDocumentValidation, retryWrites, writeRequests, binding.getOperationContext(), comment, variables), timeoutContext); + bypassDocumentValidation, retryWrites, writeRequests, operationContextWithMinRTT, comment, variables), + timeoutContextWithMinRtt); } - return executeBulkWriteBatch(retryState, writeConcern, binding, connection); - }) + return executeBulkWriteBatch(retryState, writeConcern, binding, operationContextWithMinRTT, connection); + }, operationContext) ); try { return retryingBulkWrite.get(); @@ -222,24 +224,26 @@ public BulkWriteResult execute(final WriteBinding binding) { } } - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - TimeoutContext timeoutContext = binding.getOperationContext().getTimeoutContext(); + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + TimeoutContext timeoutContext = operationContext.getTimeoutContext(); // see the comment in `execute(WriteBinding)` explaining the manual tracking of attempts RetryState retryState = new RetryState(timeoutContext); BulkWriteTracker.attachNew(retryState, retryWrites, timeoutContext); binding.retain(); AsyncCallbackSupplier retryingBulkWrite = this.decorateWriteWithRetries(retryState, - binding.getOperationContext(), + operationContext, funcCallback -> - withAsyncSourceAndConnection(binding::getWriteConnectionSource, true, funcCallback, - (source, connection, releasingCallback) -> { + withAsyncSourceAndConnection(binding::getWriteConnectionSource, true, operationContext, funcCallback, + (source, connection, operationContextWithMinRtt, releasingCallback) -> { + TimeoutContext timeoutContextWithMinRtt = operationContextWithMinRtt.getTimeoutContext(); ConnectionDescription connectionDescription = connection.getDescription(); + // attach `maxWireVersion` ASAP because it is used to check whether we can retry retryState.attach(AttachmentKeys.maxWireVersion(), connectionDescription.getMaxWireVersion(), true); - SessionContext sessionContext = binding.getOperationContext().getSessionContext(); + SessionContext sessionContext = operationContextWithMinRtt.getSessionContext(); WriteConcern writeConcern = validateAndGetEffectiveWriteConcern(this.writeConcern, sessionContext); - if (!isRetryableWrite(retryWrites, writeConcern, connectionDescription, sessionContext) - && handleMongoWriteConcernWithResponseExceptionAsync(retryState, releasingCallback, timeoutContext)) { + if (!isRetryableWrite(retryWrites, writeConcern, connectionDescription, sessionContext) + && handleMongoWriteConcernWithResponseExceptionAsync(retryState, releasingCallback, timeoutContextWithMinRtt)) { return; } if (validateWriteRequestsAndCompleteIfInvalid(connectionDescription, bypassDocumentValidation, writeRequests, @@ -250,13 +254,13 @@ && handleMongoWriteConcernWithResponseExceptionAsync(retryState, releasingCallba if (!retryState.attachment(AttachmentKeys.bulkWriteTracker()).orElseThrow(Assertions::fail).batch().isPresent()) { BulkWriteTracker.attachNew(retryState, BulkWriteBatch.createBulkWriteBatch(namespace, connectionDescription, ordered, writeConcern, - bypassDocumentValidation, retryWrites, writeRequests, binding.getOperationContext(), comment, variables), timeoutContext); + bypassDocumentValidation, retryWrites, writeRequests, operationContextWithMinRtt, comment, variables), timeoutContextWithMinRtt); } } catch (Throwable t) { releasingCallback.onResult(null, t); return; } - executeBulkWriteBatchAsync(retryState, writeConcern, binding, connection, releasingCallback); + executeBulkWriteBatchAsync(retryState, writeConcern, binding, operationContextWithMinRtt, connection, releasingCallback); }) ).whenComplete(binding::release); retryingBulkWrite.get(exceptionTransformingCallback(errorHandlingCallback(callback, LOGGER))); @@ -266,12 +270,12 @@ private BulkWriteResult executeBulkWriteBatch( final RetryState retryState, final WriteConcern effectiveWriteConcern, final WriteBinding binding, + final OperationContext operationContext, final Connection connection) { BulkWriteTracker currentBulkWriteTracker = retryState.attachment(AttachmentKeys.bulkWriteTracker()) .orElseThrow(Assertions::fail); BulkWriteBatch currentBatch = currentBulkWriteTracker.batch().orElseThrow(Assertions::fail); int maxWireVersion = connection.getDescription().getMaxWireVersion(); - OperationContext operationContext = binding.getOperationContext(); TimeoutContext timeoutContext = operationContext.getTimeoutContext(); while (currentBatch.shouldProcessBatch()) { @@ -314,6 +318,7 @@ private void executeBulkWriteBatchAsync( final RetryState retryState, final WriteConcern effectiveWriteConcern, final AsyncWriteBinding binding, + final OperationContext operationContext, final AsyncConnection connection, final SingleResultCallback callback) { LoopState loopState = new LoopState(); @@ -326,13 +331,13 @@ private void executeBulkWriteBatchAsync( if (loopState.breakAndCompleteIf(() -> !currentBatch.shouldProcessBatch(), iterationCallback)) { return; } - OperationContext operationContext = binding.getOperationContext(); + TimeoutContext timeoutContext = operationContext.getTimeoutContext(); executeCommandAsync(effectiveWriteConcern, operationContext, connection, currentBatch, (result, t) -> { if (t == null) { if (currentBatch.getRetryWrites() && !operationContext.getSessionContext().hasActiveTransaction()) { MongoException writeConcernBasedError = ProtocolHelper.createSpecialException(result, - connection.getDescription().getServerAddress(), "errMsg", binding.getOperationContext().getTimeoutContext()); + connection.getDescription().getServerAddress(), "errMsg", operationContext.getTimeoutContext()); if (writeConcernBasedError != null) { if (currentBulkWriteTracker.lastAttempt()) { addRetryableWriteErrorLabel(writeConcernBasedError, maxWireVersion); diff --git a/driver-core/src/main/com/mongodb/internal/operation/OperationHelper.java b/driver-core/src/main/com/mongodb/internal/operation/OperationHelper.java index f980d309a8a..26dbdc60126 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/OperationHelper.java +++ b/driver-core/src/main/com/mongodb/internal/operation/OperationHelper.java @@ -25,6 +25,7 @@ import com.mongodb.connection.ConnectionDescription; import com.mongodb.connection.ServerDescription; import com.mongodb.connection.ServerType; +import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.async.function.AsyncCallbackFunction; import com.mongodb.internal.async.function.AsyncCallbackSupplier; @@ -192,10 +193,12 @@ static boolean canRetryRead(final ServerDescription serverDescription, final Ope return true; } - static void setNonTailableCursorMaxTimeSupplier(final TimeoutMode timeoutMode, final OperationContext operationContext) { + static OperationContext applyTimeoutModeToOperationContext(final TimeoutMode timeoutMode, + final OperationContext operationContext) { if (timeoutMode == TimeoutMode.ITERATION) { - operationContext.getTimeoutContext().disableMaxTimeOverride(); + return operationContext.withTimeoutContextOverride(TimeoutContext::withDisabledMaxTimeOverride); } + return operationContext; } /** diff --git a/driver-core/src/main/com/mongodb/internal/operation/ReadOperation.java b/driver-core/src/main/com/mongodb/internal/operation/ReadOperation.java index 6a90d490b30..8545a8a4c53 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ReadOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ReadOperation.java @@ -19,6 +19,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; +import com.mongodb.internal.connection.OperationContext; /** * An operation that reads from a MongoDB server. @@ -36,15 +37,17 @@ public interface ReadOperation { * General execute which can return anything of type T * * @param binding the binding to execute in the context of + * @param operationContext the operation context * @return T, the result of the execution */ - T execute(ReadBinding binding); + T execute(ReadBinding binding, OperationContext operationContext); /** * General execute which can return anything of type R * * @param binding the binding to execute in the context of + * @param operationContext the operation context * @param callback the callback to be called when the operation has been executed */ - void executeAsync(AsyncReadBinding binding, SingleResultCallback callback); + void executeAsync(AsyncReadBinding binding, OperationContext operationContext, SingleResultCallback callback); } diff --git a/driver-core/src/main/com/mongodb/internal/operation/RenameCollectionOperation.java b/driver-core/src/main/com/mongodb/internal/operation/RenameCollectionOperation.java index ea477bf67bd..02beeabcfd8 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/RenameCollectionOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/RenameCollectionOperation.java @@ -21,6 +21,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.lang.Nullable; import org.bson.BsonBoolean; import org.bson.BsonDocument; @@ -80,20 +81,22 @@ public String getCommandName() { } @Override - public Void execute(final WriteBinding binding) { - return withConnection(binding, connection -> executeCommand(binding, "admin", getCommand(), connection, - writeConcernErrorTransformer(binding.getOperationContext().getTimeoutContext()))); + public Void execute(final WriteBinding binding, final OperationContext operationContext) { + return withConnection(binding, operationContext, (connection, operationContextWithMinRtt) -> + executeCommand(binding, + operationContextWithMinRtt, "admin", getCommand(), connection, + writeConcernErrorTransformer(operationContextWithMinRtt.getTimeoutContext()))); } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - withAsyncConnection(binding, (connection, t) -> { + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + withAsyncConnection(binding, operationContext, (connection, operationContextWithMinRtt, t) -> { SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); if (t != null) { errHandlingCallback.onResult(null, t); } else { - executeCommandAsync(binding, "admin", getCommand(), assertNotNull(connection), - writeConcernErrorTransformerAsync(binding.getOperationContext().getTimeoutContext()), + executeCommandAsync(binding, operationContextWithMinRtt, "admin", getCommand(), assertNotNull(connection), + writeConcernErrorTransformerAsync(operationContextWithMinRtt.getTimeoutContext()), releasingCallback(errHandlingCallback, connection)); } }); diff --git a/driver-core/src/main/com/mongodb/internal/operation/SyncOperationHelper.java b/driver-core/src/main/com/mongodb/internal/operation/SyncOperationHelper.java index 6d013df59ba..8a7fa630d00 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/SyncOperationHelper.java +++ b/driver-core/src/main/com/mongodb/internal/operation/SyncOperationHelper.java @@ -21,10 +21,6 @@ import com.mongodb.client.cursor.TimeoutMode; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.VisibleForTesting; -import com.mongodb.internal.async.SingleResultCallback; -import com.mongodb.internal.async.function.AsyncCallbackBiFunction; -import com.mongodb.internal.async.function.AsyncCallbackFunction; -import com.mongodb.internal.async.function.AsyncCallbackSupplier; import com.mongodb.internal.async.function.RetryState; import com.mongodb.internal.async.function.RetryingSyncSupplier; import com.mongodb.internal.binding.ConnectionSource; @@ -43,7 +39,6 @@ import org.bson.codecs.BsonDocumentCodec; import org.bson.codecs.Decoder; -import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -64,11 +59,16 @@ final class SyncOperationHelper { interface CallableWithConnection { - T call(Connection connection); + T call(Connection connection, OperationContext operationContext); } interface CallableWithSource { - T call(ConnectionSource source); + T call(ConnectionSource source, OperationContext operationContext); + } + + @FunctionalInterface + interface ExecutionFunction { + R apply(ConnectionSource source, Connection connection, OperationContext operationContext); } interface CommandReadTransformer { @@ -80,7 +80,7 @@ interface CommandReadTransformer { * @return the function result */ @Nullable - R apply(T t, ConnectionSource source, Connection connection); + R apply(T t, ConnectionSource source, Connection connection, OperationContext operationContext); } interface CommandWriteTransformer { @@ -97,38 +97,56 @@ interface CommandWriteTransformer { private static final BsonDocumentCodec BSON_DOCUMENT_CODEC = new BsonDocumentCodec(); - static T withReadConnectionSource(final ReadBinding binding, final CallableWithSource callable) { - ConnectionSource source = binding.getReadConnectionSource(); + static T withReadConnectionSource(final ReadBinding binding, + final OperationContext operationContext, + final CallableWithSource callable) { + OperationContext serverSelectionOperationContext = + operationContext.withTimeoutContextOverride(TimeoutContext::withComputedServerSelectionTimeoutContextNew); + ConnectionSource source = binding.getReadConnectionSource(serverSelectionOperationContext); try { - return callable.call(source); + return callable.call(source, operationContext.withMinRoundTripTime(source.getServerDescription())); } finally { source.release(); } } - static T withConnection(final WriteBinding binding, final CallableWithConnection callable) { - ConnectionSource source = binding.getWriteConnectionSource(); - try { - return withConnectionSource(source, callable); - } finally { - source.release(); - } + static T withConnection(final WriteBinding binding, + final OperationContext operationContext, + final CallableWithConnection callable) { + return withSourceAndConnection( + binding::getWriteConnectionSource, + false, + (source, connection, operationContextWithMinRtt) -> + callable.call(connection, operationContextWithMinRtt), + operationContext); } /** * Gets a {@link ConnectionSource} and a {@link Connection} from the {@code sourceSupplier} and executes the {@code function} with them. * Guarantees to {@linkplain ReferenceCounted#release() release} the source and the connection after completion of the {@code function}. * - * @param wrapConnectionSourceException See {@link #withSuppliedResource(Supplier, boolean, Function)}. - * @see #withSuppliedResource(Supplier, boolean, Function) - * @see AsyncOperationHelper#withAsyncSourceAndConnection(AsyncCallbackSupplier, boolean, SingleResultCallback, AsyncCallbackBiFunction) + * */ - static R withSourceAndConnection(final Supplier sourceSupplier, - final boolean wrapConnectionSourceException, - final BiFunction function) throws ResourceSupplierInternalException { - return withSuppliedResource(sourceSupplier, wrapConnectionSourceException, source -> - withSuppliedResource(source::getConnection, wrapConnectionSourceException, connection -> - function.apply(source, connection))); + static R withSourceAndConnection(final Function sourceFunction, + final boolean wrapConnectionSourceException, + final ExecutionFunction function, + final OperationContext originalOperationContext) throws ResourceSupplierInternalException { + OperationContext serverSelectionOperationContext = + originalOperationContext.withTimeoutContextOverride(TimeoutContext::withComputedServerSelectionTimeoutContextNew); + + return withSuppliedResource( + sourceFunction, + wrapConnectionSourceException, + serverSelectionOperationContext, + source -> withSuppliedResource( + source::getConnection, + wrapConnectionSourceException, + serverSelectionOperationContext.withMinRoundTripTime(source.getServerDescription()), + connection -> function.apply( + source, + connection, + originalOperationContext.withMinRoundTripTime(source.getServerDescription()))) + ); } /** @@ -138,14 +156,16 @@ static R withSourceAndConnection(final Supplier sourceSupp * @param wrapSupplierException If {@code true} and {@code resourceSupplier} completes abruptly, then the exception is wrapped * into {@link OperationHelper.ResourceSupplierInternalException}, such that it can be accessed * via {@link OperationHelper.ResourceSupplierInternalException#getCause()}. - * @see AsyncOperationHelper#withAsyncSuppliedResource(AsyncCallbackSupplier, boolean, SingleResultCallback, AsyncCallbackFunction) */ - static R withSuppliedResource(final Supplier resourceSupplier, - final boolean wrapSupplierException, final Function function) throws OperationHelper.ResourceSupplierInternalException { + static R withSuppliedResource(final Function resourceSupplier, + final boolean wrapSupplierException, + final OperationContext operationContext, + final Function function) + throws OperationHelper.ResourceSupplierInternalException { T resource = null; try { try { - resource = resourceSupplier.get(); + resource = resourceSupplier.apply(operationContext); } catch (Exception supplierException) { if (wrapSupplierException) { throw new ResourceSupplierInternalException(supplierException); @@ -161,80 +181,77 @@ static R withSuppliedResource(final Supplier } } - private static T withConnectionSource(final ConnectionSource source, final CallableWithConnection callable) { - Connection connection = source.getConnection(); - try { - return callable.call(connection); - } finally { - connection.release(); - } - } - static T executeRetryableRead( final ReadBinding binding, + final OperationContext operationContext, final String database, final CommandCreator commandCreator, final Decoder decoder, final CommandReadTransformer transformer, final boolean retryReads) { - return executeRetryableRead(binding, binding::getReadConnectionSource, database, commandCreator, + return executeRetryableRead(operationContext, binding::getReadConnectionSource, database, commandCreator, decoder, transformer, retryReads); } static T executeRetryableRead( - final ReadBinding binding, - final Supplier readConnectionSourceSupplier, + final OperationContext operationContext, + final Function readConnectionSourceSupplier, final String database, final CommandCreator commandCreator, final Decoder decoder, final CommandReadTransformer transformer, final boolean retryReads) { - RetryState retryState = CommandOperationHelper.initialRetryState(retryReads, binding.getOperationContext().getTimeoutContext()); + RetryState retryState = CommandOperationHelper.initialRetryState(retryReads, operationContext.getTimeoutContext()); - Supplier read = decorateReadWithRetries(retryState, binding.getOperationContext(), () -> - withSourceAndConnection(readConnectionSourceSupplier, false, (source, connection) -> { - retryState.breakAndThrowIfRetryAnd(() -> !canRetryRead(source.getServerDescription(), binding.getOperationContext())); - return createReadCommandAndExecute(retryState, binding.getOperationContext(), source, database, + Supplier read = decorateReadWithRetries(retryState, operationContext, () -> + withSourceAndConnection(readConnectionSourceSupplier, false, (source, connection, operationContextWithMinRtt) -> { + retryState.breakAndThrowIfRetryAnd(() -> !canRetryRead(source.getServerDescription(), operationContextWithMinRtt)); + return createReadCommandAndExecute(retryState, operationContextWithMinRtt, source, database, commandCreator, decoder, transformer, connection); - }) + }, operationContext) ); return read.get(); } @VisibleForTesting(otherwise = PRIVATE) - static T executeCommand(final WriteBinding binding, final String database, final CommandCreator commandCreator, + static T executeCommand(final WriteBinding binding, final OperationContext operationContext, final String database, + final CommandCreator commandCreator, final CommandWriteTransformer transformer) { - return withSourceAndConnection(binding::getWriteConnectionSource, false, (source, connection) -> + return withSourceAndConnection(binding::getWriteConnectionSource, false, (source, connection, operationContextWithMinRtt) -> transformer.apply(assertNotNull( connection.command(database, - commandCreator.create(binding.getOperationContext(), + commandCreator.create(operationContextWithMinRtt, source.getServerDescription(), connection.getDescription()), - NoOpFieldNameValidator.INSTANCE, primary(), BSON_DOCUMENT_CODEC, binding.getOperationContext())), - connection)); + NoOpFieldNameValidator.INSTANCE, primary(), BSON_DOCUMENT_CODEC, operationContextWithMinRtt)), + connection), operationContext); } @VisibleForTesting(otherwise = PRIVATE) - static T executeCommand(final WriteBinding binding, final String database, final BsonDocument command, + static T executeCommand(final WriteBinding binding, final OperationContext operationContext, final String database, + final BsonDocument command, final Decoder decoder, final CommandWriteTransformer transformer) { - return withSourceAndConnection(binding::getWriteConnectionSource, false, (source, connection) -> + return withSourceAndConnection(binding::getWriteConnectionSource, false, (source, connection, operationContextWithMinRtt) -> transformer.apply(assertNotNull( connection.command(database, command, NoOpFieldNameValidator.INSTANCE, primary(), decoder, - binding.getOperationContext())), connection)); + operationContextWithMinRtt)), connection), + operationContext); } @Nullable - static T executeCommand(final WriteBinding binding, final String database, final BsonDocument command, + static T executeCommand(final WriteBinding binding, final OperationContext operationContext, final String database, + final BsonDocument command, final Connection connection, final CommandWriteTransformer transformer) { notNull("binding", binding); return transformer.apply(assertNotNull( connection.command(database, command, NoOpFieldNameValidator.INSTANCE, primary(), BSON_DOCUMENT_CODEC, - binding.getOperationContext())), + operationContext)), connection); } static R executeRetryableWrite( final WriteBinding binding, + final OperationContext operationContext, final String database, @Nullable final ReadPreference readPreference, final FieldNameValidator fieldNameValidator, @@ -242,14 +259,14 @@ static R executeRetryableWrite( final CommandCreator commandCreator, final CommandWriteTransformer transformer, final com.mongodb.Function retryCommandModifier) { - RetryState retryState = CommandOperationHelper.initialRetryState(true, binding.getOperationContext().getTimeoutContext()); - Supplier retryingWrite = decorateWriteWithRetries(retryState, binding.getOperationContext(), () -> { + RetryState retryState = CommandOperationHelper.initialRetryState(true, operationContext.getTimeoutContext()); + Supplier retryingWrite = decorateWriteWithRetries(retryState, operationContext, () -> { boolean firstAttempt = retryState.isFirstAttempt(); - SessionContext sessionContext = binding.getOperationContext().getSessionContext(); + SessionContext sessionContext = operationContext.getSessionContext(); if (!firstAttempt && sessionContext.hasActiveTransaction()) { sessionContext.clearTransactionContext(); } - return withSourceAndConnection(binding::getWriteConnectionSource, true, (source, connection) -> { + return withSourceAndConnection(binding::getWriteConnectionSource, true, (source, connection, operationContextWithMinRtt) -> { int maxWireVersion = connection.getDescription().getMaxWireVersion(); try { retryState.breakAndThrowIfRetryAnd(() -> !canRetryWrite(connection.getDescription(), sessionContext)); @@ -257,7 +274,7 @@ static R executeRetryableWrite( .map(previousAttemptCommand -> { assertFalse(firstAttempt); return retryCommandModifier.apply(previousAttemptCommand); - }).orElseGet(() -> commandCreator.create(binding.getOperationContext(), source.getServerDescription(), + }).orElseGet(() -> commandCreator.create(operationContextWithMinRtt, source.getServerDescription(), connection.getDescription())); // attach `maxWireVersion`, `retryableCommandFlag` ASAP because they are used to check whether we should retry retryState.attach(AttachmentKeys.maxWireVersion(), maxWireVersion, true) @@ -265,7 +282,7 @@ static R executeRetryableWrite( .attach(AttachmentKeys.commandDescriptionSupplier(), command::getFirstKey, false) .attach(AttachmentKeys.command(), command, false); return transformer.apply(assertNotNull(connection.command(database, command, fieldNameValidator, readPreference, - commandResultDecoder, binding.getOperationContext())), + commandResultDecoder, operationContextWithMinRtt)), connection); } catch (MongoException e) { if (!firstAttempt) { @@ -273,7 +290,7 @@ static R executeRetryableWrite( } throw e; } - }); + }, operationContext); }); try { return retryingWrite.get(); @@ -295,8 +312,11 @@ static T createReadCommandAndExecute( BsonDocument command = commandCreator.create(operationContext, source.getServerDescription(), connection.getDescription()); retryState.attach(AttachmentKeys.commandDescriptionSupplier(), command::getFirstKey, false); - return transformer.apply(assertNotNull(connection.command(database, command, NoOpFieldNameValidator.INSTANCE, - source.getReadPreference(), decoder, operationContext)), source, connection); + + D result = assertNotNull(connection.command(database, command, NoOpFieldNameValidator.INSTANCE, + source.getReadPreference(), decoder, operationContext)); + + return transformer.apply(result, source, connection, operationContext); } @@ -329,15 +349,19 @@ static CommandWriteTransformer writeConcernErrorTransformer( } static CommandReadTransformer> singleBatchCursorTransformer(final String fieldName) { - return (result, source, connection) -> + return (result, source, connection, operationContext) -> new SingleBatchCursor<>(BsonDocumentWrapperHelper.toList(result, fieldName), 0, connection.getDescription().getServerAddress()); } static CommandBatchCursor cursorDocumentToBatchCursor(final TimeoutMode timeoutMode, final BsonDocument cursorDocument, - final int batchSize, final Decoder decoder, @Nullable final BsonValue comment, final ConnectionSource source, - final Connection connection) { - return new CommandBatchCursor<>(timeoutMode, cursorDocument, batchSize, 0, decoder, comment, source, connection); + final int batchSize, final Decoder decoder, + @Nullable final BsonValue comment, final ConnectionSource source, + final Connection connection, final OperationContext operationContext) { + return new CommandBatchCursor<>(timeoutMode, 0, operationContext, new CommandCoreCursor<>( + cursorDocument, batchSize, decoder, comment, source, connection + )); + } private SyncOperationHelper() { diff --git a/driver-core/src/main/com/mongodb/internal/operation/TransactionOperation.java b/driver-core/src/main/com/mongodb/internal/operation/TransactionOperation.java index a15a2aa88e3..703d440fb04 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/TransactionOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/TransactionOperation.java @@ -22,6 +22,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.validator.NoOpFieldNameValidator; import org.bson.BsonDocument; import org.bson.BsonInt32; @@ -54,19 +55,19 @@ public WriteConcern getWriteConcern() { } @Override - public Void execute(final WriteBinding binding) { - isTrue("in transaction", binding.getOperationContext().getSessionContext().hasActiveTransaction()); - TimeoutContext timeoutContext = binding.getOperationContext().getTimeoutContext(); - return executeRetryableWrite(binding, "admin", null, NoOpFieldNameValidator.INSTANCE, + public Void execute(final WriteBinding binding, final OperationContext operationContext) { + isTrue("in transaction", operationContext.getSessionContext().hasActiveTransaction()); + TimeoutContext timeoutContext = operationContext.getTimeoutContext(); + return executeRetryableWrite(binding, operationContext, "admin", null, NoOpFieldNameValidator.INSTANCE, new BsonDocumentCodec(), getCommandCreator(), writeConcernErrorTransformer(timeoutContext), getRetryCommandModifier(timeoutContext)); } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - isTrue("in transaction", binding.getOperationContext().getSessionContext().hasActiveTransaction()); - TimeoutContext timeoutContext = binding.getOperationContext().getTimeoutContext(); - executeRetryableWriteAsync(binding, "admin", null, NoOpFieldNameValidator.INSTANCE, + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + isTrue("in transaction", operationContext.getSessionContext().hasActiveTransaction()); + TimeoutContext timeoutContext = operationContext.getTimeoutContext(); + executeRetryableWriteAsync(binding, operationContext, "admin", null, NoOpFieldNameValidator.INSTANCE, new BsonDocumentCodec(), getCommandCreator(), writeConcernErrorTransformerAsync(timeoutContext), getRetryCommandModifier(timeoutContext), errorHandlingCallback(callback, LOGGER)); diff --git a/driver-core/src/main/com/mongodb/internal/operation/WriteOperation.java b/driver-core/src/main/com/mongodb/internal/operation/WriteOperation.java index 73cec2f416b..07a84c74d39 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/WriteOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/WriteOperation.java @@ -19,6 +19,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; +import com.mongodb.internal.connection.OperationContext; /** * An operation which writes to a MongoDB server. @@ -32,19 +33,21 @@ public interface WriteOperation { */ String getCommandName(); + //TODO javadoc /** * General execute which can return anything of type T * * @param binding the binding to execute in the context of * @return T, the result of the execution */ - T execute(WriteBinding binding); + T execute(WriteBinding binding, OperationContext operationContext); + //TODO javadoc /** * General execute which can return anything of type T * * @param binding the binding to execute in the context of * @param callback the callback to be called when the operation has been executed */ - void executeAsync(AsyncWriteBinding binding, SingleResultCallback callback); + void executeAsync(AsyncWriteBinding binding, OperationContext operationContext, SingleResultCallback callback); } diff --git a/driver-core/src/main/com/mongodb/internal/session/BaseClientSessionImpl.java b/driver-core/src/main/com/mongodb/internal/session/BaseClientSessionImpl.java index 80f88cc08f5..48d5efc9f57 100644 --- a/driver-core/src/main/com/mongodb/internal/session/BaseClientSessionImpl.java +++ b/driver-core/src/main/com/mongodb/internal/session/BaseClientSessionImpl.java @@ -222,7 +222,7 @@ protected void setTimeoutContext(@Nullable final TimeoutContext timeoutContext) protected void resetTimeout() { if (timeoutContext != null) { - timeoutContext.resetTimeoutIfPresent(); + timeoutContext = timeoutContext.withNewlyStartedTimeout(); } } diff --git a/driver-core/src/test/functional/com/mongodb/ClusterFixture.java b/driver-core/src/test/functional/com/mongodb/ClusterFixture.java index 6bbf9233cb1..9d482222782 100644 --- a/driver-core/src/test/functional/com/mongodb/ClusterFixture.java +++ b/driver-core/src/test/functional/com/mongodb/ClusterFixture.java @@ -39,14 +39,12 @@ import com.mongodb.internal.binding.AsyncOperationContextBinding; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.AsyncReadWriteBinding; -import com.mongodb.internal.binding.AsyncSessionBinding; import com.mongodb.internal.binding.AsyncSingleConnectionBinding; -import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.ClusterBinding; import com.mongodb.internal.binding.OperationContextBinding; import com.mongodb.internal.binding.ReadWriteBinding; import com.mongodb.internal.binding.ReferenceCounted; -import com.mongodb.internal.binding.SessionBinding; +import com.mongodb.internal.binding.SimpleSessionContext; import com.mongodb.internal.binding.SingleConnectionBinding; import com.mongodb.internal.connection.AsyncConnection; import com.mongodb.internal.connection.AsynchronousSocketChannelStreamFactory; @@ -145,6 +143,8 @@ public final class ClusterFixture { private static Cluster cluster; private static Cluster asyncCluster; private static final Map BINDING_MAP = new HashMap<>(); + private static final Map SESSION_CONTEXT_MAP = new HashMap<>(); + private static final Map ASYNC_SESSION_CONTEXT_MAP = new HashMap<>(); private static final Map ASYNC_BINDING_MAP = new HashMap<>(); private static ServerVersion serverVersion; @@ -184,7 +184,7 @@ public static ServerVersion getServerVersion() { if (serverVersion == null) { serverVersion = getVersion(new CommandReadOperation<>("admin", new BsonDocument("buildInfo", new BsonInt32(1)), new BsonDocumentCodec()) - .execute(new ClusterBinding(getCluster(), ReadPreference.nearest(), ReadConcern.DEFAULT, OPERATION_CONTEXT))); + .execute(new ClusterBinding(getCluster(), ReadPreference.nearest()), OPERATION_CONTEXT)); } return serverVersion; } @@ -246,7 +246,7 @@ public static boolean hasEncryptionTestsEnabled() { public static Document getServerStatus() { return new CommandReadOperation<>("admin", new BsonDocument("serverStatus", new BsonInt32(1)), new DocumentCodec()) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } public static boolean supportsFsync() { @@ -261,7 +261,7 @@ static class ShutdownHook extends Thread { public void run() { if (cluster != null) { try { - new DropDatabaseOperation(getDefaultDatabaseName(), WriteConcern.ACKNOWLEDGED).execute(getBinding()); + new DropDatabaseOperation(getDefaultDatabaseName(), WriteConcern.ACKNOWLEDGED).execute(getBinding(), OPERATION_CONTEXT); } catch (MongoCommandException e) { // if we do not have permission to drop the database, assume it is cleaned up in some other way if (!e.getMessage().contains("Command dropDatabase requires authentication")) { @@ -313,7 +313,7 @@ public static synchronized ConnectionString getConnectionString() { try { BsonDocument helloResult = new CommandReadOperation<>("admin", new BsonDocument(LEGACY_HELLO, new BsonInt32(1)), new BsonDocumentCodec()) - .execute(new ClusterBinding(cluster, ReadPreference.nearest(), ReadConcern.DEFAULT, OPERATION_CONTEXT)); + .execute(new ClusterBinding(cluster, ReadPreference.nearest()), OPERATION_CONTEXT); if (helloResult.containsKey("setName")) { connectionString = new ConnectionString(DEFAULT_URI + "/?replicaSet=" + helloResult.getString("setName").getValue()); @@ -361,7 +361,7 @@ public static ReadWriteBinding getBinding() { } public static ReadWriteBinding getBinding(final Cluster cluster) { - return new ClusterBinding(cluster, ReadPreference.primary(), ReadConcern.DEFAULT, OPERATION_CONTEXT); + return new ClusterBinding(cluster, ReadPreference.primary()); } public static ReadWriteBinding getBinding(final TimeoutSettings timeoutSettings) { @@ -384,13 +384,13 @@ private static ReadWriteBinding getBinding(final Cluster cluster, final ReadPreference readPreference, final OperationContext operationContext) { if (!BINDING_MAP.containsKey(readPreference)) { - ReadWriteBinding binding = new SessionBinding(new ClusterBinding(cluster, readPreference, ReadConcern.DEFAULT, - operationContext)); + ReadWriteBinding binding = new ClusterBinding(cluster, readPreference); BINDING_MAP.put(readPreference, binding); + SESSION_CONTEXT_MAP.put(readPreference, new SimpleSessionContext()); } ReadWriteBinding readWriteBinding = BINDING_MAP.get(readPreference); return new OperationContextBinding(readWriteBinding, - operationContext.withSessionContext(readWriteBinding.getOperationContext().getSessionContext())); + operationContext.withSessionContext(SESSION_CONTEXT_MAP.get(readPreference))); } public static SingleConnectionBinding getSingleConnectionBinding() { @@ -406,7 +406,7 @@ public static AsyncSingleConnectionBinding getAsyncSingleConnectionBinding(final } public static AsyncReadWriteBinding getAsyncBinding(final Cluster cluster) { - return new AsyncClusterBinding(cluster, ReadPreference.primary(), ReadConcern.DEFAULT, OPERATION_CONTEXT); + return new AsyncClusterBinding(cluster, ReadPreference.primary()); } public static AsyncReadWriteBinding getAsyncBinding() { @@ -430,13 +430,13 @@ public static AsyncReadWriteBinding getAsyncBinding( final ReadPreference readPreference, final OperationContext operationContext) { if (!ASYNC_BINDING_MAP.containsKey(readPreference)) { - AsyncReadWriteBinding binding = new AsyncSessionBinding(new AsyncClusterBinding(cluster, readPreference, ReadConcern.DEFAULT, - operationContext)); + AsyncReadWriteBinding binding = new AsyncClusterBinding(cluster, readPreference); ASYNC_BINDING_MAP.put(readPreference, binding); + ASYNC_SESSION_CONTEXT_MAP.put(readPreference, new SimpleSessionContext()); } AsyncReadWriteBinding readWriteBinding = ASYNC_BINDING_MAP.get(readPreference); return new AsyncOperationContextBinding(readWriteBinding, - operationContext.withSessionContext(readWriteBinding.getOperationContext().getSessionContext())); + operationContext.withSessionContext(ASYNC_SESSION_CONTEXT_MAP.get(readPreference))); } public static synchronized Cluster getCluster() { @@ -596,7 +596,7 @@ public static BsonDocument getServerParameters() { if (serverParameters == null) { serverParameters = new CommandReadOperation<>("admin", new BsonDocument("getParameter", new BsonString("*")), new BsonDocumentCodec()) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } return serverParameters; } @@ -660,7 +660,7 @@ public static void configureFailPoint(final BsonDocument failPointDocument) { if (!isSharded()) { try { new CommandReadOperation<>("admin", failPointDocument, new BsonDocumentCodec()) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } catch (MongoCommandException e) { if (e.getErrorCode() == COMMAND_NOT_FOUND_ERROR_CODE) { failsPointsSupported = false; @@ -676,7 +676,7 @@ public static void disableFailPoint(final String failPoint) { .append("mode", new BsonString("off")); try { new CommandReadOperation<>("admin", failPointDocument, new BsonDocumentCodec()) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } catch (MongoCommandException e) { // ignore } @@ -690,7 +690,7 @@ public static T executeSync(final WriteOperation op) { @SuppressWarnings("overloads") public static T executeSync(final WriteOperation op, final ReadWriteBinding binding) { - return op.execute(binding); + return op.execute(binding, applySessionContext(OPERATION_CONTEXT, binding.getReadPreference())); } @SuppressWarnings("overloads") @@ -700,7 +700,12 @@ public static T executeSync(final ReadOperation op) { @SuppressWarnings("overloads") public static T executeSync(final ReadOperation op, final ReadWriteBinding binding) { - return op.execute(binding); + return op.execute(binding, OPERATION_CONTEXT); + } + + @SuppressWarnings("overloads") + public static T executeSync(final ReadOperation op, final ReadWriteBinding binding, final OperationContext operationContext) { + return op.execute(binding, operationContext); } @SuppressWarnings("overloads") @@ -709,9 +714,9 @@ public static T executeAsync(final WriteOperation op) throws Throwable { } @SuppressWarnings("overloads") - public static T executeAsync(final WriteOperation op, final AsyncWriteBinding binding) throws Throwable { + public static T executeAsync(final WriteOperation op, final AsyncReadWriteBinding binding) throws Throwable { FutureResultCallback futureResultCallback = new FutureResultCallback<>(); - op.executeAsync(binding, futureResultCallback); + op.executeAsync(binding, applySessionContext(OPERATION_CONTEXT, binding.getReadPreference()), futureResultCallback); return futureResultCallback.get(TIMEOUT, SECONDS); } @@ -723,7 +728,13 @@ public static T executeAsync(final ReadOperation op) throws Throwable @SuppressWarnings("overloads") public static T executeAsync(final ReadOperation op, final AsyncReadBinding binding) throws Throwable { FutureResultCallback futureResultCallback = new FutureResultCallback<>(); - op.executeAsync(binding, futureResultCallback); + op.executeAsync(binding, OPERATION_CONTEXT, futureResultCallback); + return futureResultCallback.get(TIMEOUT, SECONDS); + } + + public static T executeAsync(final ReadOperation op, final AsyncReadBinding binding, final OperationContext operationContext) throws Throwable { + FutureResultCallback futureResultCallback = new FutureResultCallback<>(); + op.executeAsync(binding, operationContext, futureResultCallback); return futureResultCallback.get(TIMEOUT, SECONDS); } @@ -787,19 +798,19 @@ public static List collectCursorResults(final BatchCursor batchCursor) public static AsyncConnectionSource getWriteConnectionSource(final AsyncReadWriteBinding binding) throws Throwable { FutureResultCallback futureResultCallback = new FutureResultCallback<>(); - binding.getWriteConnectionSource(futureResultCallback); + binding.getWriteConnectionSource(OPERATION_CONTEXT, futureResultCallback); return futureResultCallback.get(TIMEOUT, SECONDS); } public static AsyncConnectionSource getReadConnectionSource(final AsyncReadWriteBinding binding) throws Throwable { FutureResultCallback futureResultCallback = new FutureResultCallback<>(); - binding.getReadConnectionSource(futureResultCallback); + binding.getReadConnectionSource(OPERATION_CONTEXT, futureResultCallback); return futureResultCallback.get(TIMEOUT, SECONDS); } public static AsyncConnection getConnection(final AsyncConnectionSource source) throws Throwable { FutureResultCallback futureResultCallback = new FutureResultCallback<>(); - source.getConnection(futureResultCallback); + source.getConnection(OPERATION_CONTEXT, futureResultCallback); return futureResultCallback.get(TIMEOUT, SECONDS); } @@ -833,4 +844,16 @@ public static ClusterSettings.Builder setDirectConnection(final ClusterSettings. return builder.mode(ClusterConnectionMode.SINGLE).hosts(singletonList(getPrimary())); } + private static OperationContext applySessionContext(final OperationContext operationContext, final ReadPreference readPreference) { + SimpleSessionContext simpleSessionContext = SESSION_CONTEXT_MAP.get(readPreference); + if (simpleSessionContext == null) { + simpleSessionContext = new SimpleSessionContext(); + SESSION_CONTEXT_MAP.put(readPreference, simpleSessionContext); + } + return operationContext.withSessionContext(simpleSessionContext); + } + + public static OperationContext getOperationContext(final ReadPreference readPreference) { + return applySessionContext(OPERATION_CONTEXT, readPreference); + } } diff --git a/driver-core/src/test/functional/com/mongodb/OperationFunctionalSpecification.groovy b/driver-core/src/test/functional/com/mongodb/OperationFunctionalSpecification.groovy index dcefaaa65ba..6648edc50c7 100644 --- a/driver-core/src/test/functional/com/mongodb/OperationFunctionalSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/OperationFunctionalSpecification.groovy @@ -30,6 +30,7 @@ import com.mongodb.connection.ServerConnectionState import com.mongodb.connection.ServerDescription import com.mongodb.connection.ServerType import com.mongodb.connection.ServerVersion +import com.mongodb.internal.async.SingleResultCallback import com.mongodb.internal.binding.AsyncConnectionSource import com.mongodb.internal.binding.AsyncReadBinding import com.mongodb.internal.binding.AsyncReadWriteBinding @@ -45,6 +46,7 @@ import com.mongodb.internal.binding.WriteBinding import com.mongodb.internal.bulk.InsertRequest import com.mongodb.internal.connection.AsyncConnection import com.mongodb.internal.connection.Connection +import com.mongodb.internal.connection.OperationContext import com.mongodb.internal.connection.ServerHelper import com.mongodb.internal.connection.SplittablePayload import com.mongodb.internal.operation.MixedBulkWriteOperation @@ -106,7 +108,7 @@ class OperationFunctionalSpecification extends Specification { void acknowledgeWrite(final SingleConnectionBinding binding) { new MixedBulkWriteOperation(getNamespace(), [new InsertRequest(new BsonDocument())], true, - ACKNOWLEDGED, false).execute(binding) + ACKNOWLEDGED, false).execute(binding, OPERATION_CONTEXT) binding.release() } @@ -149,10 +151,18 @@ class OperationFunctionalSpecification extends Specification { ClusterFixture.executeSync(operation, binding) } + def execute(operation, ReadWriteBinding binding, OperationContext operationContext) { + ClusterFixture.executeSync(operation, binding, operationContext) + } + def execute(operation, AsyncReadWriteBinding binding) { ClusterFixture.executeAsync(operation, binding) } + def execute(operation, AsyncReadWriteBinding binding, OperationContext operationContext) { + ClusterFixture.executeAsync(operation, binding, operationContext) + } + def executeAndCollectBatchCursorResults(operation, boolean async) { def cursor = execute(operation, async) def results = [] @@ -282,10 +292,9 @@ class OperationFunctionalSpecification extends Specification { } def connectionSource = Stub(ConnectionSource) { - getConnection() >> { + getConnection(_ as OperationContext) >> { connection } - getOperationContext() >> operationContext getReadPreference() >> readPreference getServerDescription() >> { def builder = ServerDescription.builder().address(Stub(ServerAddress)).state(ServerConnectionState.CONNECTED) @@ -298,11 +307,9 @@ class OperationFunctionalSpecification extends Specification { def readBinding = Stub(ReadBinding) { getReadConnectionSource(*_) >> connectionSource getReadPreference() >> readPreference - getOperationContext() >> operationContext } def writeBinding = Stub(WriteBinding) { - getWriteConnectionSource() >> connectionSource - getOperationContext() >> operationContext + getWriteConnectionSource(_) >> connectionSource } if (retryable) { @@ -336,9 +343,9 @@ class OperationFunctionalSpecification extends Specification { 1 * connection.release() } if (operation instanceof ReadOperation) { - operation.execute(readBinding) + operation.execute(readBinding, operationContext) } else if (operation instanceof WriteOperation) { - operation.execute(writeBinding) + operation.execute(writeBinding, operationContext) } } @@ -359,9 +366,10 @@ class OperationFunctionalSpecification extends Specification { } def connectionSource = Stub(AsyncConnectionSource) { - getConnection(_) >> { it[0].onResult(connection, null) } + getConnection(_ as OperationContext, _ as SingleResultCallback) >> { + it[1].onResult(connection, null) + } getReadPreference() >> readPreference - getOperationContext() >> operationContext getServerDescription() >> { def builder = ServerDescription.builder().address(Stub(ServerAddress)).state(ServerConnectionState.CONNECTED) if (new ServerVersion(serverVersion).compareTo(new ServerVersion(3, 6)) >= 0) { @@ -373,11 +381,11 @@ class OperationFunctionalSpecification extends Specification { def readBinding = Stub(AsyncReadBinding) { getReadConnectionSource(*_) >> { it.last().onResult(connectionSource, null) } getReadPreference() >> readPreference - getOperationContext() >> operationContext } def writeBinding = Stub(AsyncWriteBinding) { - getWriteConnectionSource(_) >> { it[0].onResult(connectionSource, null) } - getOperationContext() >> operationContext + getWriteConnectionSource(_ as OperationContext, _ as SingleResultCallback) >> { + it[1].onResult(connectionSource, null) + } } def callback = new FutureResultCallback() @@ -415,9 +423,9 @@ class OperationFunctionalSpecification extends Specification { } if (operation instanceof ReadOperation) { - operation.executeAsync(readBinding, callback) + operation.executeAsync(readBinding, operationContext, callback) } else if (operation instanceof WriteOperation) { - operation.executeAsync(writeBinding, callback) + operation.executeAsync(writeBinding, operationContext, callback) } try { callback.get(1000, TimeUnit.MILLISECONDS) @@ -447,18 +455,16 @@ class OperationFunctionalSpecification extends Specification { }) def connectionSource = Stub(ConnectionSource) { - getConnection() >> { + getConnection(_) >> { if (serverVersions.isEmpty()){ throw new MongoSocketOpenException('No Server', new ServerAddress(), new Exception('no server')) } else { connection } } - getOperationContext() >> operationContext } def writeBinding = Stub(WriteBinding) { - getWriteConnectionSource() >> connectionSource - getOperationContext() >> operationContext + getWriteConnectionSource(_) >> connectionSource } 1 * connection.command(*_) >> { @@ -466,7 +472,7 @@ class OperationFunctionalSpecification extends Specification { } expectedConnectionReleaseCount * connection.release() - operation.execute(writeBinding) + operation.execute(writeBinding, operationContext) } def testAyncRetryableOperationThrows(operation, Queue> serverVersions, Queue serverTypes, @@ -490,27 +496,25 @@ class OperationFunctionalSpecification extends Specification { }) def connectionSource = Stub(AsyncConnectionSource) { - getConnection(_) >> { + getConnection(_ as OperationContext, _ as SingleResultCallback) >> { if (serverVersions.isEmpty()) { - it[0].onResult(null, + it[1].onResult(null, new MongoSocketOpenException('No Server', new ServerAddress(), new Exception('no server'))) } else { - it[0].onResult(connection, null) + it[1].onResult(connection, null) } } - getOperationContext() >> operationContext } def writeBinding = Stub(AsyncWriteBinding) { - getWriteConnectionSource(_) >> { it[0].onResult(connectionSource, null) } - getOperationContext() >> operationContext + getWriteConnectionSource(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(connectionSource, null) } } def callback = new FutureResultCallback() 1 * connection.commandAsync(*_) >> { it.last().onResult(null, exception) } expectedConnectionReleaseCount * connection.release() - operation.executeAsync(writeBinding, callback) + operation.executeAsync(writeBinding, operationContext, callback) callback.get(1000, TimeUnit.MILLISECONDS) } diff --git a/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java b/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java index 3e58712ca9c..6f06f43ee9f 100644 --- a/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java +++ b/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java @@ -72,6 +72,7 @@ import java.util.Optional; import java.util.stream.Collectors; +import static com.mongodb.ClusterFixture.OPERATION_CONTEXT; import static com.mongodb.ClusterFixture.executeAsync; import static com.mongodb.ClusterFixture.getBinding; import static java.util.Arrays.asList; @@ -91,7 +92,7 @@ public CollectionHelper(final Codec codec, final MongoNamespace namespace) { public T hello() { return new CommandReadOperation<>("admin", BsonDocument.parse("{isMaster: 1}"), codec) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } public static void drop(final MongoNamespace namespace) { @@ -104,7 +105,7 @@ public static void drop(final MongoNamespace namespace, final WriteConcern write boolean success = false; while (!success) { try { - new DropCollectionOperation(namespace, writeConcern).execute(getBinding()); + new DropCollectionOperation(namespace, writeConcern).execute(getBinding(), OPERATION_CONTEXT); success = true; } catch (MongoWriteConcernException e) { LOGGER.info("Retrying drop collection after a write concern error: " + e); @@ -129,7 +130,7 @@ public static void dropDatabase(final String name, final WriteConcern writeConce return; } try { - new DropDatabaseOperation(name, writeConcern).execute(getBinding()); + new DropDatabaseOperation(name, writeConcern).execute(getBinding(), OPERATION_CONTEXT); } catch (MongoCommandException e) { if (!e.getErrorMessage().contains("ns not found")) { throw e; @@ -139,7 +140,7 @@ public static void dropDatabase(final String name, final WriteConcern writeConce public static BsonDocument getCurrentClusterTime() { return new CommandReadOperation("admin", new BsonDocument("ping", new BsonInt32(1)), new BsonDocumentCodec()) - .execute(getBinding()).getDocument("$clusterTime", null); + .execute(getBinding(), OPERATION_CONTEXT).getDocument("$clusterTime", null); } public MongoNamespace getNamespace() { @@ -211,7 +212,7 @@ public void create(final String collectionName, final CreateCollectionOptions op boolean success = false; while (!success) { try { - operation.execute(getBinding()); + operation.execute(getBinding(), OPERATION_CONTEXT); success = true; } catch (MongoCommandException e) { if ("Interrupted".equals(e.getErrorCodeName())) { @@ -230,7 +231,7 @@ public void killCursor(final MongoNamespace namespace, final ServerCursor server .append("cursors", new BsonArray(singletonList(new BsonInt64(serverCursor.getId())))); try { new CommandReadOperation<>(namespace.getDatabaseName(), command, new BsonDocumentCodec()) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } catch (Exception e) { // Ignore any exceptions killing old cursors } @@ -262,7 +263,7 @@ public void insertDocuments(final List documents, final WriteConce for (BsonDocument document : documents) { insertRequests.add(new InsertRequest(document)); } - new MixedBulkWriteOperation(namespace, insertRequests, true, writeConcern, false).execute(binding); + new MixedBulkWriteOperation(namespace, insertRequests, true, writeConcern, false).execute(binding, OPERATION_CONTEXT); } public void insertDocuments(final Document... documents) { @@ -305,7 +306,7 @@ public List find() { public Optional listSearchIndex(final String indexName) { ListSearchIndexesOperation listSearchIndexesOperation = new ListSearchIndexesOperation<>(namespace, codec, indexName, null, null, null, null, true); - BatchCursor cursor = listSearchIndexesOperation.execute(getBinding()); + BatchCursor cursor = listSearchIndexesOperation.execute(getBinding(), OPERATION_CONTEXT); List results = new ArrayList<>(); while (cursor.hasNext()) { @@ -318,13 +319,13 @@ public Optional listSearchIndex(final String indexName) { public void createSearchIndex(final SearchIndexRequest searchIndexModel) { CreateSearchIndexesOperation searchIndexesOperation = new CreateSearchIndexesOperation(namespace, singletonList(searchIndexModel)); - searchIndexesOperation.execute(getBinding()); + searchIndexesOperation.execute(getBinding(), OPERATION_CONTEXT); } public List find(final Codec codec) { BatchCursor cursor = new FindOperation<>(namespace, codec) .sort(new BsonDocument("_id", new BsonInt32(1))) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); List results = new ArrayList<>(); while (cursor.hasNext()) { results.addAll(cursor.next()); @@ -343,7 +344,7 @@ public void updateOne(final Bson filter, final Bson update, final boolean isUpse WriteRequest.Type.UPDATE) .upsert(isUpsert)), true, WriteConcern.ACKNOWLEDGED, false) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } public void replaceOne(final Bson filter, final Bson update, final boolean isUpsert) { @@ -353,7 +354,7 @@ public void replaceOne(final Bson filter, final Bson update, final boolean isUps WriteRequest.Type.REPLACE) .upsert(isUpsert)), true, WriteConcern.ACKNOWLEDGED, false) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } public void deleteOne(final Bson filter) { @@ -368,7 +369,7 @@ private void delete(final Bson filter, final boolean multi) { new MixedBulkWriteOperation(namespace, singletonList(new DeleteRequest(filter.toBsonDocument(Document.class, registry)).multi(multi)), true, WriteConcern.ACKNOWLEDGED, false) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } public List find(final Bson filter) { @@ -393,7 +394,7 @@ private List aggregate(final List pipeline, final Decoder decode bsonDocumentPipeline.add(cur.toBsonDocument(Document.class, registry)); } BatchCursor cursor = new AggregateOperation<>(namespace, bsonDocumentPipeline, decoder, level) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); List results = new ArrayList<>(); while (cursor.hasNext()) { results.addAll(cursor.next()); @@ -428,7 +429,7 @@ public List find(final BsonDocument filter, final BsonDocument sort, fina public List find(final BsonDocument filter, final BsonDocument sort, final BsonDocument projection, final Decoder decoder) { BatchCursor cursor = new FindOperation<>(namespace, decoder).filter(filter).sort(sort) - .projection(projection).execute(getBinding()); + .projection(projection).execute(getBinding(), OPERATION_CONTEXT); List results = new ArrayList<>(); while (cursor.hasNext()) { results.addAll(cursor.next()); @@ -441,7 +442,7 @@ public long count() { } public long count(final ReadBinding binding) { - return new CountDocumentsOperation(namespace).execute(binding); + return new CountDocumentsOperation(namespace).execute(binding, OPERATION_CONTEXT); } public long count(final AsyncReadWriteBinding binding) throws Throwable { @@ -450,7 +451,7 @@ public long count(final AsyncReadWriteBinding binding) throws Throwable { public long count(final Bson filter) { return new CountDocumentsOperation(namespace) - .filter(toBsonDocument(filter)).execute(getBinding()); + .filter(toBsonDocument(filter)).execute(getBinding(), OPERATION_CONTEXT); } public BsonDocument wrap(final Document document) { @@ -463,34 +464,36 @@ public BsonDocument toBsonDocument(final Bson document) { public void createIndex(final BsonDocument key) { new CreateIndexesOperation(namespace, singletonList(new IndexRequest(key)), WriteConcern.ACKNOWLEDGED) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } public void createIndex(final Document key) { new CreateIndexesOperation(namespace, singletonList(new IndexRequest(wrap(key))), WriteConcern.ACKNOWLEDGED) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } public void createUniqueIndex(final Document key) { new CreateIndexesOperation(namespace, singletonList(new IndexRequest(wrap(key)).unique(true)), WriteConcern.ACKNOWLEDGED) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } public void createIndex(final Document key, final String defaultLanguage) { new CreateIndexesOperation(namespace, - singletonList(new IndexRequest(wrap(key)).defaultLanguage(defaultLanguage)), WriteConcern.ACKNOWLEDGED).execute(getBinding()); + singletonList(new IndexRequest(wrap(key)).defaultLanguage(defaultLanguage)), WriteConcern.ACKNOWLEDGED).execute( + getBinding(), OPERATION_CONTEXT); } public void createIndex(final Bson key) { new CreateIndexesOperation(namespace, - singletonList(new IndexRequest(key.toBsonDocument(Document.class, registry))), WriteConcern.ACKNOWLEDGED).execute(getBinding()); + singletonList(new IndexRequest(key.toBsonDocument(Document.class, registry))), WriteConcern.ACKNOWLEDGED).execute( + getBinding(), OPERATION_CONTEXT); } public List listIndexes(){ List indexes = new ArrayList<>(); BatchCursor cursor = new ListIndexesOperation<>(namespace, new BsonDocumentCodec()) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); while (cursor.hasNext()) { indexes.addAll(cursor.next()); } @@ -500,7 +503,7 @@ public List listIndexes(){ public static void killAllSessions() { try { new CommandReadOperation<>("admin", - new BsonDocument("killAllSessions", new BsonArray()), new BsonDocumentCodec()).execute(getBinding()); + new BsonDocument("killAllSessions", new BsonArray()), new BsonDocumentCodec()).execute(getBinding(), OPERATION_CONTEXT); } catch (MongoCommandException e) { // ignore exception caused by killing the implicit session that the killAllSessions command itself is running in } @@ -510,7 +513,8 @@ public void renameCollection(final MongoNamespace newNamespace) { try { new CommandReadOperation<>("admin", new BsonDocument("renameCollection", new BsonString(getNamespace().getFullName())) - .append("to", new BsonString(newNamespace.getFullName())), new BsonDocumentCodec()).execute(getBinding()); + .append("to", new BsonString(newNamespace.getFullName())), new BsonDocumentCodec()).execute( + getBinding(), OPERATION_CONTEXT); } catch (MongoCommandException e) { // do nothing } @@ -522,11 +526,11 @@ public void runAdminCommand(final String command) { public void runAdminCommand(final BsonDocument command) { new CommandReadOperation<>("admin", command, new BsonDocumentCodec()) - .execute(getBinding()); + .execute(getBinding(), OPERATION_CONTEXT); } public void runAdminCommand(final BsonDocument command, final ReadPreference readPreference) { new CommandReadOperation<>("admin", command, new BsonDocumentCodec()) - .execute(getBinding(readPreference)); + .execute(getBinding(readPreference), OPERATION_CONTEXT); } } diff --git a/driver-core/src/test/functional/com/mongodb/connection/ConnectionSpecification.groovy b/driver-core/src/test/functional/com/mongodb/connection/ConnectionSpecification.groovy index b3da89231e7..5658ec5ea43 100644 --- a/driver-core/src/test/functional/com/mongodb/connection/ConnectionSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/connection/ConnectionSpecification.groovy @@ -16,15 +16,15 @@ package com.mongodb.connection - import com.mongodb.OperationFunctionalSpecification import com.mongodb.internal.operation.CommandReadOperation import org.bson.BsonDocument import org.bson.BsonInt32 import org.bson.codecs.BsonDocumentCodec -import static com.mongodb.ClusterFixture.getBinding import static com.mongodb.ClusterFixture.LEGACY_HELLO +import static com.mongodb.ClusterFixture.OPERATION_CONTEXT +import static com.mongodb.ClusterFixture.getBinding import static com.mongodb.connection.ConnectionDescription.getDefaultMaxMessageSize import static com.mongodb.connection.ConnectionDescription.getDefaultMaxWriteBatchSize @@ -32,8 +32,8 @@ class ConnectionSpecification extends OperationFunctionalSpecification { def 'should have id'() { when: - def source = getBinding().getReadConnectionSource() - def connection = source.connection + def source = getBinding().getReadConnectionSource(OPERATION_CONTEXT) + def connection = source.getConnection(OPERATION_CONTEXT) then: connection.getDescription().getConnectionId() != null @@ -50,8 +50,8 @@ class ConnectionSpecification extends OperationFunctionalSpecification { new BsonInt32(getDefaultMaxMessageSize())).intValue() def expectedMaxBatchCount = commandResult.getNumber('maxWriteBatchSize', new BsonInt32(getDefaultMaxWriteBatchSize())).intValue() - def source = getBinding().getReadConnectionSource() - def connection = source.connection + def source = getBinding().getReadConnectionSource(OPERATION_CONTEXT) + def connection = source.getConnection(OPERATION_CONTEXT) then: connection.description.serverAddress == source.getServerDescription().getAddress() @@ -66,6 +66,6 @@ class ConnectionSpecification extends OperationFunctionalSpecification { } private static BsonDocument getHelloResult() { new CommandReadOperation('admin', new BsonDocument(LEGACY_HELLO, new BsonInt32(1)), - new BsonDocumentCodec()).execute(getBinding()) + new BsonDocumentCodec()).execute(getBinding(), OPERATION_CONTEXT) } } diff --git a/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncOperationContextBinding.java b/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncOperationContextBinding.java index 17b1a1c4a7e..0a891b55a88 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncOperationContextBinding.java +++ b/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncOperationContextBinding.java @@ -40,8 +40,9 @@ public ReadPreference getReadPreference() { } @Override - public void getWriteConnectionSource(final SingleResultCallback callback) { - wrapped.getWriteConnectionSource((result, t) -> { + public void getWriteConnectionSource(final OperationContext operationContext, + final SingleResultCallback callback) { + wrapped.getWriteConnectionSource(operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -50,14 +51,11 @@ public void getWriteConnectionSource(final SingleResultCallback callback) { - wrapped.getReadConnectionSource((result, t) -> { + public void getReadConnectionSource(final OperationContext operationContext, + final SingleResultCallback callback) { + wrapped.getReadConnectionSource(operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -69,8 +67,9 @@ public void getReadConnectionSource(final SingleResultCallback callback) { - wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, (result, t) -> { + wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -79,6 +78,10 @@ public void getReadConnectionSource(final int minWireVersion, final ReadPreferen }); } + public OperationContext getOperationContext() { + return operationContext; + } + @Override public int getCount() { return wrapped.getCount(); @@ -107,19 +110,14 @@ public ServerDescription getServerDescription() { return wrapped.getServerDescription(); } - @Override - public OperationContext getOperationContext() { - return operationContext; - } - @Override public ReadPreference getReadPreference() { return wrapped.getReadPreference(); } @Override - public void getConnection(final SingleResultCallback callback) { - wrapped.getConnection(callback); + public void getConnection(final OperationContext operationContext, final SingleResultCallback callback) { + wrapped.getConnection(operationContext, callback); } @Override diff --git a/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSessionBinding.java b/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSessionBinding.java index fa588a340d0..5461a6ee007 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSessionBinding.java +++ b/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSessionBinding.java @@ -27,11 +27,11 @@ public final class AsyncSessionBinding implements AsyncReadWriteBinding { private final AsyncReadWriteBinding wrapped; - private final OperationContext operationContext; + // private final OperationContext operationContext; public AsyncSessionBinding(final AsyncReadWriteBinding wrapped) { this.wrapped = notNull("wrapped", wrapped); - this.operationContext = wrapped.getOperationContext().withSessionContext(new SimpleSessionContext()); + // this.operationContext = wrapped.getOperationContext().withSessionContext(new SimpleSessionContext()); } @Override @@ -40,8 +40,8 @@ public ReadPreference getReadPreference() { } @Override - public void getWriteConnectionSource(final SingleResultCallback callback) { - wrapped.getWriteConnectionSource((result, t) -> { + public void getWriteConnectionSource(final OperationContext operationContext, final SingleResultCallback callback) { + wrapped.getWriteConnectionSource(operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -51,13 +51,8 @@ public void getWriteConnectionSource(final SingleResultCallback callback) { - wrapped.getReadConnectionSource((result, t) -> { + public void getReadConnectionSource(final OperationContext operationContext, final SingleResultCallback callback) { + wrapped.getReadConnectionSource(operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -69,8 +64,9 @@ public void getReadConnectionSource(final SingleResultCallback callback) { - wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, (result, t) -> { + wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -107,19 +103,14 @@ public ServerDescription getServerDescription() { return wrapped.getServerDescription(); } - @Override - public OperationContext getOperationContext() { - return operationContext; - } - @Override public ReadPreference getReadPreference() { return wrapped.getReadPreference(); } @Override - public void getConnection(final SingleResultCallback callback) { - wrapped.getConnection(callback); + public void getConnection(final OperationContext operationContext, final SingleResultCallback callback) { + wrapped.getConnection(operationContext, callback); } @Override diff --git a/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSessionBindingSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSessionBindingSpecification.groovy index 87fa1b9c4ff..173cd9f0935 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSessionBindingSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSessionBindingSpecification.groovy @@ -26,7 +26,6 @@ class AsyncSessionBindingSpecification extends Specification { def 'should wrap the passed in async binding'() { given: def wrapped = Mock(AsyncReadWriteBinding) - wrapped.getOperationContext() >> OPERATION_CONTEXT def binding = new AsyncSessionBinding(wrapped) when: @@ -54,23 +53,16 @@ class AsyncSessionBindingSpecification extends Specification { 1 * wrapped.release() when: - binding.getReadConnectionSource(Stub(SingleResultCallback)) + binding.getReadConnectionSource(OPERATION_CONTEXT, Stub(SingleResultCallback)) then: - 1 * wrapped.getReadConnectionSource(_) + 1 * wrapped.getReadConnectionSource(OPERATION_CONTEXT, _) when: - binding.getWriteConnectionSource(Stub(SingleResultCallback)) + binding.getWriteConnectionSource(OPERATION_CONTEXT, Stub(SingleResultCallback)) then: - 1 * wrapped.getWriteConnectionSource(_) - - when: - def context = binding.getOperationContext().getSessionContext() - - then: - 0 * wrapped.getOperationContext().getSessionContext() - context instanceof SimpleSessionContext + 1 * wrapped.getWriteConnectionSource(OPERATION_CONTEXT, _) } } diff --git a/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSingleConnectionBinding.java b/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSingleConnectionBinding.java index 3fff8b66e06..46fde18d1c7 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSingleConnectionBinding.java +++ b/driver-core/src/test/functional/com/mongodb/internal/binding/AsyncSingleConnectionBinding.java @@ -129,16 +129,13 @@ public ReadPreference getReadPreference() { return readPreference; } - @Override - public OperationContext getOperationContext() { - return operationContext; - } @Override - public void getReadConnectionSource(final SingleResultCallback callback) { + public void getReadConnectionSource(final OperationContext operationContext, + final SingleResultCallback callback) { isTrue("open", getCount() > 0); if (readPreference == primary()) { - getWriteConnectionSource(callback); + getWriteConnectionSource(operationContext, callback); } else { callback.onResult(new SingleAsyncConnectionSource(readServerDescription, readConnection), null); } @@ -146,12 +143,13 @@ public void getReadConnectionSource(final SingleResultCallback callback) { - getReadConnectionSource(callback); + getReadConnectionSource(operationContext, callback); } @Override - public void getWriteConnectionSource(final SingleResultCallback callback) { + public void getWriteConnectionSource(final OperationContext operationContext, final SingleResultCallback callback) { isTrue("open", getCount() > 0); callback.onResult(new SingleAsyncConnectionSource(writeServerDescription, writeConnection), null); } @@ -174,6 +172,7 @@ private SingleAsyncConnectionSource(final ServerDescription serverDescription, final AsyncConnection connection) { this.serverDescription = serverDescription; this.connection = connection; + //TODO do we need to retain and release those references properly? AsyncSingleConnectionBinding.this.retain(); } @@ -182,18 +181,13 @@ public ServerDescription getServerDescription() { return serverDescription; } - @Override - public OperationContext getOperationContext() { - return operationContext; - } - @Override public ReadPreference getReadPreference() { return readPreference; } @Override - public void getConnection(final SingleResultCallback callback) { + public void getConnection(final OperationContext operationContext, final SingleResultCallback callback) { isTrue("open", getCount() > 0); callback.onResult(connection.retain(), null); } diff --git a/driver-core/src/test/functional/com/mongodb/internal/binding/OperationContextBinding.java b/driver-core/src/test/functional/com/mongodb/internal/binding/OperationContextBinding.java index 6af3f4520d4..3428db4f82e 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/binding/OperationContextBinding.java +++ b/driver-core/src/test/functional/com/mongodb/internal/binding/OperationContextBinding.java @@ -54,23 +54,22 @@ public int release() { } @Override - public ConnectionSource getReadConnectionSource() { - return new SessionBindingConnectionSource(wrapped.getReadConnectionSource()); + public ConnectionSource getReadConnectionSource(final OperationContext operationContext) { + return new SessionBindingConnectionSource(wrapped.getReadConnectionSource(operationContext)); } @Override - public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference) { - return new SessionBindingConnectionSource(wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference)); + public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference, final OperationContext operationContext) { + return new SessionBindingConnectionSource(wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, operationContext)); } - @Override public OperationContext getOperationContext() { return operationContext; } @Override - public ConnectionSource getWriteConnectionSource() { - return new SessionBindingConnectionSource(wrapped.getWriteConnectionSource()); + public ConnectionSource getWriteConnectionSource(final OperationContext operationContext) { + return new SessionBindingConnectionSource(wrapped.getWriteConnectionSource(operationContext)); } private class SessionBindingConnectionSource implements ConnectionSource { @@ -85,19 +84,14 @@ public ServerDescription getServerDescription() { return wrapped.getServerDescription(); } - @Override - public OperationContext getOperationContext() { - return operationContext; - } - @Override public ReadPreference getReadPreference() { return wrapped.getReadPreference(); } @Override - public Connection getConnection() { - return wrapped.getConnection(); + public Connection getConnection(final OperationContext operationContext) { + return wrapped.getConnection(operationContext); } @Override diff --git a/driver-core/src/test/functional/com/mongodb/internal/binding/SessionBinding.java b/driver-core/src/test/functional/com/mongodb/internal/binding/SessionBinding.java index 3a2666a8093..04ed3bf5f4b 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/binding/SessionBinding.java +++ b/driver-core/src/test/functional/com/mongodb/internal/binding/SessionBinding.java @@ -25,11 +25,11 @@ public class SessionBinding implements ReadWriteBinding { private final ReadWriteBinding wrapped; - private final OperationContext operationContext; public SessionBinding(final ReadWriteBinding wrapped) { this.wrapped = notNull("wrapped", wrapped); - this.operationContext = wrapped.getOperationContext().withSessionContext(new SimpleSessionContext()); + //TODO + // this.operationContext = wrapped.getOperationContext().withSessionContext(new SimpleSessionContext()); } @Override @@ -54,23 +54,19 @@ public int release() { } @Override - public ConnectionSource getReadConnectionSource() { - return new SessionBindingConnectionSource(wrapped.getReadConnectionSource()); + public ConnectionSource getReadConnectionSource(final OperationContext operationContext) { + return new SessionBindingConnectionSource(wrapped.getReadConnectionSource(operationContext)); } @Override - public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference) { - return new SessionBindingConnectionSource(wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference)); + public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference, final OperationContext operationContext) { + return new SessionBindingConnectionSource(wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, operationContext)); } - @Override - public OperationContext getOperationContext() { - return operationContext; - } @Override - public ConnectionSource getWriteConnectionSource() { - return new SessionBindingConnectionSource(wrapped.getWriteConnectionSource()); + public ConnectionSource getWriteConnectionSource(final OperationContext operationContext) { + return new SessionBindingConnectionSource(wrapped.getWriteConnectionSource(operationContext)); } private class SessionBindingConnectionSource implements ConnectionSource { @@ -85,19 +81,14 @@ public ServerDescription getServerDescription() { return wrapped.getServerDescription(); } - @Override - public OperationContext getOperationContext() { - return operationContext; - } - @Override public ReadPreference getReadPreference() { return wrapped.getReadPreference(); } @Override - public Connection getConnection() { - return wrapped.getConnection(); + public Connection getConnection(final OperationContext operationContext) { + return wrapped.getConnection(operationContext); } @Override diff --git a/driver-core/src/test/functional/com/mongodb/internal/binding/SingleConnectionBinding.java b/driver-core/src/test/functional/com/mongodb/internal/binding/SingleConnectionBinding.java index 6bf3cff636d..911155ce10d 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/binding/SingleConnectionBinding.java +++ b/driver-core/src/test/functional/com/mongodb/internal/binding/SingleConnectionBinding.java @@ -41,7 +41,6 @@ public class SingleConnectionBinding implements ReadWriteBinding { private final ServerDescription readServerDescription; private final ServerDescription writeServerDescription; private int count = 1; - private final OperationContext operationContext; /** * Create a new binding with the given cluster. @@ -53,7 +52,6 @@ public class SingleConnectionBinding implements ReadWriteBinding { public SingleConnectionBinding(final Cluster cluster, final ReadPreference readPreference, final OperationContext operationContext) { notNull("cluster", cluster); this.readPreference = notNull("readPreference", readPreference); - this.operationContext = operationContext; ServerTuple writeServerTuple = cluster.selectServer(new WritableServerSelector(), operationContext); writeServerDescription = writeServerTuple.getServerDescription(); writeConnection = writeServerTuple.getServer().getConnection(operationContext); @@ -90,27 +88,23 @@ public ReadPreference getReadPreference() { } @Override - public ConnectionSource getReadConnectionSource() { + public ConnectionSource getReadConnectionSource(final OperationContext operationContext) { isTrue("open", getCount() > 0); if (readPreference == primary()) { - return getWriteConnectionSource(); + return getWriteConnectionSource(operationContext); } else { return new SingleConnectionSource(readServerDescription, readConnection); } } @Override - public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference) { + public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference, final OperationContext operationContext) { throw new UnsupportedOperationException(); } - @Override - public OperationContext getOperationContext() { - return operationContext; - } @Override - public ConnectionSource getWriteConnectionSource() { + public ConnectionSource getWriteConnectionSource(final OperationContext operationContext) { isTrue("open", getCount() > 0); return new SingleConnectionSource(writeServerDescription, writeConnection); } @@ -120,7 +114,8 @@ private final class SingleConnectionSource implements ConnectionSource { private final Connection connection; private int count = 1; - SingleConnectionSource(final ServerDescription serverDescription, final Connection connection) { + SingleConnectionSource(final ServerDescription serverDescription, + final Connection connection) { this.serverDescription = serverDescription; this.connection = connection; SingleConnectionBinding.this.retain(); @@ -131,18 +126,13 @@ public ServerDescription getServerDescription() { return serverDescription; } - @Override - public OperationContext getOperationContext() { - return operationContext; - } - @Override public ReadPreference getReadPreference() { return readPreference; } @Override - public Connection getConnection() { + public Connection getConnection(final OperationContext operationContext) { isTrue("open", getCount() > 0); return connection.retain(); } diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/ScramSha256AuthenticationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/connection/ScramSha256AuthenticationSpecification.groovy index 4901872c1fc..36aac9b6908 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/ScramSha256AuthenticationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/ScramSha256AuthenticationSpecification.groovy @@ -16,9 +16,9 @@ package com.mongodb.internal.connection +import com.mongodb.ClusterFixture import com.mongodb.MongoCredential import com.mongodb.MongoSecurityException -import com.mongodb.ReadConcern import com.mongodb.ReadPreference import com.mongodb.async.FutureResultCallback import com.mongodb.internal.binding.AsyncClusterBinding @@ -85,14 +85,17 @@ class ScramSha256AuthenticationSpecification extends Specification { .append('pwd', password) .append('roles', ['root']) .append('mechanisms', mechanisms) + def binding = getBinding() new CommandReadOperation<>('admin', new BsonDocumentWrapper(createUserCommand, new DocumentCodec()), new DocumentCodec()) - .execute(getBinding()) + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) } def dropUser(final String userName) { + def binding = getBinding() + def operationContext = ClusterFixture.getOperationContext(binding.getReadPreference()) new CommandReadOperation<>('admin', new BsonDocument('dropUser', new BsonString(userName)), - new BsonDocumentCodec()).execute(getBinding()) + new BsonDocumentCodec()).execute(binding, operationContext) } def 'test authentication and authorization'() { @@ -102,7 +105,7 @@ class ScramSha256AuthenticationSpecification extends Specification { when: new CommandReadOperation('admin', new BsonDocumentWrapper(new Document('dbstats', 1), new DocumentCodec()), new DocumentCodec()) - .execute(new ClusterBinding(cluster, ReadPreference.primary(), ReadConcern.DEFAULT, OPERATION_CONTEXT)) + .execute(new ClusterBinding(cluster, ReadPreference.primary()), OPERATION_CONTEXT) then: noExceptionThrown() @@ -119,12 +122,13 @@ class ScramSha256AuthenticationSpecification extends Specification { def cluster = createAsyncCluster(credential) def callback = new FutureResultCallback() + when: // make this synchronous + def binding = new AsyncClusterBinding(cluster, ReadPreference.primary()) new CommandReadOperation('admin', new BsonDocumentWrapper(new Document('dbstats', 1), new DocumentCodec()), new DocumentCodec()) - .executeAsync(new AsyncClusterBinding(cluster, ReadPreference.primary(), ReadConcern.DEFAULT, OPERATION_CONTEXT), - callback) + .executeAsync(binding, OPERATION_CONTEXT, callback) callback.get() then: @@ -144,7 +148,7 @@ class ScramSha256AuthenticationSpecification extends Specification { when: new CommandReadOperation('admin', new BsonDocumentWrapper(new Document('dbstats', 1), new DocumentCodec()), new DocumentCodec()) - .execute(new ClusterBinding(cluster, ReadPreference.primary(), ReadConcern.DEFAULT, OPERATION_CONTEXT)) + .execute(new ClusterBinding(cluster, ReadPreference.primary()), OPERATION_CONTEXT) then: thrown(MongoSecurityException) @@ -164,7 +168,7 @@ class ScramSha256AuthenticationSpecification extends Specification { when: new CommandReadOperation('admin', new BsonDocumentWrapper(new Document('dbstats', 1), new DocumentCodec()), new DocumentCodec()) - .executeAsync(new AsyncClusterBinding(cluster, ReadPreference.primary(), ReadConcern.DEFAULT, OPERATION_CONTEXT), + .executeAsync(new AsyncClusterBinding(cluster, ReadPreference.primary()), OPERATION_CONTEXT, callback) callback.get() @@ -185,7 +189,7 @@ class ScramSha256AuthenticationSpecification extends Specification { when: new CommandReadOperation('admin', new BsonDocumentWrapper(new Document('dbstats', 1), new DocumentCodec()), new DocumentCodec()) - .execute(new ClusterBinding(cluster, ReadPreference.primary(), ReadConcern.DEFAULT, OPERATION_CONTEXT)) + .execute(new ClusterBinding(cluster, ReadPreference.primary()), OPERATION_CONTEXT) then: noExceptionThrown() @@ -205,7 +209,7 @@ class ScramSha256AuthenticationSpecification extends Specification { when: new CommandReadOperation('admin', new BsonDocumentWrapper(new Document('dbstats', 1), new DocumentCodec()), new DocumentCodec()) - .executeAsync(new AsyncClusterBinding(cluster, ReadPreference.primary(), ReadConcern.DEFAULT, OPERATION_CONTEXT), + .executeAsync(new AsyncClusterBinding(cluster, ReadPreference.primary()), OPERATION_CONTEXT, callback) callback.get() diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateOperationSpecification.groovy index 0ce503f466e..aa7506d6516 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateOperationSpecification.groovy @@ -16,7 +16,7 @@ package com.mongodb.internal.operation - +import com.mongodb.ClusterFixture import com.mongodb.MongoNamespace import com.mongodb.OperationFunctionalSpecification import com.mongodb.ReadConcern @@ -31,12 +31,14 @@ import com.mongodb.connection.ConnectionDescription import com.mongodb.connection.ConnectionId import com.mongodb.connection.ServerId import com.mongodb.connection.ServerVersion +import com.mongodb.internal.async.SingleResultCallback import com.mongodb.internal.binding.AsyncConnectionSource import com.mongodb.internal.binding.AsyncReadBinding import com.mongodb.internal.binding.ConnectionSource import com.mongodb.internal.binding.ReadBinding import com.mongodb.internal.connection.AsyncConnection import com.mongodb.internal.connection.Connection +import com.mongodb.internal.connection.OperationContext import com.mongodb.internal.session.SessionContext import org.bson.BsonArray import org.bson.BsonBoolean @@ -54,8 +56,8 @@ import static com.mongodb.ClusterFixture.OPERATION_CONTEXT import static com.mongodb.ClusterFixture.collectCursorResults import static com.mongodb.ClusterFixture.executeAsync import static com.mongodb.ClusterFixture.getAsyncCluster -import static com.mongodb.ClusterFixture.getBinding import static com.mongodb.ClusterFixture.getCluster +import static com.mongodb.ClusterFixture.getOperationContext import static com.mongodb.ClusterFixture.isSharded import static com.mongodb.ClusterFixture.isStandalone import static com.mongodb.ExplainVerbosity.QUERY_PLANNER @@ -64,6 +66,7 @@ import static com.mongodb.internal.connection.ServerHelper.waitForLastRelease import static com.mongodb.internal.operation.OperationReadConcernHelper.appendReadConcernToCommand import static com.mongodb.internal.operation.ServerVersionHelper.UNKNOWN_WIRE_VERSION import static com.mongodb.internal.operation.TestOperationHelper.getKeyPattern +import static org.junit.jupiter.api.Assertions.assertEquals class AggregateOperationSpecification extends OperationFunctionalSpecification { @@ -226,8 +229,10 @@ class AggregateOperationSpecification extends OperationFunctionalSpecification { def viewSuffix = '-view' def viewName = getCollectionName() + viewSuffix def viewNamespace = new MongoNamespace(getDatabaseName(), viewName) + + def binding = ClusterFixture.getBinding(ClusterFixture.getCluster()) new CreateViewOperation(getDatabaseName(), viewName, getCollectionName(), [], WriteConcern.ACKNOWLEDGED) - .execute(getBinding(getCluster())) + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) when: AggregateOperation operation = new AggregateOperation(viewNamespace, [], new DocumentCodec()) @@ -239,8 +244,9 @@ class AggregateOperationSpecification extends OperationFunctionalSpecification { results.containsAll(['Pete', 'Sam']) cleanup: + binding = ClusterFixture.getBinding(ClusterFixture.getCluster()) new DropCollectionOperation(viewNamespace, WriteConcern.ACKNOWLEDGED) - .execute(getBinding(getCluster())) + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) where: async << [true, false] @@ -265,7 +271,9 @@ class AggregateOperationSpecification extends OperationFunctionalSpecification { when: AggregateOperation operation = new AggregateOperation(getNamespace(), [], new DocumentCodec()) .allowDiskUse(allowDiskUse) - def cursor = operation.execute(getBinding()) + + def binding = ClusterFixture.getBinding() + def cursor = operation.execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) then: cursor.next()*.getString('name') == ['Pete', 'Sam', 'Pete'] @@ -278,7 +286,9 @@ class AggregateOperationSpecification extends OperationFunctionalSpecification { when: AggregateOperation operation = new AggregateOperation(getNamespace(), [], new DocumentCodec()) .batchSize(batchSize) - def cursor = operation.execute(getBinding()) + + def binding = ClusterFixture.getBinding() + def cursor = operation.execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) then: cursor.next()*.getString('name') == ['Pete', 'Sam', 'Pete'] @@ -343,8 +353,10 @@ class AggregateOperationSpecification extends OperationFunctionalSpecification { def 'should apply comment'() { given: def profileCollectionHelper = getCollectionHelper(new MongoNamespace(getDatabaseName(), 'system.profile')) + + def binding = ClusterFixture.getBinding() new CommandReadOperation<>(getDatabaseName(), new BsonDocument('profile', new BsonInt32(2)), - new BsonDocumentCodec()).execute(getBinding()) + new BsonDocumentCodec()).execute(binding, getOperationContext(binding.getReadPreference())) def expectedComment = 'this is a comment' def operation = new AggregateOperation(getNamespace(), [], new DocumentCodec()) .comment(new BsonString(expectedComment)) @@ -356,9 +368,11 @@ class AggregateOperationSpecification extends OperationFunctionalSpecification { Document profileDocument = profileCollectionHelper.find(Filters.exists('command.aggregate')).get(0) ((Document) profileDocument.get('command')).get('comment') == expectedComment + cleanup: + binding = ClusterFixture.getBinding() new CommandReadOperation<>(getDatabaseName(), new BsonDocument('profile', new BsonInt32(0)), - new BsonDocumentCodec()).execute(getBinding()) + new BsonDocumentCodec()).execute(binding, getOperationContext(binding.getReadPreference())) profileCollectionHelper.drop() where: @@ -372,11 +386,9 @@ class AggregateOperationSpecification extends OperationFunctionalSpecification { def source = Stub(ConnectionSource) def connection = Mock(Connection) binding.readPreference >> ReadPreference.primary() - binding.operationContext >> operationContext - binding.readConnectionSource >> source - source.connection >> connection + binding.getReadConnectionSource(_) >> source + source.getConnection(_) >> connection source.retain() >> source - source.operationContext >> operationContext def commandDocument = new BsonDocument('aggregate', new BsonString(getCollectionName())) .append('pipeline', new BsonArray()) .append('cursor', new BsonDocument()) @@ -385,15 +397,17 @@ class AggregateOperationSpecification extends OperationFunctionalSpecification { def operation = new AggregateOperation(getNamespace(), [], new DocumentCodec()) when: - operation.execute(binding) + operation.execute(binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), 6, STANDALONE, 1000, 100000, 100000, []) - 1 * connection.command(_, commandDocument, _, _, _, operationContext) >> - new BsonDocument('cursor', new BsonDocument('id', new BsonInt64(1)) - .append('ns', new BsonString(getNamespace().getFullName())) - .append('firstBatch', new BsonArrayWrapper([]))) + 1 * connection.command(_, commandDocument, _, _, _, _ as OperationContext) >> { + assertEquals(((OperationContext) it[5]).getId(), operationContext.getId()) + new BsonDocument('cursor', new BsonDocument('id', new BsonInt64(1)) + .append('ns', new BsonString(getNamespace().getFullName())) + .append('firstBatch', new BsonArrayWrapper([]))) + } 1 * connection.release() where: @@ -413,10 +427,8 @@ class AggregateOperationSpecification extends OperationFunctionalSpecification { def binding = Stub(AsyncReadBinding) def source = Stub(AsyncConnectionSource) def connection = Mock(AsyncConnection) - binding.operationContext >> operationContext - binding.getReadConnectionSource(_) >> { it[0].onResult(source, null) } - source.operationContext >> operationContext - source.getConnection(_) >> { it[0].onResult(connection, null) } + binding.getReadConnectionSource(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(source, null) } + source.getConnection(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(connection, null) } source.retain() >> source def commandDocument = new BsonDocument('aggregate', new BsonString(getCollectionName())) .append('pipeline', new BsonArray()) @@ -426,12 +438,12 @@ class AggregateOperationSpecification extends OperationFunctionalSpecification { def operation = new AggregateOperation(getNamespace(), [], new DocumentCodec()) when: - executeAsync(operation, binding) + executeAsync(operation, binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), 6, STANDALONE, 1000, 100000, 100000, []) - 1 * connection.commandAsync(_, commandDocument, _, _, _, operationContext, _) >> { + 1 * connection.commandAsync(_, commandDocument, _, _, _, _, _) >> { it.last().onResult(new BsonDocument('cursor', new BsonDocument('id', new BsonInt64(1)) .append('ns', new BsonString(getNamespace().getFullName())) .append('firstBatch', new BsonArrayWrapper([]))), null) diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateToCollectionOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateToCollectionOperationSpecification.groovy index ed617289316..6ebdcdc6b40 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateToCollectionOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateToCollectionOperationSpecification.groovy @@ -16,6 +16,7 @@ package com.mongodb.internal.operation +import com.mongodb.ClusterFixture import com.mongodb.MongoCommandException import com.mongodb.MongoNamespace import com.mongodb.MongoWriteConcernException @@ -275,8 +276,9 @@ class AggregateToCollectionOperationSpecification extends OperationFunctionalSpe def 'should apply comment'() { given: def profileCollectionHelper = getCollectionHelper(new MongoNamespace(getDatabaseName(), 'system.profile')) + def binding = getBinding() new CommandReadOperation<>(getDatabaseName(), new BsonDocument('profile', new BsonInt32(2)), - new BsonDocumentCodec()).execute(getBinding()) + new BsonDocumentCodec()).execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) def expectedComment = 'this is a comment' AggregateToCollectionOperation operation = createOperation(getNamespace(), [Aggregates.out('outputCollection').toBsonDocument(BsonDocument, registry)], ACKNOWLEDGED) @@ -291,7 +293,7 @@ class AggregateToCollectionOperationSpecification extends OperationFunctionalSpe cleanup: new CommandReadOperation<>(getDatabaseName(), new BsonDocument('profile', new BsonInt32(0)), - new BsonDocumentCodec()).execute(getBinding()) + new BsonDocumentCodec()).execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) profileCollectionHelper.drop() where: diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandBatchCursorFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandBatchCursorFunctionalTest.java index 88dc199ee29..58e01242420 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandBatchCursorFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandBatchCursorFunctionalTest.java @@ -55,6 +55,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static com.mongodb.ClusterFixture.OPERATION_CONTEXT; import static com.mongodb.ClusterFixture.checkReferenceCountReachesTarget; import static com.mongodb.ClusterFixture.getAsyncBinding; import static com.mongodb.ClusterFixture.getConnection; @@ -110,8 +111,8 @@ void cleanup() { void shouldExhaustCursorAsyncWithMultipleBatches() { // given BsonDocument commandResult = executeFindCommand(0, 3); // Fetch in batches of size 3 - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 3, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, 3, DOCUMENT_DECODER, null, connectionSource, connection)); // when FutureResultCallback>> futureCallback = new FutureResultCallback<>(); @@ -132,8 +133,8 @@ void shouldExhaustCursorAsyncWithMultipleBatches() { void shouldExhaustCursorAsyncWithClosedCursor() { // given BsonDocument commandResult = executeFindCommand(0, 3); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 3, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, 3, DOCUMENT_DECODER, null, connectionSource, connection)); cursor.close(); @@ -155,8 +156,8 @@ void shouldExhaustCursorAsyncWithEmptyCursor() { getCollectionHelper().deleteMany(Filters.empty()); BsonDocument commandResult = executeFindCommand(0, 3); // No documents to fetch - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 3, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, 3, DOCUMENT_DECODER, null, connectionSource, connection)); // when FutureResultCallback>> futureCallback = new FutureResultCallback<>(); @@ -172,33 +173,37 @@ void shouldExhaustCursorAsyncWithEmptyCursor() { @DisplayName("server cursor should not be null") void theServerCursorShouldNotBeNull() { BsonDocument commandResult = executeFindCommand(2); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + AsyncCommandCoreCursor coreCursor = + new AsyncCommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + coreCursor); - assertNotNull(cursor.getServerCursor()); + assertNotNull(coreCursor.getServerCursor()); } @Test @DisplayName("should get Exceptions for operations on the cursor after closing") void shouldGetExceptionsForOperationsOnTheCursorAfterClosing() { BsonDocument commandResult = executeFindCommand(5); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + AsyncCommandCoreCursor coreCursor = + new AsyncCommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + coreCursor); cursor.close(); assertDoesNotThrow(() -> cursor.close()); checkReferenceCountReachesTarget(connectionSource, 1); assertThrows(IllegalStateException.class, this::cursorNext); - assertNull(cursor.getServerCursor()); + assertNull(coreCursor.getServerCursor()); } @Test @DisplayName("should throw an Exception when going off the end") void shouldThrowAnExceptionWhenGoingOffTheEnd() { BsonDocument commandResult = executeFindCommand(2, 1); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); cursorNext(); cursorNext(); @@ -211,8 +216,8 @@ void shouldThrowAnExceptionWhenGoingOffTheEnd() { @DisplayName("test normal exhaustion") void testNormalExhaustion() { BsonDocument commandResult = executeFindCommand(); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, 3, DOCUMENT_DECODER, null, connectionSource, connection)); assertEquals(10, cursorFlatten().size()); } @@ -222,8 +227,8 @@ void testNormalExhaustion() { @DisplayName("test limit exhaustion") void testLimitExhaustion(final int limit, final int batchSize, final int expectedTotal) { BsonDocument commandResult = executeFindCommand(limit, batchSize); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, batchSize, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, batchSize, DOCUMENT_DECODER, null, connectionSource, connection)); assertEquals(expectedTotal, cursorFlatten().size()); @@ -241,8 +246,8 @@ void shouldBlockWaitingForNextBatchOnATailableCursor(final boolean awaitData, fi BsonDocument commandResult = executeFindCommand(new BsonDocument("ts", new BsonDocument("$gte", new BsonTimestamp(5, 0))), 0, 2, true, awaitData); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, maxTimeMS, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, maxTimeMS, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); assertFalse(cursor.isClosed()); assertEquals(1, cursorNext().get(0).get("_id")); @@ -264,8 +269,8 @@ void testTailableInterrupt() throws InterruptedException { BsonDocument commandResult = executeFindCommand(new BsonDocument("ts", new BsonDocument("$gte", new BsonTimestamp(5, 0))), 0, 2, true, true); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); CountDownLatch latch = new CountDownLatch(1); AtomicInteger seen = new AtomicInteger(); @@ -297,12 +302,14 @@ void testTailableInterrupt() throws InterruptedException { void shouldKillCursorIfLimitIsReachedOnInitialQuery() { assumeFalse(isSharded()); BsonDocument commandResult = executeFindCommand(5, 10); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + AsyncCommandCoreCursor coreCursor = + new AsyncCommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + coreCursor); assertNotNull(cursorNext()); assertTrue(cursor.isClosed()); - assertNull(cursor.getServerCursor()); + assertNull(coreCursor.getServerCursor()); } @Test @@ -310,10 +317,12 @@ void shouldKillCursorIfLimitIsReachedOnInitialQuery() { void shouldKillCursorIfLimitIsReachedOnGetMore() { assumeFalse(isSharded()); BsonDocument commandResult = executeFindCommand(5, 3); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 3, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + AsyncCommandCoreCursor coreCursor = + new AsyncCommandCoreCursor<>(commandResult, 3, DOCUMENT_DECODER, null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + coreCursor); - ServerCursor serverCursor = cursor.getServerCursor(); + ServerCursor serverCursor = coreCursor.getServerCursor(); assertNotNull(serverCursor); assertNotNull(cursorNext()); assertNotNull(cursorNext()); @@ -330,12 +339,14 @@ void shouldReleaseConnectionSourceIfLimitIsReachedOnInitialQuery() { assumeFalse(isSharded()); BsonDocument commandResult = executeFindCommand(5, 10); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + AsyncCommandCoreCursor coreCursor = + new AsyncCommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + coreCursor); assertDoesNotThrow(() -> checkReferenceCountReachesTarget(connectionSource, 1)); assertDoesNotThrow(() -> checkReferenceCountReachesTarget(connection, 1)); - assertNull(cursor.getServerCursor()); + assertNull(coreCursor.getServerCursor()); } @Test @@ -343,8 +354,8 @@ void shouldReleaseConnectionSourceIfLimitIsReachedOnInitialQuery() { void shouldReleaseConnectionSourceIfLimitIsReachedOnGetMore() { assumeFalse(isSharded()); BsonDocument commandResult = executeFindCommand(5, 3); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 3, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, 3, DOCUMENT_DECODER, null, connectionSource, connection)); assertNotNull(cursorNext()); assertNotNull(cursorNext()); @@ -356,8 +367,8 @@ void shouldReleaseConnectionSourceIfLimitIsReachedOnGetMore() { @DisplayName("test limit with get more") void testLimitWithGetMore() { BsonDocument commandResult = executeFindCommand(5, 2); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); assertNotNull(cursorNext()); assertNotNull(cursorNext()); @@ -379,8 +390,8 @@ void testLimitWithLargeDocuments() { ); BsonDocument commandResult = executeFindCommand(300, 0); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); assertEquals(300, cursorFlatten().size()); } @@ -389,8 +400,8 @@ void testLimitWithLargeDocuments() { @DisplayName("should respect batch size") void shouldRespectBatchSize() { BsonDocument commandResult = executeFindCommand(2); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new AsyncCommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); assertEquals(2, cursor.getBatchSize()); assertEquals(2, cursorNext().size()); @@ -406,16 +417,18 @@ void shouldRespectBatchSize() { @DisplayName("should throw cursor not found exception") void shouldThrowCursorNotFoundException() throws Throwable { BsonDocument commandResult = executeFindCommand(2); - cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + AsyncCommandCoreCursor coreCursor = + new AsyncCommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection); + cursor = new AsyncCommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + coreCursor); - ServerCursor serverCursor = cursor.getServerCursor(); + ServerCursor serverCursor = coreCursor.getServerCursor(); assertNotNull(serverCursor); AsyncConnection localConnection = getConnection(connectionSource); this.block(cb -> localConnection.commandAsync(getNamespace().getDatabaseName(), new BsonDocument("killCursors", new BsonString(getNamespace().getCollectionName())) .append("cursors", new BsonArray(singletonList(new BsonInt64(serverCursor.getId())))), - NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), connectionSource.getOperationContext(), cb)); + NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), OPERATION_CONTEXT, cb)); localConnection.release(); cursorNext(); @@ -481,7 +494,7 @@ private BsonDocument executeFindCommand(final BsonDocument filter, final int lim BsonDocument results = block(cb -> connection.commandAsync(getDatabaseName(), findCommand, NoOpFieldNameValidator.INSTANCE, readPreference, CommandResultDocumentCodec.create(DOCUMENT_DECODER, FIRST_BATCH), - connectionSource.getOperationContext(), cb)); + OPERATION_CONTEXT, cb)); assertNotNull(results); return results; diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandBatchCursorTest.java b/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandBatchCursorTest.java index e9a30686d5f..499f6724d6e 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandBatchCursorTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandBatchCursorTest.java @@ -17,252 +17,73 @@ package com.mongodb.internal.operation; import com.mongodb.MongoClientSettings; -import com.mongodb.MongoNamespace; -import com.mongodb.MongoOperationTimeoutException; -import com.mongodb.MongoSocketException; -import com.mongodb.ServerAddress; import com.mongodb.client.cursor.TimeoutMode; -import com.mongodb.connection.ConnectionDescription; -import com.mongodb.connection.ServerDescription; -import com.mongodb.connection.ServerType; -import com.mongodb.connection.ServerVersion; +import com.mongodb.internal.IgnorableRequestContext; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.TimeoutSettings; -import com.mongodb.internal.async.SingleResultCallback; -import com.mongodb.internal.binding.AsyncConnectionSource; -import com.mongodb.internal.connection.AsyncConnection; +import com.mongodb.internal.connection.NoOpSessionContext; import com.mongodb.internal.connection.OperationContext; -import org.bson.BsonArray; -import org.bson.BsonDocument; -import org.bson.BsonInt32; -import org.bson.BsonInt64; -import org.bson.BsonString; import org.bson.Document; -import org.bson.codecs.Decoder; -import org.bson.codecs.DocumentCodec; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; import java.time.Duration; -import static com.mongodb.internal.operation.OperationUnitSpecification.getMaxWireVersionForServerVersion; -import static com.mongodb.internal.thread.InterruptionUtil.interruptAndCreateMongoInterruptedException; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; class AsyncCommandBatchCursorTest { - - private static final MongoNamespace NAMESPACE = new MongoNamespace("test", "test"); - private static final BsonInt64 CURSOR_ID = new BsonInt64(1); - private static final BsonDocument COMMAND_CURSOR_DOCUMENT = new BsonDocument("ok", new BsonInt32(1)) - .append("cursor", - new BsonDocument("ns", new BsonString(NAMESPACE.getFullName())) - .append("id", CURSOR_ID) - .append("firstBatch", new BsonArrayWrapper<>(new BsonArray()))); - - private static final Decoder DOCUMENT_CODEC = new DocumentCodec(); private static final Duration TIMEOUT = Duration.ofMillis(3_000); - - - private AsyncConnection mockConnection; - private ConnectionDescription mockDescription; - private AsyncConnectionSource connectionSource; private OperationContext operationContext; private TimeoutContext timeoutContext; - private ServerDescription serverDescription; + private AsyncCoreCursor coreCursor; @BeforeEach void setUp() { - ServerVersion serverVersion = new ServerVersion(3, 6); - - mockConnection = mock(AsyncConnection.class, "connection"); - mockDescription = mock(ConnectionDescription.class); - when(mockDescription.getMaxWireVersion()).thenReturn(getMaxWireVersionForServerVersion(serverVersion.getVersionList())); - when(mockDescription.getServerType()).thenReturn(ServerType.LOAD_BALANCER); - when(mockConnection.getDescription()).thenReturn(mockDescription); - when(mockConnection.retain()).thenReturn(mockConnection); - - connectionSource = mock(AsyncConnectionSource.class); - operationContext = mock(OperationContext.class); - timeoutContext = new TimeoutContext(TimeoutSettings.create( - MongoClientSettings.builder().timeout(TIMEOUT.toMillis(), MILLISECONDS).build())); - serverDescription = mock(ServerDescription.class); - when(operationContext.getTimeoutContext()).thenReturn(timeoutContext); - when(connectionSource.getOperationContext()).thenReturn(operationContext); - doAnswer(invocation -> { - SingleResultCallback callback = invocation.getArgument(0); - callback.onResult(mockConnection, null); - return null; - }).when(connectionSource).getConnection(any()); - when(connectionSource.getServerDescription()).thenReturn(serverDescription); + coreCursor = mock(AsyncCoreCursor.class); + timeoutContext = spy(new TimeoutContext(TimeoutSettings.create( + MongoClientSettings.builder().timeout(TIMEOUT.toMillis(), MILLISECONDS).build()))); + operationContext = spy(new OperationContext( + IgnorableRequestContext.INSTANCE, + NoOpSessionContext.INSTANCE, + timeoutContext, + null)); } - - @Test - void shouldSkipKillsCursorsCommandWhenNetworkErrorOccurs() { - //given - doAnswer(invocation -> { - SingleResultCallback argument = invocation.getArgument(6); - argument.onResult(null, new MongoSocketException("test", new ServerAddress())); - return null; - }).when(mockConnection).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); - when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); - AsyncCommandBatchCursor commandBatchCursor = createBatchCursor(0); - - //when - commandBatchCursor.next((result, t) -> { - Assertions.assertNull(result); - Assertions.assertNotNull(t); - Assertions.assertEquals(MongoSocketException.class, t.getClass()); - }); - - //then - commandBatchCursor.close(); - verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); - } - - - @Test - void shouldNotSkipKillsCursorsCommandWhenTimeoutExceptionDoesNotHaveNetworkErrorCause() { - //given - doAnswer(invocation -> { - SingleResultCallback argument = invocation.getArgument(6); - argument.onResult(null, new MongoOperationTimeoutException("test")); - return null; - }).when(mockConnection).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); - when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); - - AsyncCommandBatchCursor commandBatchCursor = createBatchCursor(0); - - //when - commandBatchCursor.next((result, t) -> { - Assertions.assertNull(result); - Assertions.assertNotNull(t); - Assertions.assertEquals(MongoOperationTimeoutException.class, t.getClass()); - }); - - commandBatchCursor.close(); - - - //then - verify(mockConnection, times(2)).commandAsync(any(), - any(), any(), any(), any(), any(), any()); - verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), - argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any(), any()); - verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), - argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any(), any()); - } - - @Test - void shouldSkipKillsCursorsCommandWhenTimeoutExceptionHaveNetworkErrorCause() { - //given - doAnswer(invocation -> { - SingleResultCallback argument = invocation.getArgument(6); - argument.onResult(null, new MongoOperationTimeoutException("test", new MongoSocketException("test", new ServerAddress()))); - return null; - }).when(mockConnection).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); - when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); - - AsyncCommandBatchCursor commandBatchCursor = createBatchCursor(0); - - //when - commandBatchCursor.next((result, t) -> { - Assertions.assertNull(result); - Assertions.assertNotNull(t); - Assertions.assertEquals(MongoOperationTimeoutException.class, t.getClass()); - }); - - commandBatchCursor.close(); - - //then - verify(mockConnection, times(1)).commandAsync(any(), - any(), any(), any(), any(), any(), any()); - verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), - argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any(), any()); - verify(mockConnection, never()).commandAsync(eq(NAMESPACE.getDatabaseName()), - argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any(), any()); - } - - @Test + @ParameterizedTest(name = "closeShouldResetTimeoutContextToDefaultMaxTime with maxTimeMS={0}") @SuppressWarnings("try") - void closeShouldResetTimeoutContextToDefaultMaxTime() { - long maxTimeMS = 10; - com.mongodb.assertions.Assertions.assertTrue(maxTimeMS < TIMEOUT.toMillis()); + @ValueSource(ints = {10, 0}) + void closeShouldResetTimeoutContextToDefaultMaxTime(final int maxTimeMS) { + //given try (AsyncCommandBatchCursor commandBatchCursor = createBatchCursor(maxTimeMS)) { - // verify that the `maxTimeMS` override was applied - timeoutContext.runMaxTimeMS(remainingMillis -> assertTrue(remainingMillis <= maxTimeMS)); - } catch (Exception e) { - throw new RuntimeException(e); - } - timeoutContext.runMaxTimeMS(remainingMillis -> { - // verify that the `maxTimeMS` override was reset - assertTrue(remainingMillis > maxTimeMS); - assertTrue(remainingMillis <= TIMEOUT.toMillis()); - }); - } - @ParameterizedTest - @ValueSource(booleans = {false, true}) - void closeShouldNotResetOriginalTimeout(final boolean disableTimeoutResetWhenClosing) { - doAnswer(invocation -> { - SingleResultCallback argument = invocation.getArgument(6); - argument.onResult(null, null); - return null; - }).when(mockConnection).commandAsync(any(), any(), any(), any(), any(), any(), any()); - Duration thirdOfTimeout = TIMEOUT.dividedBy(3); - com.mongodb.assertions.Assertions.assertTrue(thirdOfTimeout.toMillis() > 0); - try (AsyncCommandBatchCursor commandBatchCursor = createBatchCursor(0)) { - if (disableTimeoutResetWhenClosing) { - commandBatchCursor.disableTimeoutResetWhenClosing(); - } - try { - Thread.sleep(thirdOfTimeout.toMillis()); - } catch (InterruptedException e) { - throw interruptAndCreateMongoInterruptedException(null, e); - } - when(mockConnection.release()).then(invocation -> { - Thread.sleep(thirdOfTimeout.toMillis()); - return null; + //when + commandBatchCursor.close(); + + // then verify that the `maxTimeMS` override was not applied + ArgumentCaptor operationContextArgumentCaptor = ArgumentCaptor.forClass(OperationContext.class); + verify(coreCursor).close(operationContextArgumentCaptor.capture()); + OperationContext operationContextForNext = operationContextArgumentCaptor.getValue(); + operationContextForNext.getTimeoutContext().runMaxTimeMS(remainingMillis -> { + // verify that the `maxTimeMS` override was reset + assertTrue(remainingMillis > maxTimeMS); + assertTrue(remainingMillis <= TIMEOUT.toMillis()); }); - } catch (Exception e) { - throw new RuntimeException(e); + } - verify(mockConnection, times(1)).release(); - // at this point at least (2 * thirdOfTimeout) have passed - com.mongodb.assertions.Assertions.assertNotNull(timeoutContext.getTimeout()).run( - MILLISECONDS, - com.mongodb.assertions.Assertions::fail, - remainingMillis -> { - // Verify that the original timeout has not been intact. - // If `close` had reset it, we would have observed more than `thirdOfTimeout` left. - assertTrue(remainingMillis <= thirdOfTimeout.toMillis()); - }, - Assertions::fail); } - private AsyncCommandBatchCursor createBatchCursor(final long maxTimeMS) { - return new AsyncCommandBatchCursor( + return new AsyncCommandBatchCursor<>( TimeoutMode.CURSOR_LIFETIME, - COMMAND_CURSOR_DOCUMENT, - 0, maxTimeMS, - DOCUMENT_CODEC, - null, - connectionSource, - mockConnection); + operationContext, + coreCursor); } } diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandCoreCursorTest.java b/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandCoreCursorTest.java new file mode 100644 index 00000000000..453166438eb --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandCoreCursorTest.java @@ -0,0 +1,211 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.operation; + +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoNamespace; +import com.mongodb.MongoOperationTimeoutException; +import com.mongodb.MongoSocketException; +import com.mongodb.ServerAddress; +import com.mongodb.connection.ConnectionDescription; +import com.mongodb.connection.ServerDescription; +import com.mongodb.connection.ServerType; +import com.mongodb.connection.ServerVersion; +import com.mongodb.internal.IgnorableRequestContext; +import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.TimeoutSettings; +import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.binding.AsyncConnectionSource; +import com.mongodb.internal.connection.AsyncConnection; +import com.mongodb.internal.connection.NoOpSessionContext; +import com.mongodb.internal.connection.OperationContext; +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonInt64; +import org.bson.BsonString; +import org.bson.Document; +import org.bson.codecs.Decoder; +import org.bson.codecs.DocumentCodec; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.time.Duration; + +import static com.mongodb.internal.operation.OperationUnitSpecification.getMaxWireVersionForServerVersion; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class AsyncCommandCoreCursorTest { + + private static final MongoNamespace NAMESPACE = new MongoNamespace("test", "test"); + private static final BsonInt64 CURSOR_ID = new BsonInt64(1); + private static final BsonDocument COMMAND_CURSOR_DOCUMENT = new BsonDocument("ok", new BsonInt32(1)) + .append("cursor", + new BsonDocument("ns", new BsonString(NAMESPACE.getFullName())) + .append("id", CURSOR_ID) + .append("firstBatch", new BsonArrayWrapper<>(new BsonArray()))); + + private static final Decoder DOCUMENT_CODEC = new DocumentCodec(); + private static final Duration TIMEOUT = Duration.ofMillis(3_000); + + + private AsyncConnection mockConnection; + private ConnectionDescription mockDescription; + private AsyncConnectionSource connectionSource; + private OperationContext operationContext; + private TimeoutContext timeoutContext; + private ServerDescription serverDescription; + private AsyncCoreCursor coreCursor; + + @BeforeEach + void setUp() { + coreCursor = mock(AsyncCoreCursor.class); + timeoutContext = spy(new TimeoutContext(TimeoutSettings.create( + MongoClientSettings.builder().timeout(TIMEOUT.toMillis(), MILLISECONDS).build()))); + operationContext = spy(new OperationContext( + IgnorableRequestContext.INSTANCE, + NoOpSessionContext.INSTANCE, + timeoutContext, + null)); + + ServerVersion serverVersion = new ServerVersion(3, 6); + mockConnection = mock(AsyncConnection.class, "connection"); + mockDescription = mock(ConnectionDescription.class); + when(mockDescription.getMaxWireVersion()).thenReturn(getMaxWireVersionForServerVersion(serverVersion.getVersionList())); + when(mockDescription.getServerType()).thenReturn(ServerType.LOAD_BALANCER); + when(mockConnection.getDescription()).thenReturn(mockDescription); + when(mockConnection.retain()).thenReturn(mockConnection); + + connectionSource = mock(AsyncConnectionSource.class); + serverDescription = mock(ServerDescription.class); + when(operationContext.getTimeoutContext()).thenReturn(timeoutContext); + doAnswer(invocation -> { + SingleResultCallback callback = invocation.getArgument(0); + callback.onResult(mockConnection, null); + return null; + }).when(connectionSource).getConnection(any(), any()); + when(connectionSource.getServerDescription()).thenReturn(serverDescription); + } + + + @Test + void shouldSkipKillsCursorsCommandWhenNetworkErrorOccurs() { + //given + doAnswer(invocation -> { + SingleResultCallback argument = invocation.getArgument(6); + argument.onResult(null, new MongoSocketException("test", new ServerAddress())); + return null; + }).when(mockConnection).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + AsyncCoreCursor commandBatchCursor = createBatchCursor(); + + //when + commandBatchCursor.next(operationContext, (result, t) -> { + Assertions.assertNull(result); + Assertions.assertNotNull(t); + Assertions.assertEquals(MongoSocketException.class, t.getClass()); + }); + + //then + commandBatchCursor.close(operationContext); + verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); + } + + + @Test + void shouldNotSkipKillsCursorsCommandWhenTimeoutExceptionDoesNotHaveNetworkErrorCause() { + //given + doAnswer(invocation -> { + SingleResultCallback argument = invocation.getArgument(6); + argument.onResult(null, new MongoOperationTimeoutException("test")); + return null; + }).when(mockConnection).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + + AsyncCoreCursor commandBatchCursor = createBatchCursor(); + + //when + commandBatchCursor.next(operationContext, (result, t) -> { + Assertions.assertNull(result); + Assertions.assertNotNull(t); + Assertions.assertEquals(MongoOperationTimeoutException.class, t.getClass()); + }); + + commandBatchCursor.close(operationContext); + + + //then + verify(mockConnection, times(2)).commandAsync(any(), + any(), any(), any(), any(), any(), any()); + verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any(), any()); + verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any(), any()); + } + + @Test + void shouldSkipKillsCursorsCommandWhenTimeoutExceptionHaveNetworkErrorCause() { + //given + doAnswer(invocation -> { + SingleResultCallback argument = invocation.getArgument(6); + argument.onResult(null, new MongoOperationTimeoutException("test", new MongoSocketException("test", new ServerAddress()))); + return null; + }).when(mockConnection).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + + AsyncCoreCursor commandBatchCursor = createBatchCursor(); + + //when + commandBatchCursor.next(operationContext, (result, t) -> { + Assertions.assertNull(result); + Assertions.assertNotNull(t); + Assertions.assertEquals(MongoOperationTimeoutException.class, t.getClass()); + }); + + commandBatchCursor.close(operationContext); + + //then + verify(mockConnection, times(1)).commandAsync(any(), + any(), any(), any(), any(), any(), any()); + verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any(), any()); + verify(mockConnection, never()).commandAsync(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any(), any()); + } + + + private AsyncCoreCursor createBatchCursor() { + return new AsyncCommandCoreCursor<>( + COMMAND_CURSOR_DOCUMENT, + 0, + DOCUMENT_CODEC, + null, + connectionSource, + mockConnection); + } +} diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/ChangeStreamOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/ChangeStreamOperationSpecification.groovy index 9134375ffec..19285eda077 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/ChangeStreamOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/ChangeStreamOperationSpecification.groovy @@ -37,6 +37,7 @@ import com.mongodb.internal.binding.ReadBinding import com.mongodb.internal.client.model.changestream.ChangeStreamLevel import com.mongodb.internal.connection.AsyncConnection import com.mongodb.internal.connection.Connection +import com.mongodb.internal.connection.OperationContext import com.mongodb.internal.session.SessionContext import org.bson.BsonArray import org.bson.BsonBoolean @@ -641,10 +642,8 @@ class ChangeStreamOperationSpecification extends OperationFunctionalSpecificatio }) def changeStream def binding = Stub(ReadBinding) { - getOperationContext() >> operationContext - getReadConnectionSource() >> Stub(ConnectionSource) { - getOperationContext() >> operationContext - getConnection() >> Stub(Connection) { + getReadConnectionSource(_) >> Stub(ConnectionSource) { + getConnection(_) >> Stub(Connection) { command(*_) >> { changeStream = getChangeStream(it[1]) new BsonDocument('cursor', new BsonDocument('id', new BsonInt64(1)) @@ -662,7 +661,7 @@ class ChangeStreamOperationSpecification extends OperationFunctionalSpecificatio new ChangeStreamOperation(helper.getNamespace(), FullDocument.DEFAULT, FullDocumentBeforeChange.DEFAULT, [], CODEC) .resumeAfter(new BsonDocument()) - .execute(binding) + .execute(binding, operationContext) then: changeStream.containsKey('resumeAfter') @@ -672,7 +671,7 @@ class ChangeStreamOperationSpecification extends OperationFunctionalSpecificatio new ChangeStreamOperation(helper.getNamespace(), FullDocument.DEFAULT, FullDocumentBeforeChange.DEFAULT, [], CODEC) .startAfter(new BsonDocument()) - .execute(binding) + .execute(binding, operationContext) then: changeStream.containsKey('startAfter') @@ -683,7 +682,7 @@ class ChangeStreamOperationSpecification extends OperationFunctionalSpecificatio new ChangeStreamOperation(helper.getNamespace(), FullDocument.DEFAULT, FullDocumentBeforeChange.DEFAULT, [], CODEC) .startAtOperationTime(startAtTime) - .execute(binding) + .execute(binding, operationContext) then: changeStream.getTimestamp('startAtOperationTime') == startAtTime @@ -698,11 +697,9 @@ class ChangeStreamOperationSpecification extends OperationFunctionalSpecificatio }) def changeStream def binding = Stub(AsyncReadBinding) { - getOperationContext() >> operationContext - getReadConnectionSource(_) >> { + getReadConnectionSource(_ as OperationContext, _ as SingleResultCallback) >> { it.last().onResult(Stub(AsyncConnectionSource) { - getOperationContext() >> operationContext - getConnection(_) >> { + getConnection(_ as OperationContext, _ as SingleResultCallback) >> { it.last().onResult(Stub(AsyncConnection) { commandAsync(*_) >> { changeStream = getChangeStream(it[1]) @@ -723,7 +720,7 @@ class ChangeStreamOperationSpecification extends OperationFunctionalSpecificatio new ChangeStreamOperation(helper.getNamespace(), FullDocument.DEFAULT, FullDocumentBeforeChange.DEFAULT, [], CODEC) .resumeAfter(new BsonDocument()) - .executeAsync(binding, Stub(SingleResultCallback)) + .executeAsync(binding, operationContext, Stub(SingleResultCallback)) then: changeStream.containsKey('resumeAfter') @@ -733,7 +730,7 @@ class ChangeStreamOperationSpecification extends OperationFunctionalSpecificatio new ChangeStreamOperation(helper.getNamespace(), FullDocument.DEFAULT, FullDocumentBeforeChange.DEFAULT, [], CODEC) .startAfter(new BsonDocument()) - .executeAsync(binding, Stub(SingleResultCallback)) + .executeAsync(binding, operationContext, Stub(SingleResultCallback)) then: changeStream.containsKey('startAfter') @@ -744,7 +741,7 @@ class ChangeStreamOperationSpecification extends OperationFunctionalSpecificatio new ChangeStreamOperation(helper.getNamespace(), FullDocument.DEFAULT, FullDocumentBeforeChange.DEFAULT, [], CODEC) .startAtOperationTime(startAtTime) - .executeAsync(binding, Stub(SingleResultCallback)) + .executeAsync(binding, operationContext, Stub(SingleResultCallback)) then: changeStream.getTimestamp('startAtOperationTime') == startAtTime diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/CommandBatchCursorFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/internal/operation/CommandBatchCursorFunctionalTest.java index d9861c71659..3d0774d425a 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/CommandBatchCursorFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/CommandBatchCursorFunctionalTest.java @@ -16,6 +16,7 @@ package com.mongodb.internal.operation; +import com.mongodb.ClusterFixture; import com.mongodb.MongoCursorNotFoundException; import com.mongodb.MongoQueryException; import com.mongodb.ReadPreference; @@ -54,6 +55,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static com.mongodb.ClusterFixture.OPERATION_CONTEXT; import static com.mongodb.ClusterFixture.checkReferenceCountReachesTarget; import static com.mongodb.ClusterFixture.getBinding; import static com.mongodb.ClusterFixture.getReferenceCountAfterTimeout; @@ -85,8 +87,8 @@ void setup() { .collect(Collectors.toList()); getCollectionHelper().insertDocuments(documents); - connectionSource = getBinding().getWriteConnectionSource(); - connection = connectionSource.getConnection(); + connectionSource = getBinding().getWriteConnectionSource(ClusterFixture.OPERATION_CONTEXT); + connection = connectionSource.getConnection(ClusterFixture.OPERATION_CONTEXT); } @AfterEach @@ -107,8 +109,8 @@ void cleanup() { void shouldExhaustCursorWithMultipleBatches() { // given BsonDocument commandResult = executeFindCommand(0, 3); // Fetch in batches of size 3 - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 3, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 3, DOCUMENT_DECODER, null, connectionSource, connection)); // when List> result = cursor.exhaust(); @@ -125,8 +127,8 @@ void shouldExhaustCursorWithMultipleBatches() { void shouldExhaustCursorWithClosedCursor() { // given BsonDocument commandResult = executeFindCommand(0, 3); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 3, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 3, DOCUMENT_DECODER, null, connectionSource, connection)); cursor.close(); // when & then @@ -141,8 +143,8 @@ void shouldExhaustCursorWithEmptyCursor() { getCollectionHelper().deleteMany(Filters.empty()); BsonDocument commandResult = executeFindCommand(0, 3); // No documents to fetch - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 3, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 3, DOCUMENT_DECODER, null, connectionSource, connection)); // when List> result = cursor.exhaust(); @@ -155,8 +157,8 @@ void shouldExhaustCursorWithEmptyCursor() { @DisplayName("server cursor should not be null") void theServerCursorShouldNotBeNull() { BsonDocument commandResult = executeFindCommand(2); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); assertNotNull(cursor.getServerCursor()); } @@ -165,8 +167,8 @@ void theServerCursorShouldNotBeNull() { @DisplayName("test server address should not be null") void theServerAddressShouldNotNull() { BsonDocument commandResult = executeFindCommand(); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); assertNotNull(cursor.getServerAddress()); } @@ -175,8 +177,8 @@ void theServerAddressShouldNotNull() { @DisplayName("should get Exceptions for operations on the cursor after closing") void shouldGetExceptionsForOperationsOnTheCursorAfterClosing() { BsonDocument commandResult = executeFindCommand(); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); cursor.close(); @@ -190,8 +192,8 @@ void shouldGetExceptionsForOperationsOnTheCursorAfterClosing() { @DisplayName("should throw an Exception when going off the end") void shouldThrowAnExceptionWhenGoingOffTheEnd() { BsonDocument commandResult = executeFindCommand(1); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); cursor.next(); cursor.next(); @@ -202,8 +204,8 @@ void shouldThrowAnExceptionWhenGoingOffTheEnd() { @DisplayName("test cursor remove") void testCursorRemove() { BsonDocument commandResult = executeFindCommand(); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); assertThrows(UnsupportedOperationException.class, () -> cursor.remove()); } @@ -212,8 +214,8 @@ void testCursorRemove() { @DisplayName("test normal exhaustion") void testNormalExhaustion() { BsonDocument commandResult = executeFindCommand(); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); assertEquals(10, cursorFlatten().size()); } @@ -223,8 +225,8 @@ void testNormalExhaustion() { @DisplayName("test limit exhaustion") void testLimitExhaustion(final int limit, final int batchSize, final int expectedTotal) { BsonDocument commandResult = executeFindCommand(limit, batchSize); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, batchSize, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); assertEquals(expectedTotal, cursorFlatten().size()); @@ -242,8 +244,8 @@ void shouldBlockWaitingForNextBatchOnATailableCursor(final boolean awaitData, fi BsonDocument commandResult = executeFindCommand(new BsonDocument("ts", new BsonDocument("$gte", new BsonTimestamp(5, 0))), 0, 2, true, awaitData); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, maxTimeMS, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, maxTimeMS, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); assertTrue(cursor.hasNext()); assertEquals(1, cursor.next().get(0).get("_id")); @@ -265,8 +267,8 @@ void testTryNextWithTailable() { BsonDocument commandResult = executeFindCommand(new BsonDocument("ts", new BsonDocument("$gte", new BsonTimestamp(5, 0))), 0, 2, true, true); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); List nextBatch = cursor.tryNext(); assertNotNull(nextBatch); @@ -291,8 +293,8 @@ void hasNextShouldThrowWhenCursorIsClosedInAnotherThread() throws InterruptedExc BsonDocument commandResult = executeFindCommand(new BsonDocument("ts", new BsonDocument("$gte", new BsonTimestamp(5, 0))), 0, 2, true, true); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); assertTrue(cursor.hasNext()); assertEquals(1, cursor.next().get(0).get("_id")); @@ -318,8 +320,8 @@ void testMaxTimeMS() { long maxTimeMS = 500; BsonDocument commandResult = executeFindCommand(new BsonDocument("ts", new BsonDocument("$gte", new BsonTimestamp(5, 0))), 0, 2, true, true); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, maxTimeMS, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, maxTimeMS, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); List nextBatch = cursor.tryNext(); assertNotNull(nextBatch); @@ -342,8 +344,8 @@ void testTailableInterrupt() throws InterruptedException { BsonDocument commandResult = executeFindCommand(new BsonDocument("ts", new BsonDocument("$gte", new BsonTimestamp(5, 0))), 0, 2, true, true); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); CountDownLatch latch = new CountDownLatch(1); AtomicInteger seen = new AtomicInteger(); @@ -375,8 +377,8 @@ void testTailableInterrupt() throws InterruptedException { void shouldKillCursorIfLimitIsReachedOnInitialQuery() { assumeFalse(isSharded()); BsonDocument commandResult = executeFindCommand(5, 10); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); assertNotNull(cursor.next()); assertFalse(cursor.hasNext()); @@ -388,8 +390,8 @@ void shouldKillCursorIfLimitIsReachedOnInitialQuery() { void shouldKillCursorIfLimitIsReachedOnGetMore() { assumeFalse(isSharded()); BsonDocument commandResult = executeFindCommand(5, 3); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 3, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 3, DOCUMENT_DECODER, null, connectionSource, connection)); ServerCursor serverCursor = cursor.getServerCursor(); assertNotNull(serverCursor); @@ -407,8 +409,8 @@ void shouldKillCursorIfLimitIsReachedOnGetMore() { void shouldReleaseConnectionSourceIfLimitIsReachedOnInitialQuery() { assumeFalse(isSharded()); BsonDocument commandResult = executeFindCommand(5, 10); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); assertNull(cursor.getServerCursor()); assertDoesNotThrow(() -> checkReferenceCountReachesTarget(connectionSource, 1)); @@ -420,8 +422,8 @@ void shouldReleaseConnectionSourceIfLimitIsReachedOnInitialQuery() { void shouldReleaseConnectionSourceIfLimitIsReachedOnGetMore() { assumeFalse(isSharded()); BsonDocument commandResult = executeFindCommand(5, 3); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 3, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 3, DOCUMENT_DECODER, null, connectionSource, connection)); assertNotNull(cursor.next()); assertNotNull(cursor.next()); @@ -433,8 +435,8 @@ void shouldReleaseConnectionSourceIfLimitIsReachedOnGetMore() { @DisplayName("test limit with get more") void testLimitWithGetMore() { BsonDocument commandResult = executeFindCommand(5, 2); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); assertNotNull(cursor.next()); assertNotNull(cursor.next()); @@ -454,8 +456,8 @@ void testLimitWithLargeDocuments() { ); BsonDocument commandResult = executeFindCommand(300, 0); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 0, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 0, DOCUMENT_DECODER, null, connectionSource, connection)); assertEquals(300, cursorFlatten().size()); } @@ -464,8 +466,8 @@ void testLimitWithLargeDocuments() { @DisplayName("should respect batch size") void shouldRespectBatchSize() { BsonDocument commandResult = executeFindCommand(2); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); assertEquals(2, cursor.getBatchSize()); assertEquals(2, cursor.next().size()); @@ -481,16 +483,16 @@ void shouldRespectBatchSize() { @DisplayName("should throw cursor not found exception") void shouldThrowCursorNotFoundException() { BsonDocument commandResult = executeFindCommand(2); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); ServerCursor serverCursor = cursor.getServerCursor(); assertNotNull(serverCursor); - Connection localConnection = connectionSource.getConnection(); + Connection localConnection = connectionSource.getConnection(OPERATION_CONTEXT); localConnection.command(getNamespace().getDatabaseName(), new BsonDocument("killCursors", new BsonString(getNamespace().getCollectionName())) .append("cursors", new BsonArray(singletonList(new BsonInt64(serverCursor.getId())))), - NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), connectionSource.getOperationContext()); + NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), OPERATION_CONTEXT); localConnection.release(); cursor.next(); @@ -504,8 +506,8 @@ void shouldThrowCursorNotFoundException() { @DisplayName("should report available documents") void shouldReportAvailableDocuments() { BsonDocument commandResult = executeFindCommand(3); - cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, commandResult, 2, 0, DOCUMENT_DECODER, - null, connectionSource, connection); + cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, 0, OPERATION_CONTEXT, + new CommandCoreCursor<>(commandResult, 2, DOCUMENT_DECODER, null, connectionSource, connection)); assertEquals(3, cursor.available()); @@ -582,7 +584,7 @@ private BsonDocument executeFindCommand(final BsonDocument filter, final int lim BsonDocument results = connection.command(getDatabaseName(), findCommand, NoOpFieldNameValidator.INSTANCE, readPreference, CommandResultDocumentCodec.create(DOCUMENT_DECODER, FIRST_BATCH), - connectionSource.getOperationContext()); + OPERATION_CONTEXT); assertNotNull(results); return results; diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/CountDocumentsOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/CountDocumentsOperationSpecification.groovy index 8d13cba9f61..1e538b1af11 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/CountDocumentsOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/CountDocumentsOperationSpecification.groovy @@ -16,6 +16,7 @@ package com.mongodb.internal.operation +import com.mongodb.ClusterFixture import com.mongodb.MongoException import com.mongodb.MongoNamespace import com.mongodb.OperationFunctionalSpecification @@ -26,6 +27,7 @@ import com.mongodb.connection.ClusterId import com.mongodb.connection.ConnectionDescription import com.mongodb.connection.ConnectionId import com.mongodb.connection.ServerId +import com.mongodb.internal.async.SingleResultCallback import com.mongodb.internal.binding.AsyncConnectionSource import com.mongodb.internal.binding.AsyncReadBinding import com.mongodb.internal.binding.ConnectionSource @@ -33,6 +35,7 @@ import com.mongodb.internal.binding.ReadBinding import com.mongodb.internal.bulk.IndexRequest import com.mongodb.internal.connection.AsyncConnection import com.mongodb.internal.connection.Connection +import com.mongodb.internal.connection.OperationContext import com.mongodb.internal.session.SessionContext import org.bson.BsonArray import org.bson.BsonDocument @@ -45,7 +48,6 @@ import org.bson.codecs.DocumentCodec import static com.mongodb.ClusterFixture.OPERATION_CONTEXT import static com.mongodb.ClusterFixture.executeAsync -import static com.mongodb.ClusterFixture.getBinding import static com.mongodb.connection.ServerType.STANDALONE import static com.mongodb.internal.operation.OperationReadConcernHelper.appendReadConcernToCommand import static com.mongodb.internal.operation.ServerVersionHelper.UNKNOWN_WIRE_VERSION @@ -151,8 +153,10 @@ class CountDocumentsOperationSpecification extends OperationFunctionalSpecificat def 'should use hint with the count'() { given: def indexDefinition = new BsonDocument('y', new BsonInt32(1)) + + def binding = ClusterFixture.getBinding() new CreateIndexesOperation(getNamespace(), [new IndexRequest(indexDefinition).sparse(true)], null) - .execute(getBinding()) + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) def operation = new CountDocumentsOperation(getNamespace()).hint(indexDefinition) when: @@ -260,11 +264,9 @@ class CountDocumentsOperationSpecification extends OperationFunctionalSpecificat def source = Stub(ConnectionSource) def connection = Mock(Connection) binding.readPreference >> ReadPreference.primary() - binding.operationContext >> operationContext - binding.readConnectionSource >> source - source.connection >> connection + binding.getReadConnectionSource(_) >> source + source.getConnection(_) >> connection source.retain() >> source - source.operationContext >> operationContext def pipeline = new BsonArray([BsonDocument.parse('{ $match: {}}'), BsonDocument.parse('{$group: {_id: 1, n: {$sum: 1}}}')]) def commandDocument = new BsonDocument('aggregate', new BsonString(getCollectionName())) .append('pipeline', pipeline) @@ -274,12 +276,12 @@ class CountDocumentsOperationSpecification extends OperationFunctionalSpecificat def operation = new CountDocumentsOperation(getNamespace()) when: - operation.execute(binding) + operation.execute(binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), 6, STANDALONE, 1000, 100000, 100000, []) - 1 * connection.command(_, commandDocument, _, _, _, operationContext) >> helper.cursorResult + 1 * connection.command(_, commandDocument, _, _, _, _) >> helper.cursorResult 1 * connection.release() where: @@ -300,11 +302,9 @@ class CountDocumentsOperationSpecification extends OperationFunctionalSpecificat def source = Stub(AsyncConnectionSource) def connection = Mock(AsyncConnection) binding.readPreference >> ReadPreference.primary() - binding.operationContext >> operationContext - binding.getReadConnectionSource(_) >> { it[0].onResult(source, null) } - source.getConnection(_) >> { it[0].onResult(connection, null) } + binding.getReadConnectionSource(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(source, null) } + source.getConnection(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(connection, null) } source.retain() >> source - source.operationContext >> operationContext def pipeline = new BsonArray([BsonDocument.parse('{ $match: {}}'), BsonDocument.parse('{$group: {_id: 1, n: {$sum: 1}}}')]) def commandDocument = new BsonDocument('aggregate', new BsonString(getCollectionName())) .append('pipeline', pipeline) @@ -314,7 +314,7 @@ class CountDocumentsOperationSpecification extends OperationFunctionalSpecificat def operation = new CountDocumentsOperation(getNamespace()) when: - executeAsync(operation, binding) + executeAsync(operation, binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/CreateCollectionOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/CreateCollectionOperationSpecification.groovy index b33ec785094..860ffb4a2bf 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/CreateCollectionOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/CreateCollectionOperationSpecification.groovy @@ -16,6 +16,7 @@ package com.mongodb.internal.operation +import com.mongodb.ClusterFixture import com.mongodb.MongoBulkWriteException import com.mongodb.MongoWriteConcernException import com.mongodb.OperationFunctionalSpecification @@ -108,9 +109,13 @@ class CreateCollectionOperationSpecification extends OperationFunctionalSpecific when: execute(operation, async) + then: + def binding = ClusterFixture.getBinding() new ListCollectionsOperation(getDatabaseName(), new BsonDocumentCodec()) - .execute(getBinding()).next().find { it -> it.getString('name').value == getCollectionName() } + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) + .next() + .find { it -> it.getString('name').value == getCollectionName() } .getDocument('options').getDocument('storageEngine') == operation.storageEngineOptions where: @@ -127,8 +132,11 @@ class CreateCollectionOperationSpecification extends OperationFunctionalSpecific execute(operation, async) then: + def binding = ClusterFixture.getBinding() new ListCollectionsOperation(getDatabaseName(), new BsonDocumentCodec()) - .execute(getBinding()).next().find { it -> it.getString('name').value == getCollectionName() } + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) + .next() + .find { it -> it.getString('name').value == getCollectionName() } .getDocument('options').getDocument('storageEngine') == operation.storageEngineOptions where: async << [true, false] @@ -244,8 +252,10 @@ class CreateCollectionOperationSpecification extends OperationFunctionalSpecific } def getCollectionInfo(String collectionName) { + def binding = getBinding() new ListCollectionsOperation(databaseName, new BsonDocumentCodec()).filter(new BsonDocument('name', - new BsonString(collectionName))).execute(getBinding()).tryNext()?.head() + new BsonString(collectionName))).execute(binding, + ClusterFixture.getOperationContext(binding.getReadPreference())).tryNext()?.head() } def collectionNameExists(String collectionName) { @@ -255,15 +265,16 @@ class CreateCollectionOperationSpecification extends OperationFunctionalSpecific BsonDocument storageStats() { if (serverVersionLessThan(6, 2)) { + def binding = getBinding() return new CommandReadOperation<>(getDatabaseName(), new BsonDocument('collStats', new BsonString(getCollectionName())), - new BsonDocumentCodec()).execute(getBinding()) + new BsonDocumentCodec()).execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) } + def binding = ClusterFixture.getBinding() BatchCursor cursor = new AggregateOperation( - getNamespace(), singletonList(new BsonDocument('$collStats', new BsonDocument('storageStats', new BsonDocument()))), - new BsonDocumentCodec()).execute(getBinding()) + new BsonDocumentCodec()).execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) try { return cursor.next().first().getDocument('storageStats') } finally { diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/CreateIndexesOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/CreateIndexesOperationSpecification.groovy index 78a9914e022..fce0904b786 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/CreateIndexesOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/CreateIndexesOperationSpecification.groovy @@ -16,6 +16,7 @@ package com.mongodb.internal.operation +import com.mongodb.ClusterFixture import com.mongodb.CreateIndexCommitQuorum import com.mongodb.DuplicateKeyException import com.mongodb.MongoClientException @@ -34,7 +35,6 @@ import org.bson.Document import org.bson.codecs.DocumentCodec import spock.lang.IgnoreIf -import static com.mongodb.ClusterFixture.getBinding import static com.mongodb.ClusterFixture.isDiscoverableReplicaSet import static com.mongodb.ClusterFixture.serverVersionAtLeast import static com.mongodb.ClusterFixture.serverVersionLessThan @@ -491,7 +491,10 @@ class CreateIndexesOperationSpecification extends OperationFunctionalSpecificati List getIndexes() { def indexes = [] - def cursor = new ListIndexesOperation(getNamespace(), new DocumentCodec()).execute(getBinding()) + + def binding = ClusterFixture.getBinding() + def cursor = new ListIndexesOperation(getNamespace(), new DocumentCodec()) + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) while (cursor.hasNext()) { indexes.addAll(cursor.next()) } diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/CreateViewOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/CreateViewOperationSpecification.groovy index 07a35800242..b8145de44b4 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/CreateViewOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/CreateViewOperationSpecification.groovy @@ -30,6 +30,7 @@ import org.bson.codecs.BsonDocumentCodec import spock.lang.IgnoreIf import static com.mongodb.ClusterFixture.getBinding +import static com.mongodb.ClusterFixture.getOperationContext import static com.mongodb.ClusterFixture.isDiscoverableReplicaSet class CreateViewOperationSpecification extends OperationFunctionalSpecification { @@ -121,8 +122,9 @@ class CreateViewOperationSpecification extends OperationFunctionalSpecification } def getCollectionInfo(String collectionName) { + def binding = getBinding() new ListCollectionsOperation(databaseName, new BsonDocumentCodec()).filter(new BsonDocument('name', - new BsonString(collectionName))).execute(getBinding()).tryNext()?.head() + new BsonString(collectionName))).execute(binding, getOperationContext(binding.getReadPreference())).tryNext()?.head() } def collectionNameExists(String collectionName) { diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/DistinctOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/DistinctOperationSpecification.groovy index 726a3723df5..f73c301d422 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/DistinctOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/DistinctOperationSpecification.groovy @@ -33,6 +33,7 @@ import com.mongodb.internal.binding.ConnectionSource import com.mongodb.internal.binding.ReadBinding import com.mongodb.internal.connection.AsyncConnection import com.mongodb.internal.connection.Connection +import com.mongodb.internal.connection.OperationContext import com.mongodb.internal.session.SessionContext import org.bson.BsonBoolean import org.bson.BsonDocument @@ -56,6 +57,7 @@ import static com.mongodb.connection.ServerType.STANDALONE import static com.mongodb.internal.operation.OperationReadConcernHelper.appendReadConcernToCommand import static com.mongodb.internal.operation.ServerVersionHelper.UNKNOWN_WIRE_VERSION import static org.bson.codecs.configuration.CodecRegistries.fromProviders +import static org.junit.jupiter.api.Assertions.assertEquals class DistinctOperationSpecification extends OperationFunctionalSpecification { @@ -230,11 +232,9 @@ class DistinctOperationSpecification extends OperationFunctionalSpecification { def source = Stub(ConnectionSource) def connection = Mock(Connection) binding.readPreference >> ReadPreference.primary() - binding.operationContext >> operationContext - binding.readConnectionSource >> source - source.connection >> connection + binding.getReadConnectionSource(_) >> source + source.getConnection(_) >> connection source.retain() >> source - source.operationContext >> operationContext def commandDocument = new BsonDocument('distinct', new BsonString(getCollectionName())) .append('key', new BsonString('str')) appendReadConcernToCommand(sessionContext, UNKNOWN_WIRE_VERSION, commandDocument) @@ -242,13 +242,15 @@ class DistinctOperationSpecification extends OperationFunctionalSpecification { def operation = new DistinctOperation(getNamespace(), 'str', new StringCodec()) when: - operation.execute(binding) + operation.execute(binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), 6, STANDALONE, 1000, 100000, 100000, []) - 1 * connection.command(_, commandDocument, _, _, _, operationContext) >> - new BsonDocument('values', new BsonArrayWrapper([])) + 1 * connection.command(_, commandDocument, _, _, _, _) >> { + assertEquals(((OperationContext) it[5]).getId(), operationContext.getId()) + new BsonDocument('values', new BsonArrayWrapper([])) + } 1 * connection.release() where: @@ -269,10 +271,8 @@ class DistinctOperationSpecification extends OperationFunctionalSpecification { def source = Stub(AsyncConnectionSource) def connection = Mock(AsyncConnection) binding.readPreference >> ReadPreference.primary() - binding.getReadConnectionSource(_) >> { it[0].onResult(source, null) } - binding.operationContext >> operationContext - source.operationContext >> operationContext - source.getConnection(_) >> { it[0].onResult(connection, null) } + binding.getReadConnectionSource(_, _) >> { it[1].onResult(source, null) } + source.getConnection(_, _) >> { it[1].onResult(connection, null) } source.retain() >> source def commandDocument = new BsonDocument('distinct', new BsonString(getCollectionName())) .append('key', new BsonString('str')) @@ -281,12 +281,13 @@ class DistinctOperationSpecification extends OperationFunctionalSpecification { def operation = new DistinctOperation(getNamespace(), 'str', new StringCodec()) when: - executeAsync(operation, binding) + executeAsync(operation, binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), 6, STANDALONE, 1000, 100000, 100000, []) - 1 * connection.commandAsync(_, commandDocument, _, _, _, operationContext, *_) >> { + 1 * connection.commandAsync(_, commandDocument, _, _, _, _, *_) >> { + assertEquals(((OperationContext) it[5]).getId(), operationContext.getId()) it.last().onResult(new BsonDocument('values', new BsonArrayWrapper([])), null) } 1 * connection.release() diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/DropCollectionOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/DropCollectionOperationSpecification.groovy index 164dc66d654..eb8f3efa573 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/DropCollectionOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/DropCollectionOperationSpecification.groovy @@ -16,6 +16,7 @@ package com.mongodb.internal.operation +import com.mongodb.ClusterFixture import com.mongodb.MongoNamespace import com.mongodb.MongoWriteConcernException import com.mongodb.OperationFunctionalSpecification @@ -36,7 +37,9 @@ class DropCollectionOperationSpecification extends OperationFunctionalSpecificat assert collectionNameExists(getCollectionName()) when: - new DropCollectionOperation(getNamespace(), WriteConcern.ACKNOWLEDGED).execute(getBinding()) + def binding = getBinding() + new DropCollectionOperation(getNamespace(), WriteConcern.ACKNOWLEDGED) + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) then: !collectionNameExists(getCollectionName()) @@ -60,7 +63,8 @@ class DropCollectionOperationSpecification extends OperationFunctionalSpecificat def namespace = new MongoNamespace(getDatabaseName(), 'nonExistingCollection') when: - new DropCollectionOperation(namespace, WriteConcern.ACKNOWLEDGED).execute(getBinding()) + new DropCollectionOperation(namespace, WriteConcern.ACKNOWLEDGED) + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) then: !collectionNameExists('nonExistingCollection') @@ -86,7 +90,8 @@ class DropCollectionOperationSpecification extends OperationFunctionalSpecificat def operation = new DropCollectionOperation(getNamespace(), new WriteConcern(5)) when: - async ? executeAsync(operation) : operation.execute(getBinding()) + def binding = getBinding() + async ? executeAsync(operation) : operation.execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) then: def ex = thrown(MongoWriteConcernException) @@ -98,7 +103,8 @@ class DropCollectionOperationSpecification extends OperationFunctionalSpecificat } def collectionNameExists(String collectionName) { - def cursor = new ListCollectionsOperation(databaseName, new DocumentCodec()).execute(getBinding()) + def cursor = new ListCollectionsOperation(databaseName, new DocumentCodec()) + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) if (!cursor.hasNext()) { return false } diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/DropDatabaseOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/DropDatabaseOperationSpecification.groovy index d91ac02e8cc..b56e2c1fe50 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/DropDatabaseOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/DropDatabaseOperationSpecification.groovy @@ -16,6 +16,7 @@ package com.mongodb.internal.operation + import com.mongodb.MongoWriteConcernException import com.mongodb.OperationFunctionalSpecification import com.mongodb.WriteConcern @@ -27,6 +28,7 @@ import spock.lang.IgnoreIf import static com.mongodb.ClusterFixture.configureFailPoint import static com.mongodb.ClusterFixture.executeAsync import static com.mongodb.ClusterFixture.getBinding +import static com.mongodb.ClusterFixture.getOperationContext import static com.mongodb.ClusterFixture.isDiscoverableReplicaSet import static com.mongodb.ClusterFixture.isSharded @@ -75,8 +77,10 @@ class DropDatabaseOperationSpecification extends OperationFunctionalSpecificatio 'data : {failCommands : ["dropDatabase"], ' + 'writeConcernError : {code : 100, errmsg : "failed"}}}')) + + def binding = getBinding() when: - async ? executeAsync(operation) : operation.execute(getBinding()) + async ? executeAsync(operation) : operation.execute(binding, getOperationContext(binding.getReadPreference())) then: def ex = thrown(MongoWriteConcernException) @@ -88,7 +92,8 @@ class DropDatabaseOperationSpecification extends OperationFunctionalSpecificatio } def databaseNameExists(String databaseName) { - new ListDatabasesOperation(new DocumentCodec()).execute(getBinding()).next()*.name.contains(databaseName) + new ListDatabasesOperation(new DocumentCodec()).execute(binding, + getOperationContext(binding.getReadPreference())).next()*.name.contains(databaseName) } } diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/DropIndexOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/DropIndexOperationSpecification.groovy index 611c0197faf..4f128addff3 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/DropIndexOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/DropIndexOperationSpecification.groovy @@ -16,6 +16,7 @@ package com.mongodb.internal.operation +import com.mongodb.ClusterFixture import com.mongodb.MongoException import com.mongodb.MongoWriteConcernException import com.mongodb.OperationFunctionalSpecification @@ -154,7 +155,9 @@ class DropIndexOperationSpecification extends OperationFunctionalSpecification { def getIndexes() { def indexes = [] - def cursor = new ListIndexesOperation(getNamespace(), new DocumentCodec()).execute(getBinding()) + def binding = getBinding() + def cursor = new ListIndexesOperation(getNamespace(), new DocumentCodec()) + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) while (cursor.hasNext()) { indexes.addAll(cursor.next()) } diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/FindOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/FindOperationSpecification.groovy index f61ab70f2ae..5eb707201d5 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/FindOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/FindOperationSpecification.groovy @@ -31,6 +31,7 @@ import com.mongodb.connection.ConnectionDescription import com.mongodb.connection.ConnectionId import com.mongodb.connection.ServerId import com.mongodb.internal.TimeoutContext +import com.mongodb.internal.async.SingleResultCallback import com.mongodb.internal.binding.AsyncClusterBinding import com.mongodb.internal.binding.AsyncConnectionSource import com.mongodb.internal.binding.AsyncReadBinding @@ -39,6 +40,7 @@ import com.mongodb.internal.binding.ConnectionSource import com.mongodb.internal.binding.ReadBinding import com.mongodb.internal.connection.AsyncConnection import com.mongodb.internal.connection.Connection +import com.mongodb.internal.connection.OperationContext import com.mongodb.internal.session.SessionContext import org.bson.BsonBoolean import org.bson.BsonDocument @@ -55,10 +57,10 @@ import spock.lang.IgnoreIf import static com.mongodb.ClusterFixture.OPERATION_CONTEXT import static com.mongodb.ClusterFixture.executeAsync import static com.mongodb.ClusterFixture.executeSync -import static com.mongodb.ClusterFixture.getAsyncBinding import static com.mongodb.ClusterFixture.getAsyncCluster import static com.mongodb.ClusterFixture.getBinding import static com.mongodb.ClusterFixture.getCluster +import static com.mongodb.ClusterFixture.getOperationContext import static com.mongodb.ClusterFixture.isSharded import static com.mongodb.ClusterFixture.serverVersionLessThan import static com.mongodb.CursorType.NonTailable @@ -385,8 +387,10 @@ class FindOperationSpecification extends OperationFunctionalSpecification { def 'should apply comment'() { given: def profileCollectionHelper = getCollectionHelper(new MongoNamespace(getDatabaseName(), 'system.profile')) + + def binding = getBinding() new CommandReadOperation<>(getDatabaseName(), new BsonDocument('profile', new BsonInt32(2)), - new BsonDocumentCodec()).execute(getBinding()) + new BsonDocumentCodec()).execute(binding, getOperationContext(binding.getReadPreference())) def expectedComment = 'this is a comment' def operation = new FindOperation(getNamespace(), new DocumentCodec()) .comment(new BsonString(expectedComment)) @@ -401,7 +405,7 @@ class FindOperationSpecification extends OperationFunctionalSpecification { cleanup: new CommandReadOperation<>(getDatabaseName(), new BsonDocument('profile', new BsonInt32(0)), new BsonDocumentCodec()) - .execute(getBinding()) + .execute(binding, getOperationContext(binding.getReadPreference())) profileCollectionHelper.drop() where: @@ -431,9 +435,8 @@ class FindOperationSpecification extends OperationFunctionalSpecification { given: collectionHelper.insertDocuments(new DocumentCodec(), new Document()) def operation = new FindOperation(getNamespace(), new DocumentCodec()) - def syncBinding = new ClusterBinding(getCluster(), ReadPreference.secondary(), ReadConcern.DEFAULT, OPERATION_CONTEXT) - def asyncBinding = new AsyncClusterBinding(getAsyncCluster(), ReadPreference.secondary(), ReadConcern.DEFAULT, - OPERATION_CONTEXT) + def syncBinding = new ClusterBinding(getCluster(), ReadPreference.secondary()) + def asyncBinding = new AsyncClusterBinding(getAsyncCluster(), ReadPreference.secondary()) when: def result = async ? executeAsync(operation, asyncBinding) : executeSync(operation, syncBinding) @@ -457,8 +460,8 @@ class FindOperationSpecification extends OperationFunctionalSpecification { def hedgeOptions = isHedgeEnabled != null ? ReadPreferenceHedgeOptions.builder().enabled(isHedgeEnabled as boolean).build() : null def readPreference = ReadPreference.primaryPreferred().withHedgeOptions(hedgeOptions) - def syncBinding = new ClusterBinding(getCluster(), readPreference, ReadConcern.DEFAULT, OPERATION_CONTEXT) - def asyncBinding = new AsyncClusterBinding(getAsyncCluster(), readPreference, ReadConcern.DEFAULT, OPERATION_CONTEXT) + def syncBinding = new ClusterBinding(getCluster(), readPreference) + def asyncBinding = new AsyncClusterBinding(getAsyncCluster(), readPreference) def cursor = async ? executeAsync(operation, asyncBinding) : executeSync(operation, syncBinding) def firstBatch = { if (async) { @@ -484,26 +487,26 @@ class FindOperationSpecification extends OperationFunctionalSpecification { def source = Stub(ConnectionSource) def connection = Mock(Connection) binding.readPreference >> ReadPreference.primary() - binding.operationContext >> operationContext - binding.readConnectionSource >> source - source.connection >> connection + binding.getReadConnectionSource(_) >> source + source.getConnection(_) >> connection source.retain() >> source - source.operationContext >> operationContext def commandDocument = new BsonDocument('find', new BsonString(getCollectionName())) appendReadConcernToCommand(sessionContext, UNKNOWN_WIRE_VERSION, commandDocument) def operation = new FindOperation(getNamespace(), new DocumentCodec()) when: - operation.execute(binding) + operation.execute(binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), 6, STANDALONE, 1000, 100000, 100000, []) - 1 * connection.command(_, commandDocument, _, _, _, operationContext) >> - new BsonDocument('cursor', new BsonDocument('id', new BsonInt64(1)) - .append('ns', new BsonString(getNamespace().getFullName())) - .append('firstBatch', new BsonArrayWrapper([]))) + 1 * connection.command(_, commandDocument, _, _, _, _) >> { + assertEquals(((OperationContext) it[5]).getId(), operationContext.getId()) + new BsonDocument('cursor', new BsonDocument('id', new BsonInt64(1)) + .append('ns', new BsonString(getNamespace().getFullName())) + .append('firstBatch', new BsonArrayWrapper([]))) + } 1 * connection.release() where: @@ -524,10 +527,8 @@ class FindOperationSpecification extends OperationFunctionalSpecification { def source = Stub(AsyncConnectionSource) def connection = Mock(AsyncConnection) binding.readPreference >> ReadPreference.primary() - binding.operationContext >> operationContext - binding.getReadConnectionSource(_) >> { it[0].onResult(source, null) } - source.operationContext >> operationContext - source.getConnection(_) >> { it[0].onResult(connection, null) } + binding.getReadConnectionSource(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(source, null) } + source.getConnection(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(connection, null) } source.retain() >> source def commandDocument = new BsonDocument('find', new BsonString(getCollectionName())) appendReadConcernToCommand(sessionContext, UNKNOWN_WIRE_VERSION, commandDocument) @@ -535,12 +536,13 @@ class FindOperationSpecification extends OperationFunctionalSpecification { def operation = new FindOperation(getNamespace(), new DocumentCodec()) when: - executeAsync(operation, binding) + executeAsync(operation, binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), 6, STANDALONE, 1000, 100000, 100000, []) - 1 * connection.commandAsync(_, commandDocument, _, _, _, operationContext, _) >> { + 1 * connection.commandAsync(_, commandDocument, _, _, _, _, _) >> { + assertEquals(((OperationContext) it[5]).getId(), operationContext.getId()) it.last().onResult(new BsonDocument('cursor', new BsonDocument('id', new BsonInt64(1)) .append('ns', new BsonString(getNamespace().getFullName())) .append('firstBatch', new BsonArrayWrapper([]))), null) @@ -565,26 +567,26 @@ class FindOperationSpecification extends OperationFunctionalSpecification { def source = Stub(ConnectionSource) def connection = Mock(Connection) binding.readPreference >> ReadPreference.primary() - binding.readConnectionSource >> source - binding.operationContext >> operationContext - source.connection >> connection + binding.getReadConnectionSource(_) >> source + source.getConnection(_) >> connection source.retain() >> source - source.operationContext >> operationContext def commandDocument = new BsonDocument('find', new BsonString(getCollectionName())).append('allowDiskUse', BsonBoolean.TRUE) appendReadConcernToCommand(sessionContext, UNKNOWN_WIRE_VERSION, commandDocument) def operation = new FindOperation(getNamespace(), new DocumentCodec()).allowDiskUse(true) when: - operation.execute(binding) + operation.execute(binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), 6, STANDALONE, 1000, 100000, 100000, []) - 1 * connection.command(_, commandDocument, _, _, _, operationContext) >> - new BsonDocument('cursor', new BsonDocument('id', new BsonInt64(1)) - .append('ns', new BsonString(getNamespace().getFullName())) - .append('firstBatch', new BsonArrayWrapper([]))) + 1 * connection.command(_, commandDocument, _, _, _, _) >> { + assertEquals(((OperationContext) it[5]).getId(), operationContext.getId()) + new BsonDocument('cursor', new BsonDocument('id', new BsonInt64(1)) + .append('ns', new BsonString(getNamespace().getFullName())) + .append('firstBatch', new BsonArrayWrapper([]))) + } 1 * connection.release() where: @@ -604,11 +606,9 @@ class FindOperationSpecification extends OperationFunctionalSpecification { def binding = Stub(AsyncReadBinding) def source = Stub(AsyncConnectionSource) def connection = Mock(AsyncConnection) - binding.operationContext >> operationContext binding.readPreference >> ReadPreference.primary() - binding.getReadConnectionSource(_) >> { it[0].onResult(source, null) } - source.operationContext >> operationContext - source.getConnection(_) >> { it[0].onResult(connection, null) } + binding.getReadConnectionSource(_, _) >> { it[1].onResult(source, null) } + source.getConnection(_, _) >> { it[1].onResult(connection, null) } source.retain() >> source def commandDocument = new BsonDocument('find', new BsonString(getCollectionName())).append('allowDiskUse', BsonBoolean.TRUE) appendReadConcernToCommand(sessionContext, UNKNOWN_WIRE_VERSION, commandDocument) @@ -616,12 +616,13 @@ class FindOperationSpecification extends OperationFunctionalSpecification { def operation = new FindOperation(getNamespace(), new DocumentCodec()).allowDiskUse(true) when: - executeAsync(operation, binding) + executeAsync(operation, binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), 6, STANDALONE, 1000, 100000, 100000, []) - 1 * connection.commandAsync(_, commandDocument, _, _, _, operationContext, _) >> { + 1 * connection.commandAsync(_, commandDocument, _, _, _, _, _) >> { + assertEquals(((OperationContext) it[5]).getId(), operationContext.getId()) it.last().onResult(new BsonDocument('cursor', new BsonDocument('id', new BsonInt64(1)) .append('ns', new BsonString(getNamespace().getFullName())) .append('firstBatch', new BsonArrayWrapper([]))), null) @@ -644,7 +645,7 @@ class FindOperationSpecification extends OperationFunctionalSpecification { given: def (cursorType, long maxAwaitTimeMS, long maxTimeMSForCursor) = cursorDetails def timeoutSettings = ClusterFixture.TIMEOUT_SETTINGS_WITH_INFINITE_TIMEOUT.withMaxAwaitTimeMS(maxAwaitTimeMS) - def timeoutContext = Spy(TimeoutContext, constructorArgs: [timeoutSettings]) + def timeoutContext = new TimeoutContext(timeoutSettings) def operationContext = OPERATION_CONTEXT.withTimeoutContext(timeoutContext) collectionHelper.create(getCollectionName(), new CreateCollectionOptions().capped(true).sizeInBytes(1000)) @@ -652,14 +653,19 @@ class FindOperationSpecification extends OperationFunctionalSpecification { .cursorType(cursorType) when: + def cursor; if (async) { - execute(operation, getBinding(operationContext)) + cursor = execute(operation, ClusterFixture.getAsyncBinding(operationContext), operationContext) } else { - execute(operation, getAsyncBinding(operationContext)) + cursor = execute(operation, ClusterFixture.getBinding(operationContext), operationContext) } then: - timeoutContext.setMaxTimeOverride(maxTimeMSForCursor) + cursor.operationContext.getTimeoutContext().getMaxAwaitTimeMS() == maxAwaitTimeMS + // should have maxTimeMS override + cursor.operationContext.getTimeoutContext().runMaxTimeMS { long actualMaxTimeMs -> + assertEquals(maxTimeMSForCursor, actualMaxTimeMs) + } where: [async, cursorDetails] << [ diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/ListCollectionsOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/ListCollectionsOperationSpecification.groovy index 0d2688e0da6..ad55b706ba2 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/ListCollectionsOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/ListCollectionsOperationSpecification.groovy @@ -33,6 +33,7 @@ import com.mongodb.internal.binding.ConnectionSource import com.mongodb.internal.binding.ReadBinding import com.mongodb.internal.connection.AsyncConnection import com.mongodb.internal.connection.Connection +import com.mongodb.internal.connection.OperationContext import org.bson.BsonBoolean import org.bson.BsonDocument import org.bson.BsonDouble @@ -45,6 +46,8 @@ import org.bson.codecs.DocumentCodec import static com.mongodb.ClusterFixture.OPERATION_CONTEXT import static com.mongodb.ClusterFixture.executeAsync import static com.mongodb.ClusterFixture.getBinding +import static com.mongodb.ClusterFixture.getOperationContext +import static org.junit.jupiter.api.Assertions.assertEquals class ListCollectionsOperationSpecification extends OperationFunctionalSpecification { @@ -54,8 +57,10 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica given: def operation = new ListCollectionsOperation(madeUpDatabase, new DocumentCodec()) + + def binding = getBinding() when: - def cursor = operation.execute(getBinding()) + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) then: !cursor.hasNext() @@ -90,8 +95,10 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica helper.insertDocuments(codec, ['a': 1] as Document) helper2.insertDocuments(codec, ['a': 1] as Document) + + def binding = getBinding() when: - def cursor = operation.execute(getBinding()) + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) def collections = cursor.next() def names = collections*.get('name') @@ -111,8 +118,11 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica helper.insertDocuments(codec, ['a': 1] as Document) helper2.insertDocuments(codec, ['a': 1] as Document) + + def binding = getBinding() when: - def cursor = operation.execute(getBinding()) + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference()) + ) def collections = cursor.next() def names = collections*.get('name') @@ -130,8 +140,10 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica def codec = new DocumentCodec() helper.insertDocuments(codec, ['a': 1] as Document) + + def binding = getBinding() when: - def cursor = operation.execute(getBinding()) + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) def collections = cursor.next() def names = collections*.get('name') @@ -146,8 +158,10 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica .nameOnly(true) getCollectionHelper().create('collection5', new CreateCollectionOptions()) + + def binding = getBinding() when: - def cursor = operation.execute(getBinding()) + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) def collection = cursor.next()[0] then: @@ -161,8 +175,10 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica .authorizedCollections(true) getCollectionHelper().create('collection6', new CreateCollectionOptions()) + + def binding = getBinding() when: - def cursor = operation.execute(getBinding()) + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) def collection = cursor.next()[0] then: @@ -176,8 +192,10 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica .authorizedCollections(true) getCollectionHelper().create('collection8', new CreateCollectionOptions()) + + def binding = getBinding() when: - def cursor = operation.execute(getBinding()) + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) def collection = cursor.next()[0] then: @@ -206,13 +224,17 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica } def 'should filter indexes when calling hasNext before next'() { + def binding = getBinding() given: - new DropDatabaseOperation(databaseName, WriteConcern.ACKNOWLEDGED).execute(getBinding()) + new DropDatabaseOperation(databaseName, WriteConcern.ACKNOWLEDGED) + .execute(binding, getOperationContext(binding.getReadPreference())) addSeveralIndexes() def operation = new ListCollectionsOperation(databaseName, new DocumentCodec()).batchSize(2) + when: - def cursor = operation.execute(getBinding()) + binding = getBinding() + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) then: cursor.hasNext() @@ -222,13 +244,16 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica } def 'should filter indexes without calling hasNext before next'() { + def binding = getBinding() given: - new DropDatabaseOperation(databaseName, WriteConcern.ACKNOWLEDGED).execute(getBinding()) + new DropDatabaseOperation(databaseName, WriteConcern.ACKNOWLEDGED) + .execute(binding, getOperationContext(binding.getReadPreference())) addSeveralIndexes() def operation = new ListCollectionsOperation(databaseName, new DocumentCodec()).batchSize(2) when: - def cursor = operation.execute(getBinding()) + binding = getBinding() + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) def list = cursorToListWithNext(cursor) then: @@ -244,13 +269,17 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica } def 'should filter indexes when calling hasNext before tryNext'() { + def binding = getBinding() given: - new DropDatabaseOperation(databaseName, WriteConcern.ACKNOWLEDGED).execute(getBinding()) + new DropDatabaseOperation(databaseName, WriteConcern.ACKNOWLEDGED) + .execute(binding, getOperationContext(binding.getReadPreference())) addSeveralIndexes() def operation = new ListCollectionsOperation(databaseName, new DocumentCodec()).batchSize(2) + when: - def cursor = operation.execute(getBinding()) + binding = getBinding() + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) then: cursor.hasNext() @@ -267,12 +296,15 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica def 'should filter indexes without calling hasNext before tryNext'() { given: - new DropDatabaseOperation(databaseName, WriteConcern.ACKNOWLEDGED).execute(getBinding()) + def binding = getBinding() + new DropDatabaseOperation(databaseName, WriteConcern.ACKNOWLEDGED) + .execute(binding, getOperationContext(binding.getReadPreference())) addSeveralIndexes() def operation = new ListCollectionsOperation(databaseName, new DocumentCodec()).batchSize(2) when: - def cursor = operation.execute(getBinding()) + binding = getBinding() + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) def list = cursorToListWithTryNext(cursor) then: @@ -284,7 +316,9 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica def 'should filter indexes asynchronously'() { given: - new DropDatabaseOperation(databaseName, WriteConcern.ACKNOWLEDGED).execute(getBinding()) + def binding = getBinding() + new DropDatabaseOperation(databaseName, WriteConcern.ACKNOWLEDGED) + .execute(binding, getOperationContext(binding.getReadPreference())) addSeveralIndexes() def operation = new ListCollectionsOperation(databaseName, new DocumentCodec()).batchSize(2) @@ -307,8 +341,10 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica getCollectionHelper(new MongoNamespace(databaseName, 'collection4')).insertDocuments(codec, ['a': 1] as Document) getCollectionHelper(new MongoNamespace(databaseName, 'collection5')).insertDocuments(codec, ['a': 1] as Document) + when: - def cursor = operation.execute(getBinding()) + def binding = getBinding() + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) def collections = cursor.next() then: @@ -364,23 +400,24 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica given: def connection = Mock(Connection) def connectionSource = Stub(ConnectionSource) { - getConnection() >> connection + getConnection(_) >> connection getReadPreference() >> readPreference - getOperationContext() >> OPERATION_CONTEXT } def readBinding = Stub(ReadBinding) { - getReadConnectionSource() >> connectionSource + getReadConnectionSource(_) >> connectionSource getReadPreference() >> readPreference - getOperationContext() >> OPERATION_CONTEXT } def operation = new ListCollectionsOperation(helper.dbName, helper.decoder) when: '3.6.0' - operation.execute(readBinding) + operation.execute(readBinding, OPERATION_CONTEXT) then: _ * connection.getDescription() >> helper.threeSixConnectionDescription - 1 * connection.command(_, _, _, readPreference, _, OPERATION_CONTEXT) >> helper.commandResult + 1 * connection.command(_, _, _, readPreference, _, _) >> { + assertEquals(((OperationContext) it[5]).getId(), OPERATION_CONTEXT.getId()) + helper.commandResult + } 1 * connection.release() where: @@ -391,23 +428,22 @@ class ListCollectionsOperationSpecification extends OperationFunctionalSpecifica given: def connection = Mock(AsyncConnection) def connectionSource = Stub(AsyncConnectionSource) { - getConnection(_) >> { it[0].onResult(connection, null) } + getConnection(_, _) >> { it[1].onResult(connection, null) } getReadPreference() >> readPreference - getOperationContext() >> OPERATION_CONTEXT } def readBinding = Stub(AsyncReadBinding) { - getReadConnectionSource(_) >> { it[0].onResult(connectionSource, null) } + getReadConnectionSource(_, _) >> { it[1].onResult(connectionSource, null) } getReadPreference() >> readPreference - getOperationContext() >> OPERATION_CONTEXT } def operation = new ListCollectionsOperation(helper.dbName, helper.decoder) when: '3.6.0' - operation.executeAsync(readBinding, Stub(SingleResultCallback)) + operation.executeAsync(readBinding, OPERATION_CONTEXT, Stub(SingleResultCallback)) then: _ * connection.getDescription() >> helper.threeSixConnectionDescription - 1 * connection.commandAsync(helper.dbName, _, _, readPreference, _, OPERATION_CONTEXT, *_) >> { + 1 * connection.commandAsync(helper.dbName, _, _, readPreference, _, _, *_) >> { + assertEquals(((OperationContext) it[5]).getId(), OPERATION_CONTEXT.getId()) it.last().onResult(helper.commandResult, null) } where: diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/ListDatabasesOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/ListDatabasesOperationSpecification.groovy index 740f9073dcd..55504d0babc 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/ListDatabasesOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/ListDatabasesOperationSpecification.groovy @@ -72,23 +72,21 @@ class ListDatabasesOperationSpecification extends OperationFunctionalSpecificati given: def connection = Mock(Connection) def connectionSource = Stub(ConnectionSource) { - getConnection() >> connection + getConnection(_) >> connection getReadPreference() >> readPreference - getOperationContext() >> OPERATION_CONTEXT } def readBinding = Stub(ReadBinding) { - getReadConnectionSource() >> connectionSource + getReadConnectionSource(_) >> connectionSource getReadPreference() >> readPreference - getOperationContext() >> OPERATION_CONTEXT } def operation = new ListDatabasesOperation(helper.decoder) when: - operation.execute(readBinding) + operation.execute(readBinding, OPERATION_CONTEXT) then: _ * connection.getDescription() >> helper.connectionDescription - 1 * connection.command(_, _, _, readPreference, _, OPERATION_CONTEXT) >> helper.commandResult + 1 * connection.command(_, _, _, readPreference, _, _) >> helper.commandResult 1 * connection.release() where: @@ -100,16 +98,16 @@ class ListDatabasesOperationSpecification extends OperationFunctionalSpecificati def connection = Mock(AsyncConnection) def connectionSource = Stub(AsyncConnectionSource) { getReadPreference() >> readPreference - getConnection(_) >> { it[0].onResult(connection, null) } + getConnection(_, _) >> { it[1].onResult(connection, null) } } def readBinding = Stub(AsyncReadBinding) { getReadPreference() >> readPreference - getReadConnectionSource(_) >> { it[0].onResult(connectionSource, null) } + getReadConnectionSource(_, _) >> { it[1].onResult(connectionSource, null) } } def operation = new ListDatabasesOperation(helper.decoder) when: - operation.executeAsync(readBinding, Stub(SingleResultCallback)) + operation.executeAsync(readBinding, OPERATION_CONTEXT, Stub(SingleResultCallback)) then: _ * connection.getDescription() >> helper.connectionDescription diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/ListIndexesOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/ListIndexesOperationSpecification.groovy index 462bf367e50..c11d67bcf22 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/ListIndexesOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/ListIndexesOperationSpecification.groovy @@ -33,6 +33,7 @@ import com.mongodb.internal.binding.ReadBinding import com.mongodb.internal.bulk.IndexRequest import com.mongodb.internal.connection.AsyncConnection import com.mongodb.internal.connection.Connection +import com.mongodb.internal.connection.OperationContext import org.bson.BsonDocument import org.bson.BsonDouble import org.bson.BsonInt32 @@ -41,10 +42,12 @@ import org.bson.BsonString import org.bson.Document import org.bson.codecs.Decoder import org.bson.codecs.DocumentCodec +import org.junit.jupiter.api.Assertions import static com.mongodb.ClusterFixture.OPERATION_CONTEXT import static com.mongodb.ClusterFixture.executeAsync import static com.mongodb.ClusterFixture.getBinding +import static com.mongodb.ClusterFixture.getOperationContext class ListIndexesOperationSpecification extends OperationFunctionalSpecification { @@ -52,8 +55,10 @@ class ListIndexesOperationSpecification extends OperationFunctionalSpecification given: def operation = new ListIndexesOperation(getNamespace(), new DocumentCodec()) + + def binding = getBinding() when: - def cursor = operation.execute(getBinding()) + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) then: !cursor.hasNext() @@ -79,8 +84,10 @@ class ListIndexesOperationSpecification extends OperationFunctionalSpecification def operation = new ListIndexesOperation(getNamespace(), new DocumentCodec()) getCollectionHelper().insertDocuments(new DocumentCodec(), new Document('documentThat', 'forces creation of the Collection')) + + def binding = getBinding() when: - BatchCursor indexes = operation.execute(getBinding()) + BatchCursor indexes = operation.execute(binding, getOperationContext(binding.getReadPreference())) then: def firstBatch = indexes.next() @@ -111,11 +118,15 @@ class ListIndexesOperationSpecification extends OperationFunctionalSpecification def operation = new ListIndexesOperation(getNamespace(), new DocumentCodec()) collectionHelper.createIndex(new BsonDocument('theField', new BsonInt32(1))) collectionHelper.createIndex(new BsonDocument('compound', new BsonInt32(1)).append('index', new BsonInt32(-1))) + + def binding = getBinding() new CreateIndexesOperation(namespace, - [new IndexRequest(new BsonDocument('unique', new BsonInt32(1))).unique(true)], null).execute(getBinding()) + [new IndexRequest(new BsonDocument('unique', new BsonInt32(1))).unique(true)], null).execute(binding, + getOperationContext(binding.getReadPreference())) when: - BatchCursor cursor = operation.execute(getBinding()) + binding = getBinding() + BatchCursor cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) then: def indexes = cursor.next() @@ -131,8 +142,11 @@ class ListIndexesOperationSpecification extends OperationFunctionalSpecification def operation = new ListIndexesOperation(getNamespace(), new DocumentCodec()) collectionHelper.createIndex(new BsonDocument('theField', new BsonInt32(1))) collectionHelper.createIndex(new BsonDocument('compound', new BsonInt32(1)).append('index', new BsonInt32(-1))) + + def binding = getBinding() new CreateIndexesOperation(namespace, - [new IndexRequest(new BsonDocument('unique', new BsonInt32(1))).unique(true)], null).execute(getBinding()) + [new IndexRequest(new BsonDocument('unique', new BsonInt32(1))).unique(true)], null).execute(binding, + getOperationContext(binding.getReadPreference())) when: def cursor = executeAsync(operation) @@ -155,8 +169,10 @@ class ListIndexesOperationSpecification extends OperationFunctionalSpecification collectionHelper.createIndex(new BsonDocument('collection4', new BsonInt32(1))) collectionHelper.createIndex(new BsonDocument('collection5', new BsonInt32(1))) + + def binding = getBinding() when: - def cursor = operation.execute(getBinding()) + def cursor = operation.execute(binding, getOperationContext(binding.getReadPreference())) def collections = cursor.next() then: @@ -211,23 +227,24 @@ class ListIndexesOperationSpecification extends OperationFunctionalSpecification given: def connection = Mock(Connection) def connectionSource = Stub(ConnectionSource) { - getConnection() >> connection + getConnection(_) >> connection getReadPreference() >> readPreference - getOperationContext() >> OPERATION_CONTEXT } def readBinding = Stub(ReadBinding) { - getReadConnectionSource() >> connectionSource + getReadConnectionSource(_) >> connectionSource getReadPreference() >> readPreference - getOperationContext() >> OPERATION_CONTEXT } def operation = new ListIndexesOperation(helper.namespace, helper.decoder) when: '3.6.0' - operation.execute(readBinding) + operation.execute(readBinding, OPERATION_CONTEXT) then: _ * connection.getDescription() >> helper.threeSixConnectionDescription - 1 * connection.command(_, _, _, readPreference, _, OPERATION_CONTEXT) >> helper.commandResult + 1 * connection.command(_, _, _, readPreference, _, _) >> { + Assertions.assertEquals(((OperationContext) it[5]).getId(), OPERATION_CONTEXT.getId()) + helper.commandResult + } 1 * connection.release() where: @@ -239,16 +256,16 @@ class ListIndexesOperationSpecification extends OperationFunctionalSpecification def connection = Mock(AsyncConnection) def connectionSource = Stub(AsyncConnectionSource) { getReadPreference() >> readPreference - getConnection(_) >> { it[0].onResult(connection, null) } + getConnection(_, _) >> { it[1].onResult(connection, null) } } def readBinding = Stub(AsyncReadBinding) { getReadPreference() >> readPreference - getReadConnectionSource(_) >> { it[0].onResult(connectionSource, null) } + getReadConnectionSource(_, _) >> { it[1].onResult(connectionSource, null) } } def operation = new ListIndexesOperation(helper.namespace, helper.decoder) when: '3.6.0' - operation.executeAsync(readBinding, Stub(SingleResultCallback)) + operation.executeAsync(readBinding, OPERATION_CONTEXT, Stub(SingleResultCallback)) then: _ * connection.getDescription() >> helper.threeSixConnectionDescription diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceToCollectionOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceToCollectionOperationSpecification.groovy index 0f48042da47..5d6be781d1f 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceToCollectionOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceToCollectionOperationSpecification.groovy @@ -16,6 +16,7 @@ package com.mongodb.internal.operation +import com.mongodb.ClusterFixture import com.mongodb.MongoCommandException import com.mongodb.MongoNamespace import com.mongodb.MongoWriteConcernException @@ -62,9 +63,12 @@ class MapReduceToCollectionOperationSpecification extends OperationFunctionalSpe } def cleanup() { - new DropCollectionOperation(mapReduceInputNamespace, WriteConcern.ACKNOWLEDGED).execute(getBinding()) + def binding = getBinding() + def operationContext = ClusterFixture.getOperationContext(binding.getReadPreference()) + new DropCollectionOperation(mapReduceInputNamespace, WriteConcern.ACKNOWLEDGED) + .execute(binding, operationContext) new DropCollectionOperation(mapReduceOutputNamespace, WriteConcern.ACKNOWLEDGED) - .execute(getBinding()) + .execute(binding, operationContext) } def 'should have the correct defaults'() { diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceWithInlineResultsOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceWithInlineResultsOperationSpecification.groovy index 17b3c28f637..8efd4e00f6c 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceWithInlineResultsOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceWithInlineResultsOperationSpecification.groovy @@ -26,12 +26,14 @@ import com.mongodb.connection.ClusterId import com.mongodb.connection.ConnectionDescription import com.mongodb.connection.ConnectionId import com.mongodb.connection.ServerId +import com.mongodb.internal.async.SingleResultCallback import com.mongodb.internal.binding.AsyncConnectionSource import com.mongodb.internal.binding.AsyncReadBinding import com.mongodb.internal.binding.ConnectionSource import com.mongodb.internal.binding.ReadBinding import com.mongodb.internal.connection.AsyncConnection import com.mongodb.internal.connection.Connection +import com.mongodb.internal.connection.OperationContext import com.mongodb.internal.session.SessionContext import org.bson.BsonBoolean import org.bson.BsonDocument @@ -220,11 +222,9 @@ class MapReduceWithInlineResultsOperationSpecification extends OperationFunction def source = Stub(ConnectionSource) def connection = Mock(Connection) binding.readPreference >> ReadPreference.primary() - binding.operationContext >> operationContext - binding.readConnectionSource >> source - source.connection >> connection + binding.getReadConnectionSource(_) >> source + source.getConnection(_) >> connection source.retain() >> source - source.operationContext >> operationContext def commandDocument = BsonDocument.parse(''' { "mapReduce" : "coll", "map" : { "$code" : "function(){ }" }, @@ -237,12 +237,12 @@ class MapReduceWithInlineResultsOperationSpecification extends OperationFunction new BsonJavaScript('function(){ }'), new BsonJavaScript('function(key, values){ }'), bsonDocumentCodec) when: - operation.execute(binding) + operation.execute(binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), 6, STANDALONE, 1000, 100000, 100000, []) - 1 * connection.command(_, commandDocument, _, _, _, operationContext) >> + 1 * connection.command(_, commandDocument, _, _, _, _) >> new BsonDocument('results', new BsonArrayWrapper([])) .append('counts', new BsonDocument('input', new BsonInt32(0)) @@ -269,10 +269,8 @@ class MapReduceWithInlineResultsOperationSpecification extends OperationFunction def source = Stub(AsyncConnectionSource) def connection = Mock(AsyncConnection) binding.readPreference >> ReadPreference.primary() - binding.operationContext >> operationContext - binding.getReadConnectionSource(_) >> { it[0].onResult(source, null) } - source.operationContext >> operationContext - source.getConnection(_) >> { it[0].onResult(connection, null) } + binding.getReadConnectionSource(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(source, null) } + source.getConnection(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(connection, null) } source.retain() >> source def commandDocument = BsonDocument.parse(''' { "mapReduce" : "coll", @@ -286,12 +284,12 @@ class MapReduceWithInlineResultsOperationSpecification extends OperationFunction new BsonJavaScript('function(){ }'), new BsonJavaScript('function(key, values){ }'), bsonDocumentCodec) when: - executeAsync(operation, binding) + executeAsync(operation, binding, operationContext) then: _ * connection.description >> new ConnectionDescription(new ConnectionId(new ServerId(new ClusterId(), new ServerAddress())), 6, STANDALONE, 1000, 100000, 100000, []) - 1 * connection.commandAsync(_, commandDocument, _, _, _, operationContext, _) >> { + 1 * connection.commandAsync(_, commandDocument, _, _, _, _, _) >> { it.last().onResult(new BsonDocument('results', new BsonArrayWrapper([])) .append('counts', new BsonDocument('input', new BsonInt32(0)) diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/RenameCollectionOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/RenameCollectionOperationSpecification.groovy index f2e75a235df..bc55bf5a134 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/RenameCollectionOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/RenameCollectionOperationSpecification.groovy @@ -16,6 +16,7 @@ package com.mongodb.internal.operation + import com.mongodb.MongoNamespace import com.mongodb.MongoServerException import com.mongodb.MongoWriteConcernException @@ -27,6 +28,7 @@ import spock.lang.IgnoreIf import static com.mongodb.ClusterFixture.executeAsync import static com.mongodb.ClusterFixture.getBinding +import static com.mongodb.ClusterFixture.getOperationContext import static com.mongodb.ClusterFixture.isDiscoverableReplicaSet import static com.mongodb.ClusterFixture.isSharded @@ -34,8 +36,9 @@ import static com.mongodb.ClusterFixture.isSharded class RenameCollectionOperationSpecification extends OperationFunctionalSpecification { def cleanup() { + def binding = getBinding() new DropCollectionOperation(new MongoNamespace(getDatabaseName(), 'newCollection'), - WriteConcern.ACKNOWLEDGED).execute(getBinding()) + WriteConcern.ACKNOWLEDGED).execute(binding, getOperationContext(binding.getReadPreference())) } def 'should return rename a collection'() { @@ -81,8 +84,10 @@ class RenameCollectionOperationSpecification extends OperationFunctionalSpecific def operation = new RenameCollectionOperation(getNamespace(), new MongoNamespace(getDatabaseName(), 'newCollection'), new WriteConcern(5)) + + def binding = getBinding() when: - async ? executeAsync(operation) : operation.execute(getBinding()) + async ? executeAsync(operation) : operation.execute(binding, getOperationContext(binding.getReadPreference())) then: def ex = thrown(MongoWriteConcernException) @@ -94,7 +99,9 @@ class RenameCollectionOperationSpecification extends OperationFunctionalSpecific } def collectionNameExists(String collectionName) { - def cursor = new ListCollectionsOperation(databaseName, new DocumentCodec()).execute(getBinding()) + def binding = getBinding() + def cursor = new ListCollectionsOperation(databaseName, new DocumentCodec()).execute(binding, + getOperationContext(binding.getReadPreference())) if (!cursor.hasNext()) { return false } diff --git a/driver-core/src/test/resources/logback-test.xml b/driver-core/src/test/resources/logback-test.xml index dde5eeba5aa..6941b8913e0 100644 --- a/driver-core/src/test/resources/logback-test.xml +++ b/driver-core/src/test/resources/logback-test.xml @@ -11,7 +11,7 @@ - + diff --git a/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java b/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java index 130d408076e..be4526aada7 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java @@ -228,9 +228,9 @@ void testValidatedMinRoundTripTime() { Supplier supplier = () -> new TimeoutContext(TIMEOUT_SETTINGS.withTimeoutMS(100L)); assertTrue(getMaxTimeMS(supplier.get()) <= 100); - assertTrue(getMaxTimeMS(supplier.get().minRoundTripTimeMS(10)) <= 90); - assertThrows(MongoOperationTimeoutException.class, () -> getMaxTimeMS(supplier.get().minRoundTripTimeMS(101))); - assertThrows(MongoOperationTimeoutException.class, () -> getMaxTimeMS(supplier.get().minRoundTripTimeMS(100))); + assertTrue(getMaxTimeMS(supplier.get().withMinRoundTripTime(10)) <= 90); + assertThrows(MongoOperationTimeoutException.class, () -> getMaxTimeMS(supplier.get().withMinRoundTripTimeMS(101))); + assertThrows(MongoOperationTimeoutException.class, () -> getMaxTimeMS(supplier.get().withMinRoundTripTimeMS(100))); } @Test @@ -277,7 +277,7 @@ void testCreateTimeoutContextWithTimeout() { void shouldOverrideMaximeMS() { TimeoutContext timeoutContext = new TimeoutContext(TIMEOUT_SETTINGS.withTimeoutMS(100L).withMaxTimeMS(1)); - timeoutContext.setMaxTimeOverride(2L); + timeoutContext = timeoutContext.withMaxTimeOverride(2L); assertEquals(2, getMaxTimeMS(timeoutContext)); } @@ -286,9 +286,9 @@ void shouldOverrideMaximeMS() { @DisplayName("should reset maxTimeMS to default behaviour") void shouldResetMaximeMS() { TimeoutContext timeoutContext = new TimeoutContext(TIMEOUT_SETTINGS.withTimeoutMS(100L).withMaxTimeMS(1)); - timeoutContext.setMaxTimeOverride(1L); + timeoutContext = timeoutContext.withMaxTimeOverride(1L); - timeoutContext.resetToDefaultMaxTime(); + timeoutContext = timeoutContext.withDefaultMaxTime(); assertTrue(getMaxTimeMS(timeoutContext) > 1); } diff --git a/driver-core/src/test/unit/com/mongodb/internal/binding/SingleServerBindingSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/binding/SingleServerBindingSpecification.groovy index 824a724ee81..d52fb593a70 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/binding/SingleServerBindingSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/binding/SingleServerBindingSpecification.groovy @@ -41,22 +41,12 @@ class SingleServerBindingSpecification extends Specification { .build()) } def address = new ServerAddress() - def operationContext = OPERATION_CONTEXT when: - - def binding = new SingleServerBinding(cluster, address, operationContext) + def binding = new SingleServerBinding(cluster, address) then: binding.readPreference == ReadPreference.primary() - binding.getOperationContext() == operationContext - - - when: - def source = binding.getReadConnectionSource() - - then: - source.getOperationContext() == operationContext } def 'should increment and decrement reference counts'() { @@ -72,13 +62,13 @@ class SingleServerBindingSpecification extends Specification { def address = new ServerAddress() when: - def binding = new SingleServerBinding(cluster, address, OPERATION_CONTEXT) + def binding = new SingleServerBinding(cluster, address) then: binding.count == 1 when: - def source = binding.getReadConnectionSource() + def source = binding.getReadConnectionSource(OPERATION_CONTEXT) then: source.count == 1 @@ -106,7 +96,7 @@ class SingleServerBindingSpecification extends Specification { binding.count == 1 when: - source = binding.getWriteConnectionSource() + source = binding.getWriteConnectionSource(OPERATION_CONTEXT) then: source.count == 1 diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/AbstractConnectionPoolTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/AbstractConnectionPoolTest.java index 92e224df835..69a2c236048 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/AbstractConnectionPoolTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/AbstractConnectionPoolTest.java @@ -77,6 +77,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; +import static com.mongodb.ClusterFixture.OPERATION_CONTEXT; import static com.mongodb.ClusterFixture.OPERATION_CONTEXT_FACTORY; import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS; import static com.mongodb.assertions.Assertions.assertFalse; @@ -541,7 +542,7 @@ private Event getNextEvent(final Iterator eventsIterator, final private static void executeAdminCommand(final BsonDocument command) { new CommandReadOperation<>("admin", command, new BsonDocumentCodec()) - .execute(ClusterFixture.getBinding()); + .execute(ClusterFixture.getBinding(), OPERATION_CONTEXT); } private void setFailPoint() { diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerSpecification.groovy index 6552a69a70d..3910da575f0 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerSpecification.groovy @@ -16,7 +16,6 @@ package com.mongodb.internal.connection - import com.mongodb.MongoException import com.mongodb.MongoNodeIsRecoveringException import com.mongodb.MongoNotPrimaryException @@ -351,7 +350,7 @@ class DefaultServerSpecification extends Specification { clusterClock.advance(clusterClockClusterTime) def server = new DefaultServer(serverId, SINGLE, Mock(ConnectionPool), new TestConnectionFactory(), Mock(ServerMonitor), Mock(SdamServerDescriptionManager), Mock(ServerListener), Mock(CommandListener), clusterClock, false) - def testConnection = (TestConnection) server.getConnection() + def testConnection = (TestConnection) server.getConnection(OPERATION_CONTEXT) def sessionContext = new TestSessionContext(initialClusterTime) def response = BsonDocument.parse( '''{ diff --git a/driver-core/src/test/unit/com/mongodb/internal/mockito/InsufficientStubbingDetectorDemoTest.java b/driver-core/src/test/unit/com/mongodb/internal/mockito/InsufficientStubbingDetectorDemoTest.java index 40d33c31288..5d8bd8e61b1 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/mockito/InsufficientStubbingDetectorDemoTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/mockito/InsufficientStubbingDetectorDemoTest.java @@ -40,33 +40,33 @@ void beforeEach() { @Test void mockObjectWithDefaultAnswer() { ReadBinding binding = Mockito.mock(ReadBinding.class); - assertThrows(NullPointerException.class, () -> operation.execute(binding)); + assertThrows(NullPointerException.class, () -> operation.execute(binding, OPERATION_CONTEXT)); } @Test void mockObjectWithThrowsException() { ReadBinding binding = Mockito.mock(ReadBinding.class, new ThrowsException(new AssertionError("Insufficient stubbing for " + ReadBinding.class))); - assertThrows(AssertionError.class, () -> operation.execute(binding)); + assertThrows(AssertionError.class, () -> operation.execute(binding, OPERATION_CONTEXT)); } @Test void mockObjectWithInsufficientStubbingDetector() { ReadBinding binding = MongoMockito.mock(ReadBinding.class); - assertThrows(AssertionError.class, () -> operation.execute(binding)); + assertThrows(AssertionError.class, () -> operation.execute(binding, OPERATION_CONTEXT)); } @Test void stubbingWithThrowsException() { ReadBinding binding = Mockito.mock(ReadBinding.class, new ThrowsException(new AssertionError("Unfortunately, you cannot do stubbing"))); - assertThrows(AssertionError.class, () -> when(binding.getOperationContext()).thenReturn(OPERATION_CONTEXT)); + assertThrows(AssertionError.class, () -> when(binding.getReadConnectionSource(OPERATION_CONTEXT)).thenReturn(null)); } @Test void stubbingWithInsufficientStubbingDetector() { MongoMockito.mock(ReadBinding.class, bindingMock -> - when(bindingMock.getOperationContext()).thenReturn(OPERATION_CONTEXT) + when(bindingMock.getReadConnectionSource(OPERATION_CONTEXT)).thenReturn(null) ); } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncChangeStreamBatchCursorSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncChangeStreamBatchCursorSpecification.groovy index 998c0a28b6e..37a7220a851 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncChangeStreamBatchCursorSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncChangeStreamBatchCursorSpecification.groovy @@ -16,32 +16,37 @@ package com.mongodb.internal.operation +import com.mongodb.MongoClientSettings import com.mongodb.MongoException import com.mongodb.async.FutureResultCallback +import com.mongodb.internal.IgnorableRequestContext import com.mongodb.internal.TimeoutContext +import com.mongodb.internal.TimeoutSettings import com.mongodb.internal.async.SingleResultCallback import com.mongodb.internal.binding.AsyncReadBinding +import com.mongodb.internal.connection.NoOpSessionContext import com.mongodb.internal.connection.OperationContext import org.bson.Document +import org.bson.RawBsonDocument import spock.lang.Specification +import java.util.concurrent.TimeUnit + import static java.util.concurrent.TimeUnit.SECONDS class AsyncChangeStreamBatchCursorSpecification extends Specification { def 'should call the underlying AsyncCommandBatchCursor'() { given: - def changeStreamOpertation = Stub(ChangeStreamOperation) + def changeStreamOperation = Stub(ChangeStreamOperation) def binding = Mock(AsyncReadBinding) - def operationContext = Mock(OperationContext) - def timeoutContext = Mock(TimeoutContext) - binding.getOperationContext() >> operationContext - operationContext.getTimeoutContext() >> timeoutContext - timeoutContext.hasTimeoutMS() >> hasTimeoutMS + def operationContext = getOperationContext() + operationContext.getTimeoutContext().hasTimeoutMS() >> hasTimeoutMS - def wrapped = Mock(AsyncCommandBatchCursor) def callback = Stub(SingleResultCallback) - def cursor = new AsyncChangeStreamBatchCursor(changeStreamOpertation, wrapped, binding, null, + AsyncCoreCursor wrapped = Mock(AsyncCoreCursor) + def cursor = new AsyncChangeStreamBatchCursor(changeStreamOperation, + wrapped, binding, operationContext, null, ServerVersionHelper.FOUR_DOT_FOUR_WIRE_VERSION) when: @@ -54,13 +59,13 @@ class AsyncChangeStreamBatchCursorSpecification extends Specification { cursor.next(callback) then: - 1 * wrapped.next(_) >> { it[0].onResult([], null) } + 1 * wrapped.next(_ as OperationContext, _) >> { it[1].onResult([], null) } when: cursor.close() then: - 1 * wrapped.close() + 1 * wrapped.close(_) 1 * binding.release() when: @@ -77,24 +82,23 @@ class AsyncChangeStreamBatchCursorSpecification extends Specification { def 'should not close the cursor in next if the cursor was closed before next completed'() { def changeStreamOpertation = Stub(ChangeStreamOperation) def binding = Mock(AsyncReadBinding) - def operationContext = Mock(OperationContext) - def timeoutContext = Mock(TimeoutContext) - binding.getOperationContext() >> operationContext - operationContext.getTimeoutContext() >> timeoutContext - timeoutContext.hasTimeoutMS() >> hasTimeoutMS - def wrapped = Mock(AsyncCommandBatchCursor) + def operationContext = getOperationContext() + operationContext.getTimeoutContext().hasTimeoutMS() >> hasTimeoutMS + def callback = Stub(SingleResultCallback) - def cursor = new AsyncChangeStreamBatchCursor(changeStreamOpertation, wrapped, binding, null, + AsyncCoreCursor wrapped = Mock(AsyncCoreCursor) + def cursor = new AsyncChangeStreamBatchCursor(changeStreamOpertation, + wrapped, binding, operationContext, null, ServerVersionHelper.FOUR_DOT_FOUR_WIRE_VERSION) when: cursor.next(callback) then: - 1 * wrapped.next(_) >> { + 1 * wrapped.next(_ as OperationContext, _) >> { // Simulate the user calling close while wrapped.next() is in flight cursor.close() - it[0].onResult([], null) + it[1].onResult([], null) } then: @@ -110,13 +114,12 @@ class AsyncChangeStreamBatchCursorSpecification extends Specification { def 'should throw a MongoException when next/tryNext is called after the cursor is closed'() { def changeStreamOpertation = Stub(ChangeStreamOperation) def binding = Mock(AsyncReadBinding) - def operationContext = Mock(OperationContext) - def timeoutContext = Mock(TimeoutContext) - binding.getOperationContext() >> operationContext - operationContext.getTimeoutContext() >> timeoutContext - timeoutContext.hasTimeoutMS() >> hasTimeoutMS - def wrapped = Mock(AsyncCommandBatchCursor) - def cursor = new AsyncChangeStreamBatchCursor(changeStreamOpertation, wrapped, binding, null, + def operationContext = getOperationContext() + operationContext.getTimeoutContext().hasTimeoutMS() >> hasTimeoutMS + + AsyncCoreCursor wrapped = Mock(AsyncCoreCursor) + def cursor = new AsyncChangeStreamBatchCursor(changeStreamOpertation, + wrapped, binding, operationContext, null, ServerVersionHelper.FOUR_DOT_FOUR_WIRE_VERSION) given: @@ -138,4 +141,13 @@ class AsyncChangeStreamBatchCursorSpecification extends Specification { cursor.next(futureResultCallback) futureResultCallback.get(1, SECONDS) } + + OperationContext getOperationContext() { + def timeoutContext = Spy(new TimeoutContext(TimeoutSettings.create( + MongoClientSettings.builder().timeout(3, TimeUnit.SECONDS).build()))) + Spy(new OperationContext( + IgnorableRequestContext.INSTANCE, + NoOpSessionContext.INSTANCE, + timeoutContext, null)) + } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncCommandBatchCursorSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncCommandBatchCursorSpecification.groovy index d2bcd0804bb..901c2607ead 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncCommandBatchCursorSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncCommandBatchCursorSpecification.groovy @@ -16,6 +16,7 @@ package com.mongodb.internal.operation + import com.mongodb.MongoClientSettings import com.mongodb.MongoCommandException import com.mongodb.MongoException @@ -29,11 +30,13 @@ import com.mongodb.connection.ServerConnectionState import com.mongodb.connection.ServerDescription import com.mongodb.connection.ServerType import com.mongodb.connection.ServerVersion +import com.mongodb.internal.IgnorableRequestContext import com.mongodb.internal.TimeoutContext import com.mongodb.internal.TimeoutSettings import com.mongodb.internal.async.SingleResultCallback import com.mongodb.internal.binding.AsyncConnectionSource import com.mongodb.internal.connection.AsyncConnection +import com.mongodb.internal.connection.NoOpSessionContext import com.mongodb.internal.connection.OperationContext import org.bson.BsonArray import org.bson.BsonDocument @@ -58,7 +61,8 @@ class AsyncCommandBatchCursorSpecification extends Specification { def initialConnection = referenceCountedAsyncConnection() def connection = referenceCountedAsyncConnection() def connectionSource = getAsyncConnectionSource(connection) - def timeoutContext = connectionSource.getOperationContext().getTimeoutContext() + def operationContext = getOperationContext() + def timeoutContext = operationContext.getTimeoutContext() def firstBatch = createCommandResult([]) def expectedCommand = new BsonDocument('getMore': new BsonInt64(CURSOR_ID)) .append('collection', new BsonString(NAMESPACE.getCollectionName())) @@ -69,10 +73,10 @@ class AsyncCommandBatchCursorSpecification extends Specification { def reply = getMoreResponse([], 0) when: - def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, firstBatch, batchSize, maxTimeMS, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new AsyncCommandCoreCursor<>(firstBatch, batchSize, CODEC, null, connectionSource, initialConnection) + def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, maxTimeMS, operationContext, commandCoreCursor) then: - 1 * timeoutContext.setMaxTimeOverride(*_) + 1 * timeoutContext.withMaxTimeOverride(*_) when: def batch = nextBatch(cursor) @@ -107,15 +111,17 @@ class AsyncCommandBatchCursorSpecification extends Specification { def serverVersion = new ServerVersion([3, 6, 0]) def connection = referenceCountedAsyncConnection(serverVersion) def connectionSource = getAsyncConnectionSource(connection) - def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def operationContext = getOperationContext() + def commandCoreCursor = new AsyncCommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) when: cursor.close() then: - if (cursor.getServerCursor() != null) { - 1 * connection.commandAsync(NAMESPACE.databaseName, createKillCursorsDocument(cursor.getServerCursor()), _, primary(), *_) >> { + if (commandCoreCursor.getServerCursor() != null) { + 1 * connection.commandAsync(NAMESPACE.databaseName, + createKillCursorsDocument(commandCoreCursor.getServerCursor()), _, primary(), *_) >> { it.last().onResult(null, null) } } @@ -137,8 +143,8 @@ class AsyncCommandBatchCursorSpecification extends Specification { when: def firstBatch = createCommandResult(FIRST_BATCH, 0) - def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new AsyncCommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) then: nextBatch(cursor) == FIRST_BATCH @@ -167,8 +173,8 @@ class AsyncCommandBatchCursorSpecification extends Specification { when: def firstBatch = createCommandResult([], CURSOR_ID) - def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new AsyncCommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) def batch = nextBatch(cursor) then: @@ -213,8 +219,8 @@ class AsyncCommandBatchCursorSpecification extends Specification { def firstBatch = createCommandResult() when: - def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new AsyncCommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) def batch = nextBatch(cursor) then: @@ -267,8 +273,9 @@ class AsyncCommandBatchCursorSpecification extends Specification { def connectionSource = getAsyncConnectionSource(connectionA, connectionB) when: - def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, createCommandResult(FIRST_BATCH, 42), 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new AsyncCommandCoreCursor<>(createCommandResult(FIRST_BATCH, 42), 0, + CODEC, null, connectionSource, initialConnection) + def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) def batch = nextBatch(cursor) then: @@ -303,8 +310,8 @@ class AsyncCommandBatchCursorSpecification extends Specification { def firstBatch = createCommandResult() when: - def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new AsyncCommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) def batch = nextBatch(cursor) then: @@ -343,8 +350,8 @@ class AsyncCommandBatchCursorSpecification extends Specification { def initialConnection = referenceCountedAsyncConnection() def connectionSource = getAsyncConnectionSourceWithResult(ServerType.STANDALONE) { [null, MONGO_EXCEPTION] } def firstBatch = createCommandResult() - def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new AsyncCommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) when: cursor.close() @@ -363,8 +370,8 @@ class AsyncCommandBatchCursorSpecification extends Specification { when: def firstBatch = createCommandResult() - def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new AsyncCommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) then: nextBatch(cursor) @@ -390,8 +397,8 @@ class AsyncCommandBatchCursorSpecification extends Specification { when: def firstBatch = createCommandResult() - def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new AsyncCommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new AsyncCommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) then: connectionSource.getCount() == 1 @@ -523,17 +530,12 @@ class AsyncCommandBatchCursorSpecification extends Specification { .state(ServerConnectionState.CONNECTED) .build() } - OperationContext operationContext = Mock(OperationContext) - def timeoutContext = Spy(new TimeoutContext(TimeoutSettings.create( - MongoClientSettings.builder().timeout(3, TimeUnit.SECONDS).build()))) - operationContext.getTimeoutContext() >> timeoutContext - mock.getOperationContext() >> operationContext - mock.getConnection(_) >> { + mock.getConnection(_ as OperationContext, _ as SingleResultCallback) >> { if (counter == 0) { throw new IllegalStateException('Tried to use released AsyncConnectionSource') } def (result, error) = connectionCallbackResults() - it[0].onResult(result, error) + it[1].onResult(result, error) } mock.retain() >> { if (released) { @@ -555,4 +557,13 @@ class AsyncCommandBatchCursorSpecification extends Specification { mock.getCount() >> { counter } mock } + + OperationContext getOperationContext() { + def timeoutContext = Spy(new TimeoutContext(TimeoutSettings.create( + MongoClientSettings.builder().timeout(3, TimeUnit.SECONDS).build()))) + Spy(new OperationContext( + IgnorableRequestContext.INSTANCE, + NoOpSessionContext.INSTANCE, + timeoutContext, null)) + } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncOperationHelperSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncOperationHelperSpecification.groovy index ba69097cffa..d573822cab7 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncOperationHelperSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncOperationHelperSpecification.groovy @@ -27,6 +27,7 @@ import com.mongodb.internal.binding.AsyncConnectionSource import com.mongodb.internal.binding.AsyncReadBinding import com.mongodb.internal.binding.AsyncWriteBinding import com.mongodb.internal.connection.AsyncConnection +import com.mongodb.internal.connection.OperationContext import com.mongodb.internal.session.SessionContext import com.mongodb.internal.validator.NoOpFieldNameValidator import org.bson.BsonDocument @@ -80,17 +81,15 @@ class AsyncOperationHelperSpecification extends Specification { getReadConcern() >> ReadConcern.DEFAULT }) def connectionSource = Stub(AsyncConnectionSource) { - getConnection(_) >> { it[0].onResult(connection, null) } + getConnection(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(connection, null) } getServerDescription() >> serverDescription - getOperationContext() >> operationContext } def asyncWriteBinding = Stub(AsyncWriteBinding) { - getWriteConnectionSource(_) >> { it[0].onResult(connectionSource, null) } - getOperationContext() >> operationContext + getWriteConnectionSource(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(connectionSource, null) } } when: - executeRetryableWriteAsync(asyncWriteBinding, dbName, primary(), + executeRetryableWriteAsync(asyncWriteBinding, operationContext, dbName, primary(), NoOpFieldNameValidator.INSTANCE, decoder, commandCreator, FindAndModifyHelper.asyncTransformer(), { cmd -> cmd }, callback) @@ -109,15 +108,15 @@ class AsyncOperationHelperSpecification extends Specification { def callback = Stub(SingleResultCallback) def connection = Mock(AsyncConnection) def connectionSource = Stub(AsyncConnectionSource) { - getConnection(_) >> { it[0].onResult(connection, null) } + getConnection(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(connection, null) } } def asyncWriteBinding = Stub(AsyncWriteBinding) { - getWriteConnectionSource(_) >> { it[0].onResult(connectionSource, null) } + getWriteConnectionSource(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(connectionSource, null) } } def connectionDescription = Stub(ConnectionDescription) when: - executeCommandAsync(asyncWriteBinding, dbName, command, connection, { t, conn -> t }, callback) + executeCommandAsync(asyncWriteBinding, OPERATION_CONTEXT, dbName, command, connection, { t, conn -> t }, callback) then: _ * connection.getDescription() >> connectionDescription @@ -135,18 +134,16 @@ class AsyncOperationHelperSpecification extends Specification { def function = Stub(CommandReadTransformerAsync) def connection = Mock(AsyncConnection) def connectionSource = Stub(AsyncConnectionSource) { - getOperationContext() >> OPERATION_CONTEXT - getConnection(_) >> { it[0].onResult(connection, null) } + getConnection(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(connection, null) } getReadPreference() >> readPreference } def asyncReadBinding = Stub(AsyncReadBinding) { - getOperationContext() >> OPERATION_CONTEXT - getReadConnectionSource(_) >> { it[0].onResult(connectionSource, null) } + getReadConnectionSource(_ as OperationContext, _ as SingleResultCallback) >> { it[1].onResult(connectionSource, null) } } def connectionDescription = Stub(ConnectionDescription) when: - executeRetryableReadAsync(asyncReadBinding, dbName, commandCreator, decoder, function, false, callback) + executeRetryableReadAsync(asyncReadBinding, OPERATION_CONTEXT, dbName, commandCreator, decoder, function, false, callback) then: _ * connection.getDescription() >> connectionDescription diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/ChangeStreamBatchCursorSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/ChangeStreamBatchCursorSpecification.groovy index 09c6ff221b6..919258a57c7 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/ChangeStreamBatchCursorSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/ChangeStreamBatchCursorSpecification.groovy @@ -16,11 +16,20 @@ package com.mongodb.internal.operation +import com.mongodb.MongoClientSettings +import com.mongodb.internal.IgnorableRequestContext +import com.mongodb.internal.TimeoutContext +import com.mongodb.internal.TimeoutSettings import com.mongodb.internal.binding.ReadBinding +import com.mongodb.internal.connection.NoOpSessionContext +import com.mongodb.internal.connection.OperationContext import org.bson.BsonDocument import org.bson.BsonInt32 +import org.bson.RawBsonDocument import spock.lang.Specification +import java.util.concurrent.TimeUnit + import static java.util.Collections.emptyList class ChangeStreamBatchCursorSpecification extends Specification { @@ -29,11 +38,14 @@ class ChangeStreamBatchCursorSpecification extends Specification { given: def changeStreamOperation = Stub(ChangeStreamOperation) def binding = Stub(ReadBinding) - def wrapped = Mock(CommandBatchCursor) def resumeToken = new BsonDocument('_id': new BsonInt32(1)) - def cursor = new ChangeStreamBatchCursor(changeStreamOperation, wrapped, binding, resumeToken, + def operationContext = getOperationContext() + CoreCursor wrapped = Mock(CoreCursor) + def cursor = new ChangeStreamBatchCursor(changeStreamOperation, + wrapped, binding, operationContext, resumeToken, ServerVersionHelper.FOUR_DOT_FOUR_WIRE_VERSION) + when: cursor.setBatchSize(10) @@ -44,27 +56,35 @@ class ChangeStreamBatchCursorSpecification extends Specification { cursor.tryNext() then: - 1 * wrapped.tryNext() + 1 * wrapped.tryNext(_ as OperationContext) 1 * wrapped.getPostBatchResumeToken() when: cursor.next() then: - 1 * wrapped.next() >> emptyList() + 1 * wrapped.next(_ as OperationContext) >> emptyList() 1 * wrapped.getPostBatchResumeToken() when: cursor.close() then: - 1 * wrapped.close() + 1 * wrapped.close(_ as OperationContext) when: cursor.close() then: - 0 * wrapped.close() + 0 * wrapped.close(_ as OperationContext) } + OperationContext getOperationContext() { + def timeoutContext = Spy(new TimeoutContext(TimeoutSettings.create( + MongoClientSettings.builder().timeout(3, TimeUnit.SECONDS).build()))) + Spy(new OperationContext( + IgnorableRequestContext.INSTANCE, + NoOpSessionContext.INSTANCE, + timeoutContext, null)); + } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/ChangeStreamBatchCursorTest.java b/driver-core/src/test/unit/com/mongodb/internal/operation/ChangeStreamBatchCursorTest.java index 48c3a50e79a..b0cc6bc4c96 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/ChangeStreamBatchCursorTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/ChangeStreamBatchCursorTest.java @@ -20,34 +20,47 @@ import com.mongodb.MongoOperationTimeoutException; import com.mongodb.ServerAddress; import com.mongodb.connection.ServerDescription; +import com.mongodb.internal.IgnorableRequestContext; import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.TimeoutSettings; import com.mongodb.internal.binding.ConnectionSource; import com.mongodb.internal.binding.ReadBinding; import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.NoOpSessionContext; import com.mongodb.internal.connection.OperationContext; import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.Document; import org.bson.RawBsonDocument; +import org.bson.assertions.Assertions; import org.bson.codecs.DocumentCodec; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import static com.mongodb.ClusterFixture.sleep; import static com.mongodb.internal.operation.CommandBatchCursorHelper.MESSAGE_IF_CLOSED_AS_CURSOR; +import static java.lang.String.format; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -57,6 +70,8 @@ final class ChangeStreamBatchCursorTest { private static final List RESULT_FROM_NEW_CURSOR = new ArrayList<>(); + private static final long TIMEOUT_MILLISECONDS = TimeUnit.MINUTES.toMillis(10); + private static final int TIMEOUT_CONSUMPTION_SLEEP_MS = 100; private final int maxWireVersion = ServerVersionHelper.SIX_DOT_ZERO_WIRE_VERSION; private ServerDescription serverDescription; private TimeoutContext timeoutContext; @@ -65,26 +80,29 @@ final class ChangeStreamBatchCursorTest { private ConnectionSource connectionSource; private ReadBinding readBinding; private BsonDocument resumeToken; - private CommandBatchCursor commandBatchCursor; - private CommandBatchCursor newCommandBatchCursor; + private CoreCursor coreCursor; + private CoreCursor newCoreCursor; private ChangeStreamBatchCursor newChangeStreamCursor; private ChangeStreamOperation changeStreamOperation; @Test @DisplayName("should return result on next") void shouldReturnResultOnNext() { - when(commandBatchCursor.next()).thenReturn(RESULT_FROM_NEW_CURSOR); + when(coreCursor.next(any())).thenReturn(RESULT_FROM_NEW_CURSOR); ChangeStreamBatchCursor cursor = createChangeStreamCursor(); //when + sleep(TIMEOUT_CONSUMPTION_SLEEP_MS); // Simulate some delay to ensure timeout is reset. List next = cursor.next(); //then assertEquals(RESULT_FROM_NEW_CURSOR, next); - verify(timeoutContext, times(1)).resetTimeoutIfPresent(); - verify(commandBatchCursor, times(1)).next(); - verify(commandBatchCursor, atLeastOnce()).getPostBatchResumeToken(); - verifyNoMoreInteractions(commandBatchCursor); + assertTimeoutWasRefreshedForOperation(operationContextCaptor -> + verify(coreCursor).next(operationContextCaptor.capture())); + + verify(coreCursor, times(1)).next(any()); + verify(coreCursor, atLeastOnce()).getPostBatchResumeToken(); + verifyNoMoreInteractions(coreCursor); verify(changeStreamOperation, times(1)).getDecoder(); verifyNoMoreInteractions(changeStreamOperation); } @@ -92,37 +110,79 @@ void shouldReturnResultOnNext() { @Test @DisplayName("should throw timeout exception without resume attempt on next") void shouldThrowTimeoutExceptionWithoutResumeAttemptOnNext() { - when(commandBatchCursor.next()).thenThrow(new MongoOperationTimeoutException("timeout")); + when(coreCursor.next(any())).thenThrow(new MongoOperationTimeoutException("timeout")); ChangeStreamBatchCursor cursor = createChangeStreamCursor(); //when assertThrows(MongoOperationTimeoutException.class, cursor::next); //then - verify(timeoutContext, times(1)).resetTimeoutIfPresent(); - verify(commandBatchCursor, times(1)).next(); - verify(commandBatchCursor, atLeastOnce()).getPostBatchResumeToken(); - verifyNoMoreInteractions(commandBatchCursor); + verify(coreCursor, times(1)).next(any()); + verify(coreCursor, atLeastOnce()).getPostBatchResumeToken(); + verifyNoMoreInteractions(coreCursor); verifyNoResumeAttemptCalled(); } + @Test + @DisplayName("should not refresh timeout on next() after cursor close() when resume attempt is made") + void shouldNotRefreshTimeoutOnNextAfterCloseWhenResumeAttemptIsMade() { + // given + when(coreCursor.next(any())).thenThrow(new MongoOperationTimeoutException("timeout")); + ChangeStreamBatchCursor cursor = createChangeStreamCursor(); + // when + assertThrows(MongoOperationTimeoutException.class, cursor::next); + // trigger resume attempt to close the cursor + cursor.next(); + + // then + TimeoutContext timeoutContextForClose = captureTimeoutContext(captor -> verify(coreCursor) + .close(captor.capture())); + TimeoutContext timeoutContextForNext = captureTimeoutContext(captor -> verify(newCoreCursor) + .next(captor.capture())); + assertEquals(timeoutContextForNext.getTimeout(), timeoutContextForClose.getTimeout(), + "Timeout should not be refreshed on close after resume attempt"); + } + + @Test + @DisplayName("should not refresh timeout on close() after cursor next() when resume attempt is made") + void shouldNotRefreshTimeoutOnCloseAfterNextWhenResumeAttemptIsMade() { + // given + when(coreCursor.next(any())).thenThrow(new MongoNotPrimaryException(new BsonDocument(), new ServerAddress())); + ChangeStreamBatchCursor cursor = createChangeStreamCursor(); + + // when + cursor.next(); + + // then + TimeoutContext timeoutContextForNext = captureTimeoutContext(captor -> verify(coreCursor) + .next(captor.capture())); + TimeoutContext timeoutContextForClose = captureTimeoutContext(captor -> verify(coreCursor) + .close(captor.capture())); + assertEquals(timeoutContextForNext.getTimeout(), timeoutContextForClose.getTimeout(), + "Timeout should not be refreshed on close after resume attempt"); + } + @Test @DisplayName("should perform resume attempt on next when resumable error is thrown") void shouldPerformResumeAttemptOnNextWhenResumableErrorIsThrown() { - when(commandBatchCursor.next()).thenThrow(new MongoNotPrimaryException(new BsonDocument(), new ServerAddress())); + when(coreCursor.next(any())).thenThrow(new MongoNotPrimaryException(new BsonDocument(), new ServerAddress())); ChangeStreamBatchCursor cursor = createChangeStreamCursor(); + //when + sleep(TIMEOUT_CONSUMPTION_SLEEP_MS); List next = cursor.next(); //then assertEquals(RESULT_FROM_NEW_CURSOR, next); - verify(timeoutContext, times(1)).resetTimeoutIfPresent(); - verify(commandBatchCursor, times(1)).next(); - verify(commandBatchCursor, atLeastOnce()).getPostBatchResumeToken(); + assertTimeoutWasRefreshedForOperation(operationContextCaptor -> + verify(newCoreCursor).next(operationContextCaptor.capture())); + verify(coreCursor, times(1)).next(any()); + verify(coreCursor, atLeastOnce()).getPostBatchResumeToken(); verifyResumeAttemptCalled(); verify(changeStreamOperation, times(1)).getDecoder(); - verify(newCommandBatchCursor, times(1)).next(); - verify(newCommandBatchCursor, atLeastOnce()).getPostBatchResumeToken(); - verifyNoMoreInteractions(newCommandBatchCursor); + verify(newCoreCursor, times(1)).next(any()); + verify(newCoreCursor, atLeastOnce()).getPostBatchResumeToken(); + + verifyNoMoreInteractions(newCoreCursor); verifyNoMoreInteractions(changeStreamOperation); } @@ -130,45 +190,50 @@ void shouldPerformResumeAttemptOnNextWhenResumableErrorIsThrown() { @Test @DisplayName("should resume only once on subsequent calls after timeout error") void shouldResumeOnlyOnceOnSubsequentCallsAfterTimeoutError() { - when(commandBatchCursor.next()).thenThrow(new MongoOperationTimeoutException("timeout")); + when(coreCursor.next(any())).thenThrow(new MongoOperationTimeoutException("timeout")); ChangeStreamBatchCursor cursor = createChangeStreamCursor(); //when assertThrows(MongoOperationTimeoutException.class, cursor::next); //then - verify(timeoutContext, times(1)).resetTimeoutIfPresent(); - verify(commandBatchCursor, times(1)).next(); - verify(commandBatchCursor, atLeastOnce()).getPostBatchResumeToken(); - verifyNoMoreInteractions(commandBatchCursor); + assertTimeoutWasRefreshedForOperation(operationContextCaptor -> + verify(coreCursor).next(operationContextCaptor.capture())); + verify(coreCursor, times(1)).next(any()); + verify(coreCursor, atLeastOnce()).getPostBatchResumeToken(); + verifyNoMoreInteractions(coreCursor); verifyNoResumeAttemptCalled(); - clearInvocations(commandBatchCursor, newCommandBatchCursor, timeoutContext, changeStreamOperation, readBinding); + clearInvocations(coreCursor, newCoreCursor, timeoutContext, changeStreamOperation, readBinding); //when seconds next is called. Resume is attempted. + sleep(TIMEOUT_CONSUMPTION_SLEEP_MS); List next = cursor.next(); //then assertEquals(Collections.emptyList(), next); - verify(timeoutContext, times(1)).resetTimeoutIfPresent(); - verify(commandBatchCursor, times(1)).close(); - verifyNoMoreInteractions(commandBatchCursor); + verify(coreCursor, times(1)).close(any()); + verifyNoMoreInteractions(coreCursor); + assertTimeoutWasRefreshedForOperation(operationContextCaptor -> + verify(newCoreCursor).next(operationContextCaptor.capture())); verify(changeStreamOperation).setChangeStreamOptionsForResume(resumeToken, maxWireVersion); verify(changeStreamOperation, times(1)).getDecoder(); - verify(changeStreamOperation, times(1)).execute(readBinding); + verify(changeStreamOperation, times(1)).execute(eq(readBinding), any()); verifyNoMoreInteractions(changeStreamOperation); - verify(newCommandBatchCursor, times(1)).next(); - verify(newCommandBatchCursor, atLeastOnce()).getPostBatchResumeToken(); - clearInvocations(commandBatchCursor, newCommandBatchCursor, timeoutContext, changeStreamOperation, readBinding); + verify(newCoreCursor, times(1)).next(any()); + verify(newCoreCursor, atLeastOnce()).getPostBatchResumeToken(); + clearInvocations(coreCursor, newCoreCursor, timeoutContext, changeStreamOperation, readBinding); //when third next is called. No resume is attempted. + sleep(TIMEOUT_CONSUMPTION_SLEEP_MS); List next2 = cursor.next(); //then assertEquals(Collections.emptyList(), next2); - verifyNoInteractions(commandBatchCursor); - verify(timeoutContext, times(1)).resetTimeoutIfPresent(); - verify(newCommandBatchCursor, times(1)).next(); - verify(newCommandBatchCursor, atLeastOnce()).getPostBatchResumeToken(); - verifyNoMoreInteractions(newCommandBatchCursor); + verifyNoInteractions(coreCursor); + assertTimeoutWasRefreshedForOperation(operationContextCaptor -> + verify(newCoreCursor).next(operationContextCaptor.capture())); + verify(newCoreCursor, times(1)).next(any()); + verify(newCoreCursor, atLeastOnce()).getPostBatchResumeToken(); + verifyNoMoreInteractions(newCoreCursor); verify(changeStreamOperation, times(1)).getDecoder(); verifyNoMoreInteractions(changeStreamOperation); verifyNoInteractions(readBinding); @@ -178,21 +243,20 @@ void shouldResumeOnlyOnceOnSubsequentCallsAfterTimeoutError() { @Test @DisplayName("should propagate any errors occurred in aggregate operation during creating new change stream when previous next timed out") void shouldPropagateAnyErrorsOccurredInAggregateOperation() { - when(commandBatchCursor.next()).thenThrow(new MongoOperationTimeoutException("timeout")); + when(coreCursor.next(any())).thenThrow(new MongoOperationTimeoutException("timeout")); MongoNotPrimaryException resumableError = new MongoNotPrimaryException(new BsonDocument(), new ServerAddress()); - when(changeStreamOperation.execute(readBinding)).thenThrow(resumableError); + when(changeStreamOperation.execute(eq(readBinding), any())).thenThrow(resumableError); ChangeStreamBatchCursor cursor = createChangeStreamCursor(); //when assertThrows(MongoOperationTimeoutException.class, cursor::next); - clearInvocations(commandBatchCursor, newCommandBatchCursor, timeoutContext, changeStreamOperation, readBinding); + clearInvocations(coreCursor, newCoreCursor, timeoutContext, changeStreamOperation, readBinding); assertThrows(MongoNotPrimaryException.class, cursor::next); //then - verify(timeoutContext, times(1)).resetTimeoutIfPresent(); verifyResumeAttemptCalled(); verifyNoMoreInteractions(changeStreamOperation); - verifyNoInteractions(newCommandBatchCursor); + verifyNoInteractions(newCoreCursor); } @@ -203,31 +267,33 @@ void shouldResumeAfterTimeoutInAggregateOnNextCall() { ChangeStreamBatchCursor cursor = createChangeStreamCursor(); //first next operation times out on getMore - when(commandBatchCursor.next()).thenThrow(new MongoOperationTimeoutException("timeout during next call")); + when(coreCursor.next(any())).thenThrow(new MongoOperationTimeoutException("timeout during next call")); assertThrows(MongoOperationTimeoutException.class, cursor::next); - clearInvocations(commandBatchCursor, newCommandBatchCursor, timeoutContext, changeStreamOperation, readBinding); + clearInvocations(coreCursor, newCoreCursor, timeoutContext, changeStreamOperation, readBinding); //second next operation times out on resume attempt when creating change stream - when(changeStreamOperation.execute(readBinding)).thenThrow(new MongoOperationTimeoutException("timeout during resumption")); + when(changeStreamOperation.execute(eq(readBinding), any())).thenThrow( + new MongoOperationTimeoutException("timeout during resumption")); assertThrows(MongoOperationTimeoutException.class, cursor::next); - clearInvocations(commandBatchCursor, newCommandBatchCursor, timeoutContext, changeStreamOperation); + clearInvocations(coreCursor, newCoreCursor, timeoutContext, changeStreamOperation); - doReturn(newChangeStreamCursor).when(changeStreamOperation).execute(readBinding); + doReturn(newChangeStreamCursor).when(changeStreamOperation).execute(eq(readBinding), any()); //when third operation succeeds to resume and call next + sleep(TIMEOUT_CONSUMPTION_SLEEP_MS); List next = cursor.next(); //then assertEquals(RESULT_FROM_NEW_CURSOR, next); - verify(timeoutContext, times(1)).resetTimeoutIfPresent(); - verifyResumeAttemptCalled(); verify(changeStreamOperation, times(1)).getDecoder(); verifyNoMoreInteractions(changeStreamOperation); - verify(newCommandBatchCursor, times(1)).next(); - verify(newCommandBatchCursor, atLeastOnce()).getPostBatchResumeToken(); - verifyNoMoreInteractions(newCommandBatchCursor); + assertTimeoutWasRefreshedForOperation(operationContextCaptor -> + verify(newCoreCursor).next(operationContextCaptor.capture())); + verify(newCoreCursor, times(1)).next(any()); + verify(newCoreCursor, atLeastOnce()).getPostBatchResumeToken(); + verifyNoMoreInteractions(newCoreCursor); } @Test @@ -237,51 +303,49 @@ void shouldCloseChangeStreamWhenResumeOperationFailsDueToNonTimeoutError() { ChangeStreamBatchCursor cursor = createChangeStreamCursor(); //first next operation times out on getMore - when(commandBatchCursor.next()).thenThrow(new MongoOperationTimeoutException("timeout during next call")); + when(coreCursor.next(any())).thenThrow(new MongoOperationTimeoutException("timeout during next call")); assertThrows(MongoOperationTimeoutException.class, cursor::next); - clearInvocations(commandBatchCursor, newCommandBatchCursor, timeoutContext, changeStreamOperation, readBinding); + clearInvocations(coreCursor, newCoreCursor, timeoutContext, changeStreamOperation, readBinding); //when second next operation errors on resume attempt when creating change stream - when(changeStreamOperation.execute(readBinding)).thenThrow(new MongoNotPrimaryException(new BsonDocument(), new ServerAddress())); + when(changeStreamOperation.execute(eq(readBinding), any())).thenThrow( + new MongoNotPrimaryException(new BsonDocument(), new ServerAddress())); assertThrows(MongoNotPrimaryException.class, cursor::next); //then - verify(timeoutContext, times(1)).resetTimeoutIfPresent(); verifyResumeAttemptCalled(); verifyNoMoreInteractions(changeStreamOperation); - verifyNoInteractions(newCommandBatchCursor); - clearInvocations(commandBatchCursor, newCommandBatchCursor, timeoutContext, changeStreamOperation, readBinding); - + verifyNoInteractions(newCoreCursor); + clearInvocations(coreCursor, newCoreCursor, timeoutContext, changeStreamOperation, readBinding); //when third next operation errors with cursor closed exception - doThrow(new IllegalStateException(MESSAGE_IF_CLOSED_AS_CURSOR)).when(commandBatchCursor).next(); + doThrow(new IllegalStateException(MESSAGE_IF_CLOSED_AS_CURSOR)).when(coreCursor).next(any()); MongoException mongoException = assertThrows(MongoException.class, cursor::next); //then assertEquals(MESSAGE_IF_CLOSED_AS_CURSOR, mongoException.getMessage()); - verify(timeoutContext, times(1)).resetTimeoutIfPresent(); verifyNoResumeAttemptCalled(); } private ChangeStreamBatchCursor createChangeStreamCursor() { ChangeStreamBatchCursor cursor = - new ChangeStreamBatchCursor<>(changeStreamOperation, commandBatchCursor, readBinding, null, maxWireVersion); - clearInvocations(commandBatchCursor, newCommandBatchCursor, timeoutContext, changeStreamOperation, readBinding); + new ChangeStreamBatchCursor<>(changeStreamOperation, coreCursor, readBinding, operationContext, null, maxWireVersion); + clearInvocations(coreCursor, newCoreCursor, timeoutContext, changeStreamOperation, readBinding); return cursor; } private void verifyNoResumeAttemptCalled() { verifyNoInteractions(changeStreamOperation); - verifyNoInteractions(newCommandBatchCursor); + verifyNoInteractions(newCoreCursor); verifyNoInteractions(readBinding); } private void verifyResumeAttemptCalled() { - verify(commandBatchCursor, times(1)).close(); + verify(coreCursor, times(1)).close(any()); verify(changeStreamOperation).setChangeStreamOptionsForResume(resumeToken, maxWireVersion); - verify(changeStreamOperation, times(1)).execute(readBinding); - verifyNoMoreInteractions(commandBatchCursor); + verify(changeStreamOperation, times(1)).execute(eq(readBinding), any()); + verifyNoMoreInteractions(coreCursor); } @BeforeEach @@ -291,42 +355,76 @@ void setUp() { serverDescription = mock(ServerDescription.class); when(serverDescription.getMaxWireVersion()).thenReturn(maxWireVersion); - timeoutContext = mock(TimeoutContext.class); + timeoutContext = spy(new TimeoutContext(new TimeoutSettings( + 10, 10, 10, TIMEOUT_MILLISECONDS, 0 + ))); when(timeoutContext.hasTimeoutMS()).thenReturn(true); - doNothing().when(timeoutContext).resetTimeoutIfPresent(); - operationContext = mock(OperationContext.class); - when(operationContext.getTimeoutContext()).thenReturn(timeoutContext); + operationContext = spy(new OperationContext( + IgnorableRequestContext.INSTANCE, + NoOpSessionContext.INSTANCE, + timeoutContext, + null)); + connection = mock(Connection.class); when(connection.command(any(), any(), any(), any(), any(), any())).thenReturn(null); connectionSource = mock(ConnectionSource.class); - when(connectionSource.getConnection()).thenReturn(connection); + when(connectionSource.getConnection(any())).thenReturn(connection); when(connectionSource.release()).thenReturn(1); when(connectionSource.getServerDescription()).thenReturn(serverDescription); readBinding = mock(ReadBinding.class); - when(readBinding.getOperationContext()).thenReturn(operationContext); when(readBinding.retain()).thenReturn(readBinding); when(readBinding.release()).thenReturn(1); - when(readBinding.getReadConnectionSource()).thenReturn(connectionSource); + when(readBinding.getReadConnectionSource(any())).thenReturn(connectionSource); - commandBatchCursor = mock(CommandBatchCursor.class); - when(commandBatchCursor.getPostBatchResumeToken()).thenReturn(resumeToken); - doNothing().when(commandBatchCursor).close(); + coreCursor = mock(CoreCursor.class); + when(coreCursor.getPostBatchResumeToken()).thenReturn(resumeToken); + doNothing().when(coreCursor).close(any()); - newCommandBatchCursor = mock(CommandBatchCursor.class); - when(newCommandBatchCursor.getPostBatchResumeToken()).thenReturn(resumeToken); - when(newCommandBatchCursor.next()).thenReturn(RESULT_FROM_NEW_CURSOR); - doNothing().when(newCommandBatchCursor).close(); + newCoreCursor = mock(CoreCursor.class); + when(newCoreCursor.getPostBatchResumeToken()).thenReturn(resumeToken); + when(newCoreCursor.next(any())).thenReturn(RESULT_FROM_NEW_CURSOR); + doNothing().when(newCoreCursor).close(any()); newChangeStreamCursor = mock(ChangeStreamBatchCursor.class); - when(newChangeStreamCursor.getWrapped()).thenReturn(newCommandBatchCursor); + when(newChangeStreamCursor.getWrapped()).thenReturn(newCoreCursor); changeStreamOperation = mock(ChangeStreamOperation.class); when(changeStreamOperation.getDecoder()).thenReturn(new DocumentCodec()); doNothing().when(changeStreamOperation).setChangeStreamOptionsForResume(resumeToken, maxWireVersion); - when(changeStreamOperation.execute(readBinding)).thenReturn(newChangeStreamCursor); + when(changeStreamOperation.execute(eq(readBinding), any())).thenReturn(newChangeStreamCursor); } + + private void assertTimeoutWasRefreshedForOperation(final TimeoutContext timeoutContextUsedForOperation) { + assertNotNull(timeoutContextUsedForOperation.getTimeout(), "TimeoutMs was not set"); + timeoutContextUsedForOperation.getTimeout().run(TimeUnit.MILLISECONDS, () -> { + Assertions.fail("Non-infinite timeout was not expected to be refreshed to infinity"); + }, + (remainingMs) -> { + int allowedDifference = 20; + boolean originalAndRefreshedTimeoutDifference = TIMEOUT_MILLISECONDS - remainingMs < allowedDifference; + assertTrue(originalAndRefreshedTimeoutDifference, format("Timeout was expected to be refreshed " + + "to original timeout: %d, but remaining time was: %d. Allowed difference was: %d ", + TIMEOUT_MILLISECONDS, + remainingMs, + allowedDifference)); + }, + () -> { + Assertions.fail("Timeout was expected to be refreshed"); + }); + } + + private void assertTimeoutWasRefreshedForOperation(final Consumer> capturerConsumer) { + assertTimeoutWasRefreshedForOperation(captureTimeoutContext(capturerConsumer)); + } + + private static TimeoutContext captureTimeoutContext(final Consumer> capturerConsumer) { + ArgumentCaptor operationContextCaptor = ArgumentCaptor.forClass(OperationContext.class); + capturerConsumer.accept(operationContextCaptor); + TimeoutContext timeoutContextUsedForOperation = operationContextCaptor.getValue().getTimeoutContext(); + return timeoutContextUsedForOperation; + } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/CommandBatchCursorSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/CommandBatchCursorSpecification.groovy index c95a119134a..c81c9ad759c 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/CommandBatchCursorSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/CommandBatchCursorSpecification.groovy @@ -30,10 +30,12 @@ import com.mongodb.connection.ServerConnectionState import com.mongodb.connection.ServerDescription import com.mongodb.connection.ServerType import com.mongodb.connection.ServerVersion +import com.mongodb.internal.IgnorableRequestContext import com.mongodb.internal.TimeoutContext import com.mongodb.internal.TimeoutSettings import com.mongodb.internal.binding.ConnectionSource import com.mongodb.internal.connection.Connection +import com.mongodb.internal.connection.NoOpSessionContext import com.mongodb.internal.connection.OperationContext import org.bson.BsonArray import org.bson.BsonDocument @@ -58,7 +60,8 @@ class CommandBatchCursorSpecification extends Specification { def initialConnection = referenceCountedConnection() def connection = referenceCountedConnection() def connectionSource = getConnectionSource(connection) - def timeoutContext = connectionSource.getOperationContext().getTimeoutContext() + def operationContext = getOperationContext() + def timeoutContext = operationContext.getTimeoutContext() def firstBatch = createCommandResult([]) def expectedCommand = new BsonDocument('getMore': new BsonInt64(CURSOR_ID)) @@ -70,11 +73,11 @@ class CommandBatchCursorSpecification extends Specification { def reply = getMoreResponse([], 0) when: - def cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, firstBatch, batchSize, maxTimeMS, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new CommandCoreCursor<>(firstBatch, batchSize, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, maxTimeMS, operationContext, commandCoreCursor) then: - 1 * timeoutContext.setMaxTimeOverride(*_) + 1 * timeoutContext.withMaxTimeOverride(*_) when: cursor.hasNext() @@ -83,7 +86,7 @@ class CommandBatchCursorSpecification extends Specification { 1 * connection.command(NAMESPACE.getDatabaseName(), expectedCommand, *_) >> reply then: - !cursor.isClosed() + !commandCoreCursor.isClosed() when: cursor.close() @@ -106,8 +109,8 @@ class CommandBatchCursorSpecification extends Specification { def serverVersion = new ServerVersion([3, 6, 0]) def connection = referenceCountedConnection(serverVersion) def connectionSource = getConnectionSource(connection) - def cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new CommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) when: cursor.close() @@ -134,8 +137,8 @@ class CommandBatchCursorSpecification extends Specification { when: def firstBatch = createCommandResult(FIRST_BATCH, 0) - def cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new CommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) then: cursor.next() == FIRST_BATCH @@ -146,7 +149,7 @@ class CommandBatchCursorSpecification extends Specification { then: // Unlike the AsyncCommandBatchCursor - the cursor isn't automatically closed - !cursor.isClosed() + !commandCoreCursor.isClosed() } def 'should handle getMore when there are empty results but there is a cursor'() { @@ -158,8 +161,8 @@ class CommandBatchCursorSpecification extends Specification { when: def firstBatch = createCommandResult([], CURSOR_ID) - def cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new CommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) def batch = cursor.next() then: @@ -212,8 +215,8 @@ class CommandBatchCursorSpecification extends Specification { def firstBatch = createCommandResult() when: - def cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new CommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) List batch = cursor.next() then: @@ -245,7 +248,7 @@ class CommandBatchCursorSpecification extends Specification { connectionB.getCount() == 0 initialConnection.getCount() == 0 connectionSource.getCount() == 0 - cursor.isClosed() + commandCoreCursor.isClosed() where: serverType | responseCursorId @@ -264,8 +267,9 @@ class CommandBatchCursorSpecification extends Specification { def connectionSource = getConnectionSource(connectionA, connectionB) when: - def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, createCommandResult(FIRST_BATCH, 42), 0, 0, CODEC, + def commandCoreCursor = new CommandCoreCursor<>(createCommandResult(FIRST_BATCH, 42), 0, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) def batch = cursor.next() then: @@ -300,8 +304,8 @@ class CommandBatchCursorSpecification extends Specification { def firstBatch = createCommandResult() when: - def cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new CommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) def batch = cursor.next() then: @@ -328,7 +332,7 @@ class CommandBatchCursorSpecification extends Specification { then: connectionA.getCount() == 0 - cursor.isClosed() + commandCoreCursor.isClosed() where: serverType << [ServerType.LOAD_BALANCER, ServerType.STANDALONE] @@ -339,14 +343,14 @@ class CommandBatchCursorSpecification extends Specification { def initialConnection = referenceCountedConnection() def connectionSource = getConnectionSourceWithResult(ServerType.STANDALONE) { throw MONGO_EXCEPTION } def firstBatch = createCommandResult() - def cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new CommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) when: cursor.close() then: - cursor.isClosed() + commandCoreCursor.isClosed() initialConnection.getCount() == 0 connectionSource.getCount() == 0 } @@ -360,8 +364,8 @@ class CommandBatchCursorSpecification extends Specification { when: def firstBatch = createCommandResult() - def cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new CommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) then: cursor.next() @@ -387,9 +391,8 @@ class CommandBatchCursorSpecification extends Specification { when: def firstBatch = createCommandResult() - def cursor = new CommandBatchCursor<>(TimeoutMode.CURSOR_LIFETIME, firstBatch, 0, 0, CODEC, - null, connectionSource, initialConnection) - + def commandCoreCursor = new CommandCoreCursor<>(firstBatch, 0, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 0, operationContext, commandCoreCursor) then: connectionSource.getCount() == 1 @@ -442,13 +445,13 @@ class CommandBatchCursorSpecification extends Specification { } def connectionSource = Stub(ConnectionSource) { getServerApi() >> null - getConnection() >> { connection } + getConnection(_) >> { connection } } connectionSource.retain() >> connectionSource def initialResults = createCommandResult([]) - def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, initialResults, 2, 100, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new CommandCoreCursor<>(initialResults, 2, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 100, operationContext, commandCoreCursor) when: cursor.close() @@ -467,14 +470,14 @@ class CommandBatchCursorSpecification extends Specification { given: def initialConnection = referenceCountedConnection() def connectionSource = Stub(ConnectionSource) { - getConnection() >> { throw new MongoSocketOpenException("can't open socket", SERVER_ADDRESS, new IOException()) } + getConnection(_) >> { throw new MongoSocketOpenException("can't open socket", SERVER_ADDRESS, new IOException()) } getServerApi() >> null } connectionSource.retain() >> connectionSource def initialResults = createCommandResult([]) - def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, initialResults, 2, 100, CODEC, - null, connectionSource, initialConnection) + def commandCoreCursor = new CommandCoreCursor<>(initialResults, 2, CODEC, null, connectionSource, initialConnection) + def cursor = new CommandBatchCursor(TimeoutMode.CURSOR_LIFETIME, 100, operationContext, commandCoreCursor) when: cursor.close() @@ -573,12 +576,7 @@ class CommandBatchCursorSpecification extends Specification { .state(ServerConnectionState.CONNECTED) .build() } - OperationContext operationContext = Mock(OperationContext) - def timeoutContext = Spy(new TimeoutContext(TimeoutSettings.create( - MongoClientSettings.builder().timeout(3, TimeUnit.SECONDS).build()))) - operationContext.getTimeoutContext() >> timeoutContext - mock.getOperationContext() >> operationContext - mock.getConnection() >> { + mock.getConnection(_ as OperationContext) >> { if (counter == 0) { throw new IllegalStateException('Tried to use released ConnectionSource') } @@ -605,4 +603,13 @@ class CommandBatchCursorSpecification extends Specification { mock } + OperationContext getOperationContext() { + def timeoutContext = Spy(new TimeoutContext(TimeoutSettings.create( + MongoClientSettings.builder().timeout(3, TimeUnit.SECONDS).build()))) + Spy(new OperationContext( + IgnorableRequestContext.INSTANCE, + NoOpSessionContext.INSTANCE, + timeoutContext, null)) + } + } diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/CommandBatchCursorTest.java b/driver-core/src/test/unit/com/mongodb/internal/operation/CommandBatchCursorTest.java index c3bec291432..78f65e04328 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/CommandBatchCursorTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/CommandBatchCursorTest.java @@ -17,216 +17,159 @@ package com.mongodb.internal.operation; import com.mongodb.MongoClientSettings; -import com.mongodb.MongoNamespace; -import com.mongodb.MongoOperationTimeoutException; -import com.mongodb.MongoSocketException; -import com.mongodb.ServerAddress; import com.mongodb.client.cursor.TimeoutMode; -import com.mongodb.connection.ConnectionDescription; -import com.mongodb.connection.ServerDescription; -import com.mongodb.connection.ServerType; -import com.mongodb.connection.ServerVersion; +import com.mongodb.internal.IgnorableRequestContext; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.TimeoutSettings; -import com.mongodb.internal.binding.ConnectionSource; -import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.NoOpSessionContext; import com.mongodb.internal.connection.OperationContext; -import org.bson.BsonArray; -import org.bson.BsonDocument; -import org.bson.BsonInt32; -import org.bson.BsonInt64; -import org.bson.BsonString; import org.bson.Document; -import org.bson.codecs.Decoder; -import org.bson.codecs.DocumentCodec; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; import java.time.Duration; -import static com.mongodb.internal.operation.OperationUnitSpecification.getMaxWireVersionForServerVersion; -import static com.mongodb.internal.thread.InterruptionUtil.interruptAndCreateMongoInterruptedException; import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; class CommandBatchCursorTest { - - private static final MongoNamespace NAMESPACE = new MongoNamespace("test", "test"); - private static final BsonInt64 CURSOR_ID = new BsonInt64(1); - private static final BsonDocument COMMAND_CURSOR_DOCUMENT = new BsonDocument("ok", new BsonInt32(1)) - .append("cursor", - new BsonDocument("ns", new BsonString(NAMESPACE.getFullName())) - .append("id", CURSOR_ID) - .append("firstBatch", new BsonArrayWrapper<>(new BsonArray()))); - - private static final Decoder DOCUMENT_CODEC = new DocumentCodec(); private static final Duration TIMEOUT = Duration.ofMillis(3_000); - - private Connection mockConnection; - private ConnectionDescription mockDescription; - private ConnectionSource connectionSource; private OperationContext operationContext; private TimeoutContext timeoutContext; - private ServerDescription serverDescription; + private CoreCursor coreCursor; @BeforeEach void setUp() { - ServerVersion serverVersion = new ServerVersion(3, 6); - - mockConnection = mock(Connection.class, "connection"); - mockDescription = mock(ConnectionDescription.class); - when(mockDescription.getMaxWireVersion()).thenReturn(getMaxWireVersionForServerVersion(serverVersion.getVersionList())); - when(mockDescription.getServerType()).thenReturn(ServerType.LOAD_BALANCER); - when(mockConnection.getDescription()).thenReturn(mockDescription); - when(mockConnection.retain()).thenReturn(mockConnection); - - connectionSource = mock(ConnectionSource.class); - operationContext = mock(OperationContext.class); - timeoutContext = new TimeoutContext(TimeoutSettings.create( - MongoClientSettings.builder().timeout(TIMEOUT.toMillis(), MILLISECONDS).build())); - serverDescription = mock(ServerDescription.class); - when(operationContext.getTimeoutContext()).thenReturn(timeoutContext); - when(connectionSource.getOperationContext()).thenReturn(operationContext); - when(connectionSource.getConnection()).thenReturn(mockConnection); - when(connectionSource.getServerDescription()).thenReturn(serverDescription); - } - - - @Test - void shouldSkipKillsCursorsCommandWhenNetworkErrorOccurs() { - //given - when(mockConnection.command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any())).thenThrow( - new MongoSocketException("test", new ServerAddress())); - when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); - - CommandBatchCursor commandBatchCursor = createBatchCursor(0); - //when - assertThrows(MongoSocketException.class, commandBatchCursor::next); - - //then - commandBatchCursor.close(); - verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any()); + coreCursor = mock(CoreCursor.class); + timeoutContext = spy(new TimeoutContext(TimeoutSettings.create( + MongoClientSettings.builder().timeout(TIMEOUT.toMillis(), MILLISECONDS).build()))); + operationContext = spy(new OperationContext( + IgnorableRequestContext.INSTANCE, + NoOpSessionContext.INSTANCE, + timeoutContext, + null)); } private CommandBatchCursor createBatchCursor(final long maxTimeMS) { return new CommandBatchCursor<>( TimeoutMode.CURSOR_LIFETIME, - COMMAND_CURSOR_DOCUMENT, - 0, maxTimeMS, - DOCUMENT_CODEC, - null, - connectionSource, - mockConnection); + operationContext, + coreCursor); } @Test - void shouldNotSkipKillsCursorsCommandWhenTimeoutExceptionDoesNotHaveNetworkErrorCause() { + @SuppressWarnings("try") + void nextShouldUseTimeoutContextWithMaxTimeOverride() { //given - when(mockConnection.command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any())).thenThrow( - new MongoOperationTimeoutException("test")); - when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + long maxTimeMS = 10; + com.mongodb.assertions.Assertions.assertTrue(maxTimeMS < TIMEOUT.toMillis()); - CommandBatchCursor commandBatchCursor = createBatchCursor(0); + try (CommandBatchCursor commandBatchCursor = createBatchCursor(maxTimeMS)) { - //when - assertThrows(MongoOperationTimeoutException.class, commandBatchCursor::next); + //when + commandBatchCursor.next(); - commandBatchCursor.close(); + // then verify that the `maxTimeMS` override was applied + ArgumentCaptor operationContextArgumentCaptor = ArgumentCaptor.forClass(OperationContext.class); + verify(coreCursor).next(operationContextArgumentCaptor.capture()); + OperationContext operationContextForNext = operationContextArgumentCaptor.getValue(); + operationContextForNext.getTimeoutContext() + .runMaxTimeMS(remainingMillis -> assertEquals(maxTimeMS, remainingMillis, "MaxTieMs override not applied")); + } + } + @Test + @SuppressWarnings("try") + void tryNextShouldUseTimeoutContextWithMaxTimeOverride() { + //given + long maxTimeMS = 10; + com.mongodb.assertions.Assertions.assertTrue(maxTimeMS < TIMEOUT.toMillis()); + + try (CommandBatchCursor commandBatchCursor = createBatchCursor(maxTimeMS)) { - //then - verify(mockConnection, times(2)).command(any(), - any(), any(), any(), any(), any()); - verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), - argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any()); - verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), - argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any()); + //when + commandBatchCursor.tryNext(); + + // then verify that the `maxTimeMS` override was applied + ArgumentCaptor operationContextArgumentCaptor = ArgumentCaptor.forClass(OperationContext.class); + verify(coreCursor).tryNext(operationContextArgumentCaptor.capture()); + OperationContext operationContextForNext = operationContextArgumentCaptor.getValue(); + operationContextForNext.getTimeoutContext() + .runMaxTimeMS(remainingMillis -> assertEquals(maxTimeMS, remainingMillis, "MaxTieMs override not applied")); + } } @Test - void shouldSkipKillsCursorsCommandWhenTimeoutExceptionHaveNetworkErrorCause() { + @SuppressWarnings("try") + void nextShouldNotUseTimeoutContextWithMaxTimeOverride() { //given - when(mockConnection.command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any())).thenThrow( - new MongoOperationTimeoutException("test", new MongoSocketException("test", new ServerAddress()))); - when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); - - CommandBatchCursor commandBatchCursor = createBatchCursor(0); - - //when - assertThrows(MongoOperationTimeoutException.class, commandBatchCursor::next); - commandBatchCursor.close(); - - //then - verify(mockConnection, times(1)).command(any(), - any(), any(), any(), any(), any()); - verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), - argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any()); - verify(mockConnection, never()).command(eq(NAMESPACE.getDatabaseName()), - argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any()); + int maxTimeMS = 0; + try (CommandBatchCursor commandBatchCursor = createBatchCursor(maxTimeMS)) { + + //when + commandBatchCursor.next(); + + // then verify that the `maxTimeMS` override was not applied + ArgumentCaptor operationContextArgumentCaptor = ArgumentCaptor.forClass(OperationContext.class); + verify(coreCursor).next(operationContextArgumentCaptor.capture()); + OperationContext operationContextForNext = operationContextArgumentCaptor.getValue(); + operationContextForNext.getTimeoutContext().runMaxTimeMS(remainingMillis -> { + // verify that the `maxTimeMS` override was reset + assertTrue(remainingMillis > maxTimeMS); + assertTrue(remainingMillis <= TIMEOUT.toMillis()); + }); + } } @Test @SuppressWarnings("try") - void closeShouldResetTimeoutContextToDefaultMaxTime() { - long maxTimeMS = 10; - com.mongodb.assertions.Assertions.assertTrue(maxTimeMS < TIMEOUT.toMillis()); + void tryNextShouldNotUseTimeoutContextWithMaxTimeOverride() { + //given + int maxTimeMS = 0; try (CommandBatchCursor commandBatchCursor = createBatchCursor(maxTimeMS)) { - // verify that the `maxTimeMS` override was applied - timeoutContext.runMaxTimeMS(remainingMillis -> assertTrue(remainingMillis <= maxTimeMS)); - } catch (Exception e) { - throw new RuntimeException(e); + + //when + commandBatchCursor.tryNext(); + + // then verify that the `maxTimeMS` override was not applied + ArgumentCaptor operationContextArgumentCaptor = ArgumentCaptor.forClass(OperationContext.class); + verify(coreCursor).tryNext(operationContextArgumentCaptor.capture()); + OperationContext operationContextForNext = operationContextArgumentCaptor.getValue(); + operationContextForNext.getTimeoutContext().runMaxTimeMS(remainingMillis -> { + // verify that the `maxTimeMS` override was reset + assertTrue(remainingMillis > maxTimeMS); + assertTrue(remainingMillis <= TIMEOUT.toMillis()); + }); } - timeoutContext.runMaxTimeMS(remainingMillis -> { - // verify that the `maxTimeMS` override was reset - assertTrue(remainingMillis > maxTimeMS); - assertTrue(remainingMillis <= TIMEOUT.toMillis()); - }); } - @ParameterizedTest - @ValueSource(booleans = {false, true}) - void closeShouldNotResetOriginalTimeout(final boolean disableTimeoutResetWhenClosing) { - Duration thirdOfTimeout = TIMEOUT.dividedBy(3); - com.mongodb.assertions.Assertions.assertTrue(thirdOfTimeout.toMillis() > 0); - try (CommandBatchCursor commandBatchCursor = createBatchCursor(0)) { - if (disableTimeoutResetWhenClosing) { - commandBatchCursor.disableTimeoutResetWhenClosing(); - } - try { - Thread.sleep(thirdOfTimeout.toMillis()); - } catch (InterruptedException e) { - throw interruptAndCreateMongoInterruptedException(null, e); - } - when(mockConnection.release()).then(invocation -> { - Thread.sleep(thirdOfTimeout.toMillis()); - return null; + @ParameterizedTest(name = "closeShouldResetTimeoutContextToDefaultMaxTime with maxTimeMS={0}") + @SuppressWarnings("try") + @ValueSource(ints = {10, 0}) + void closeShouldResetTimeoutContextToDefaultMaxTime(final int maxTimeMS) { + //given + try (CommandBatchCursor commandBatchCursor = createBatchCursor(maxTimeMS)) { + + //when + commandBatchCursor.close(); + + // then verify that the `maxTimeMS` override was not applied + ArgumentCaptor operationContextArgumentCaptor = ArgumentCaptor.forClass(OperationContext.class); + verify(coreCursor).close(operationContextArgumentCaptor.capture()); + OperationContext operationContextForNext = operationContextArgumentCaptor.getValue(); + operationContextForNext.getTimeoutContext().runMaxTimeMS(remainingMillis -> { + // verify that the `maxTimeMS` override was reset + assertTrue(remainingMillis > maxTimeMS); + assertTrue(remainingMillis <= TIMEOUT.toMillis()); }); - } catch (Exception e) { - throw new RuntimeException(e); } - verify(mockConnection, times(1)).release(); - // at this point at least (2 * thirdOfTimeout) have passed - com.mongodb.assertions.Assertions.assertNotNull(timeoutContext.getTimeout()).run( - MILLISECONDS, - com.mongodb.assertions.Assertions::fail, - remainingMillis -> { - // Verify that the original timeout has not been intact. - // If `close` had reset it, we would have observed more than `thirdOfTimeout` left. - assertTrue(remainingMillis <= thirdOfTimeout.toMillis()); - }, - Assertions::fail); } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/CommandCoreCursorTest.java b/driver-core/src/test/unit/com/mongodb/internal/operation/CommandCoreCursorTest.java new file mode 100644 index 00000000000..b6b7e22a616 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/CommandCoreCursorTest.java @@ -0,0 +1,177 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.operation; + +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoNamespace; +import com.mongodb.MongoOperationTimeoutException; +import com.mongodb.MongoSocketException; +import com.mongodb.ServerAddress; +import com.mongodb.connection.ConnectionDescription; +import com.mongodb.connection.ServerDescription; +import com.mongodb.connection.ServerType; +import com.mongodb.connection.ServerVersion; +import com.mongodb.internal.IgnorableRequestContext; +import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.TimeoutSettings; +import com.mongodb.internal.binding.ConnectionSource; +import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.NoOpSessionContext; +import com.mongodb.internal.connection.OperationContext; +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonInt64; +import org.bson.BsonString; +import org.bson.Document; +import org.bson.codecs.Decoder; +import org.bson.codecs.DocumentCodec; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.time.Duration; + +import static com.mongodb.internal.operation.OperationUnitSpecification.getMaxWireVersionForServerVersion; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class CommandCoreCursorTest { + private static final MongoNamespace NAMESPACE = new MongoNamespace("test", "test"); + private static final BsonInt64 CURSOR_ID = new BsonInt64(1); + private static final BsonDocument COMMAND_CURSOR_DOCUMENT = new BsonDocument("ok", new BsonInt32(1)) + .append("cursor", + new BsonDocument("ns", new BsonString(NAMESPACE.getFullName())) + .append("id", CURSOR_ID) + .append("firstBatch", new BsonArrayWrapper<>(new BsonArray()))); + + private static final Decoder DOCUMENT_CODEC = new DocumentCodec(); + private static final Duration TIMEOUT = Duration.ofMillis(3_000); + + private Connection mockConnection; + private ConnectionDescription mockDescription; + private ConnectionSource connectionSource; + private OperationContext operationContext; + private TimeoutContext timeoutContext; + private ServerDescription serverDescription; + + @BeforeEach + void setUp() { + ServerVersion serverVersion = new ServerVersion(3, 6); + + mockConnection = mock(Connection.class, "connection"); + mockDescription = mock(ConnectionDescription.class); + when(mockDescription.getMaxWireVersion()).thenReturn(getMaxWireVersionForServerVersion(serverVersion.getVersionList())); + when(mockDescription.getServerType()).thenReturn(ServerType.LOAD_BALANCER); + when(mockConnection.getDescription()).thenReturn(mockDescription); + when(mockConnection.retain()).thenReturn(mockConnection); + + connectionSource = mock(ConnectionSource.class); + timeoutContext = spy(new TimeoutContext(TimeoutSettings.create( + MongoClientSettings.builder().timeout(TIMEOUT.toMillis(), MILLISECONDS).build()))); + operationContext = spy(new OperationContext( + IgnorableRequestContext.INSTANCE, + NoOpSessionContext.INSTANCE, + timeoutContext, + null)); + serverDescription = mock(ServerDescription.class); + when(operationContext.getTimeoutContext()).thenReturn(timeoutContext); + when(connectionSource.getConnection(any())).thenReturn(mockConnection); + when(connectionSource.getServerDescription()).thenReturn(serverDescription); + } + + + @Test + void shouldSkipKillsCursorsCommandWhenNetworkErrorOccurs() { + //given + when(mockConnection.command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any())).thenThrow( + new MongoSocketException("test", new ServerAddress())); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + + CoreCursor coreCursor = createCoreCursor(); + //when + assertThrows(MongoSocketException.class, () -> coreCursor.next(operationContext)); + + //then + coreCursor.close(operationContext); + verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any()); + } + + @Test + void shouldNotSkipKillsCursorsCommandWhenTimeoutExceptionDoesNotHaveNetworkErrorCause() { + //given + when(mockConnection.command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any())).thenThrow( + new MongoOperationTimeoutException("test")); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + + CoreCursor coreCursor = createCoreCursor(); + + //when + assertThrows(MongoOperationTimeoutException.class, () -> coreCursor.next(operationContext)); + + coreCursor.close(operationContext); + + + //then + verify(mockConnection, times(2)).command(any(), + any(), any(), any(), any(), any()); + verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any()); + verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any()); + } + + @Test + void shouldSkipKillsCursorsCommandWhenTimeoutExceptionHaveNetworkErrorCause() { + //given + when(mockConnection.command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any())).thenThrow( + new MongoOperationTimeoutException("test", new MongoSocketException("test", new ServerAddress()))); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + + CoreCursor coreCursor = createCoreCursor(); + + //when + assertThrows(MongoOperationTimeoutException.class, () -> coreCursor.next(operationContext)); + coreCursor.close(operationContext); + + //then + verify(mockConnection, times(1)).command(any(), + any(), any(), any(), any(), any()); + verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any()); + verify(mockConnection, never()).command(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any()); + } + + private CoreCursor createCoreCursor() { + return new CommandCoreCursor<>( + COMMAND_CURSOR_DOCUMENT, + 0, + DOCUMENT_CODEC, + null, + connectionSource, + mockConnection); + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/CommitTransactionOperationUnitSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/CommitTransactionOperationUnitSpecification.groovy index 21ae1c4dfb9..75ed9e6c5f3 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/CommitTransactionOperationUnitSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/CommitTransactionOperationUnitSpecification.groovy @@ -21,8 +21,10 @@ import com.mongodb.MongoTimeoutException import com.mongodb.ReadConcern import com.mongodb.WriteConcern import com.mongodb.async.FutureResultCallback +import com.mongodb.internal.async.SingleResultCallback import com.mongodb.internal.binding.AsyncWriteBinding import com.mongodb.internal.binding.WriteBinding +import com.mongodb.internal.connection.OperationContext import com.mongodb.internal.session.SessionContext import static com.mongodb.ClusterFixture.OPERATION_CONTEXT @@ -35,13 +37,12 @@ class CommitTransactionOperationUnitSpecification extends OperationUnitSpecifica hasActiveTransaction() >> true } def writeBinding = Stub(WriteBinding) { - getWriteConnectionSource() >> { throw new MongoTimeoutException('Time out!') } - getOperationContext() >> OPERATION_CONTEXT.withSessionContext(sessionContext) + getWriteConnectionSource(_) >> { throw new MongoTimeoutException('Time out!') } } def operation = new CommitTransactionOperation(WriteConcern.ACKNOWLEDGED) when: - operation.execute(writeBinding) + operation.execute(writeBinding, OPERATION_CONTEXT.withSessionContext(sessionContext)) then: def e = thrown(MongoTimeoutException) @@ -55,16 +56,15 @@ class CommitTransactionOperationUnitSpecification extends OperationUnitSpecifica hasActiveTransaction() >> true } def writeBinding = Stub(AsyncWriteBinding) { - getWriteConnectionSource(_) >> { - it[0].onResult(null, new MongoTimeoutException('Time out!')) + getWriteConnectionSource(_ as OperationContext, _ as SingleResultCallback) >> { + it[1].onResult(null, new MongoTimeoutException('Time out!')) } - getOperationContext() >> OPERATION_CONTEXT.withSessionContext(sessionContext) } def operation = new CommitTransactionOperation(WriteConcern.ACKNOWLEDGED) def callback = new FutureResultCallback() when: - operation.executeAsync(writeBinding, callback) + operation.executeAsync(writeBinding, OPERATION_CONTEXT.withSessionContext(sessionContext), callback) callback.get() then: diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/ListCollectionsOperationTest.java b/driver-core/src/test/unit/com/mongodb/internal/operation/ListCollectionsOperationTest.java index 12a964db625..de1bfe405ed 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/ListCollectionsOperationTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/ListCollectionsOperationTest.java @@ -99,7 +99,7 @@ void authorizedCollectionsIsFalseByDefault() { } private BsonDocument executeOperationAndCaptureCommand() { - operation.execute(mocks.readBinding()); + operation.execute(mocks.readBinding(), OPERATION_CONTEXT); ArgumentCaptor commandCaptor = forClass(BsonDocument.class); verify(mocks.connection()).command(any(), commandCaptor.capture(), any(), any(), any(), any()); return commandCaptor.getValue(); @@ -108,9 +108,7 @@ private BsonDocument executeOperationAndCaptureCommand() { private static Mocks mocks(final MongoNamespace namespace) { Mocks result = new Mocks(); result.readBinding(mock(ReadBinding.class, bindingMock -> { - when(bindingMock.getOperationContext()).thenReturn(OPERATION_CONTEXT); ConnectionSource connectionSource = mock(ConnectionSource.class, connectionSourceMock -> { - when(connectionSourceMock.getOperationContext()).thenReturn(OPERATION_CONTEXT); when(connectionSourceMock.release()).thenReturn(1); ServerAddress serverAddress = new ServerAddress(); result.connection(mock(Connection.class, connectionMock -> { @@ -119,7 +117,7 @@ private static Mocks mocks(final MongoNamespace namespace) { when(connectionMock.getDescription()).thenReturn(connectionDescription); when(connectionMock.command(any(), any(), any(), any(), any(), any())).thenReturn(cursorDoc(namespace)); })); - when(connectionSourceMock.getConnection()).thenReturn(result.connection()); + when(connectionSourceMock.getConnection(any())).thenReturn(result.connection()); ServerDescription serverDescription = ServerDescription.builder() .address(serverAddress) .type(ServerType.STANDALONE) @@ -128,7 +126,7 @@ private static Mocks mocks(final MongoNamespace namespace) { when(connectionSourceMock.getServerDescription()).thenReturn(serverDescription); when(connectionSourceMock.getReadPreference()).thenReturn(ReadPreference.primary()); }); - when(bindingMock.getReadConnectionSource()).thenReturn(connectionSource); + when(bindingMock.getReadConnectionSource(any())).thenReturn(connectionSource); })); return result; } diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/OperationUnitSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/OperationUnitSpecification.groovy index d298112656e..ec5cb74156f 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/OperationUnitSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/OperationUnitSpecification.groovy @@ -110,18 +110,15 @@ class OperationUnitSpecification extends Specification { } def connectionSource = Stub(ConnectionSource) { - getConnection() >> connection + getConnection(_) >> connection getReadPreference() >> readPreference - getOperationContext() >> operationContext } def readBinding = Stub(ReadBinding) { - getReadConnectionSource() >> connectionSource + getReadConnectionSource(_) >> connectionSource getReadPreference() >> readPreference - getOperationContext() >> operationContext } def writeBinding = Stub(WriteBinding) { - getWriteConnectionSource() >> connectionSource - getOperationContext() >> operationContext + getWriteConnectionSource(_) >> connectionSource } if (checkCommand) { @@ -144,9 +141,9 @@ class OperationUnitSpecification extends Specification { 1 * connection.release() if (operation instanceof ReadOperation) { - operation.execute(readBinding) + operation.execute(readBinding, operationContext) } else if (operation instanceof WriteOperation) { - operation.execute(writeBinding) + operation.execute(writeBinding, operationContext) } } @@ -167,18 +164,15 @@ class OperationUnitSpecification extends Specification { } def connectionSource = Stub(AsyncConnectionSource) { - getConnection(_) >> { it[0].onResult(connection, null) } + getConnection(_, _) >> { it[1].onResult(connection, null) } getReadPreference() >> readPreference - getOperationContext() >> getOperationContext() >> operationContext } def readBinding = Stub(AsyncReadBinding) { - getReadConnectionSource(_) >> { it[0].onResult(connectionSource, null) } + getReadConnectionSource(_, _) >> { it[1].onResult(connectionSource, null) } getReadPreference() >> readPreference - getOperationContext() >> operationContext } def writeBinding = Stub(AsyncWriteBinding) { - getWriteConnectionSource(_) >> { it[0].onResult(connectionSource, null) } - getOperationContext() >> operationContext + getWriteConnectionSource(_, _) >> { it[1].onResult(connectionSource, null) } } def callback = new FutureResultCallback() @@ -202,9 +196,9 @@ class OperationUnitSpecification extends Specification { 1 * connection.release() if (operation instanceof ReadOperation) { - operation.executeAsync(readBinding, callback) + operation.executeAsync(readBinding, operationContext, callback) } else if (operation instanceof WriteOperation) { - operation.executeAsync(writeBinding, callback) + operation.executeAsync(writeBinding, operationContext, callback) } try { callback.get(1000, TimeUnit.MILLISECONDS) diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/SyncOperationHelperSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/SyncOperationHelperSpecification.groovy index df2d54bfb9d..bd9bd2f2578 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/SyncOperationHelperSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/SyncOperationHelperSpecification.groovy @@ -53,21 +53,19 @@ class SyncOperationHelperSpecification extends Specification { def connection = Mock(Connection) def function = Stub(CommandWriteTransformer) def connectionSource = Stub(ConnectionSource) { - getConnection() >> connection - getOperationContext() >> OPERATION_CONTEXT + getConnection(_) >> connection } def writeBinding = Stub(WriteBinding) { - getWriteConnectionSource() >> connectionSource - getOperationContext() >> OPERATION_CONTEXT + getWriteConnectionSource(_) >> connectionSource } def connectionDescription = Stub(ConnectionDescription) when: - executeCommand(writeBinding, dbName, command, decoder, function) + executeCommand(writeBinding, OPERATION_CONTEXT, dbName, command, decoder, function) then: _ * connection.getDescription() >> connectionDescription - 1 * connection.command(dbName, command, _, primary(), decoder, OPERATION_CONTEXT) >> new BsonDocument() + 1 * connection.command(dbName, command, _, primary(), decoder, _) >> new BsonDocument() 1 * connection.release() } @@ -94,24 +92,22 @@ class SyncOperationHelperSpecification extends Specification { } } def connectionSource = Stub(ConnectionSource) { - _ * getConnection() >> connection + _ * getConnection(_) >> connection _ * getServerDescription() >> Stub(ServerDescription) { getLogicalSessionTimeoutMinutes() >> 1 } - getOperationContext() >> operationContext } def writeBinding = Stub(WriteBinding) { - getWriteConnectionSource() >> connectionSource - getOperationContext() >> operationContext + getWriteConnectionSource(_) >> connectionSource } when: - executeRetryableWrite(writeBinding, dbName, primary(), + executeRetryableWrite(writeBinding, operationContext, dbName, primary(), NoOpFieldNameValidator.INSTANCE, decoder, commandCreator, FindAndModifyHelper.transformer()) { cmd -> cmd } then: - 2 * connection.command(dbName, command, _, primary(), decoder, operationContext) >> { results.poll() } + 2 * connection.command(dbName, command, _, primary(), decoder, _) >> { results.poll() } then: def ex = thrown(MongoWriteConcernException) @@ -127,22 +123,20 @@ class SyncOperationHelperSpecification extends Specification { def function = Stub(CommandReadTransformer) def connection = Mock(Connection) def connectionSource = Stub(ConnectionSource) { - getConnection() >> connection + getConnection(_) >> connection getReadPreference() >> readPreference - getOperationContext() >> OPERATION_CONTEXT } def readBinding = Stub(ReadBinding) { - getReadConnectionSource() >> connectionSource - getOperationContext() >> OPERATION_CONTEXT + getReadConnectionSource(_) >> connectionSource } def connectionDescription = Stub(ConnectionDescription) when: - executeRetryableRead(readBinding, dbName, commandCreator, decoder, function, false) + executeRetryableRead(readBinding, OPERATION_CONTEXT, dbName, commandCreator, decoder, function, false) then: _ * connection.getDescription() >> connectionDescription - 1 * connection.command(dbName, command, _, readPreference, decoder, OPERATION_CONTEXT) >> new BsonDocument() + 1 * connection.command(dbName, command, _, readPreference, decoder, _) >> new BsonDocument() 1 * connection.release() where: diff --git a/driver-legacy/src/main/com/mongodb/LegacyMixedBulkWriteOperation.java b/driver-legacy/src/main/com/mongodb/LegacyMixedBulkWriteOperation.java index 95990833f00..3c4be0dcc78 100644 --- a/driver-legacy/src/main/com/mongodb/LegacyMixedBulkWriteOperation.java +++ b/driver-legacy/src/main/com/mongodb/LegacyMixedBulkWriteOperation.java @@ -26,6 +26,7 @@ import com.mongodb.internal.bulk.InsertRequest; import com.mongodb.internal.bulk.UpdateRequest; import com.mongodb.internal.bulk.WriteRequest; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.operation.MixedBulkWriteOperation; import com.mongodb.internal.operation.WriteOperation; import com.mongodb.lang.Nullable; @@ -98,9 +99,9 @@ public String getCommandName() { } @Override - public WriteConcernResult execute(final WriteBinding binding) { + public WriteConcernResult execute(final WriteBinding binding, final OperationContext operationContext) { try { - BulkWriteResult result = wrappedOperation.bypassDocumentValidation(bypassDocumentValidation).execute(binding); + BulkWriteResult result = wrappedOperation.bypassDocumentValidation(bypassDocumentValidation).execute(binding, operationContext); if (result.wasAcknowledged()) { return translateBulkWriteResult(result); } else { @@ -112,7 +113,7 @@ public WriteConcernResult execute(final WriteBinding binding) { } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { + public void executeAsync(final AsyncWriteBinding binding,final OperationContext operationContext, final SingleResultCallback callback) { throw new UnsupportedOperationException("This operation is sync only"); } diff --git a/driver-legacy/src/main/com/mongodb/MongoClient.java b/driver-legacy/src/main/com/mongodb/MongoClient.java index 09d58e1b493..1b18017db79 100644 --- a/driver-legacy/src/main/com/mongodb/MongoClient.java +++ b/driver-legacy/src/main/com/mongodb/MongoClient.java @@ -859,18 +859,20 @@ private void cleanCursors() { try { ServerCursorAndNamespace cur; while ((cur = orphanedCursors.poll()) != null) { - ReadWriteBinding binding = new SingleServerBinding(delegate.getCluster(), cur.serverCursor.getAddress(), - new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, - new TimeoutContext(getTimeoutSettings()), options.getServerApi())); + OperationContext operationContext = new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, + new TimeoutContext(getTimeoutSettings()), options.getServerApi()); + + ReadWriteBinding binding = new SingleServerBinding(delegate.getCluster(), cur.serverCursor.getAddress()); try { - ConnectionSource source = binding.getReadConnectionSource(); + OperationContext serverSelectionOperationContext = operationContext.withTimeoutContextOverride(TimeoutContext::withComputedServerSelectionTimeoutContextNew); + ConnectionSource source = binding.getReadConnectionSource(serverSelectionOperationContext); try { - Connection connection = source.getConnection(); + Connection connection = source.getConnection(serverSelectionOperationContext); try { BsonDocument killCursorsCommand = new BsonDocument("killCursors", new BsonString(cur.namespace.getCollectionName())) .append("cursors", new BsonArray(singletonList(new BsonInt64(cur.serverCursor.getId())))); connection.command(cur.namespace.getDatabaseName(), killCursorsCommand, NoOpFieldNameValidator.INSTANCE, - ReadPreference.primary(), new BsonDocumentCodec(), source.getOperationContext()); + ReadPreference.primary(), new BsonDocumentCodec(), operationContext); } finally { connection.release(); } diff --git a/driver-legacy/src/test/functional/com/mongodb/DBTest.java b/driver-legacy/src/test/functional/com/mongodb/DBTest.java index 4ce9b3f760b..cf44573a2b4 100644 --- a/driver-legacy/src/test/functional/com/mongodb/DBTest.java +++ b/driver-legacy/src/test/functional/com/mongodb/DBTest.java @@ -31,6 +31,7 @@ import java.util.Locale; import java.util.UUID; +import static com.mongodb.ClusterFixture.OPERATION_CONTEXT; import static com.mongodb.ClusterFixture.disableMaxTimeFailPoint; import static com.mongodb.ClusterFixture.enableMaxTimeFailPoint; import static com.mongodb.ClusterFixture.getBinding; @@ -344,7 +345,7 @@ public void shouldApplyUuidRepresentationToCommandEncodingAndDecoding() { BsonDocument getCollectionInfo(final String collectionName) { return new ListCollectionsOperation<>(getDefaultDatabaseName(), new BsonDocumentCodec()) - .filter(new BsonDocument("name", new BsonString(collectionName))).execute(getBinding()).next().get(0); + .filter(new BsonDocument("name", new BsonString(collectionName))).execute(getBinding(), OPERATION_CONTEXT).next().get(0); } private boolean isCapped(final DBCollection collection) { diff --git a/driver-legacy/src/test/functional/com/mongodb/LegacyMixedBulkWriteOperationSpecification.groovy b/driver-legacy/src/test/functional/com/mongodb/LegacyMixedBulkWriteOperationSpecification.groovy index 42854387e4a..6a9c511c3bc 100644 --- a/driver-legacy/src/test/functional/com/mongodb/LegacyMixedBulkWriteOperationSpecification.groovy +++ b/driver-legacy/src/test/functional/com/mongodb/LegacyMixedBulkWriteOperationSpecification.groovy @@ -182,8 +182,9 @@ class LegacyMixedBulkWriteOperationSpecification extends OperationFunctionalSpec def 'should replace a single document'() { given: def insert = new InsertRequest(new BsonDocument('_id', new BsonInt32(1))) + def binding = getBinding() createBulkWriteOperationForInsert(getNamespace(), true, ACKNOWLEDGED, false, asList(insert)) - .execute(getBinding()) + .execute(binding, ClusterFixture.getOperationContext(binding.getReadPreference())) def replacement = new UpdateRequest(new BsonDocument('_id', new BsonInt32(1)), new BsonDocument('_id', new BsonInt32(1)).append('x', new BsonInt32(1)), REPLACE) diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java index 2e87b3bccf8..55357244fae 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java @@ -21,7 +21,7 @@ import com.mongodb.connection.ClusterType; import com.mongodb.connection.ServerDescription; import com.mongodb.internal.async.SingleResultCallback; -import com.mongodb.internal.async.function.AsyncCallbackSupplier; +import com.mongodb.internal.async.function.AsyncCallbackFunction; import com.mongodb.internal.binding.AbstractReferenceCounted; import com.mongodb.internal.binding.AsyncClusterAwareReadWriteBinding; import com.mongodb.internal.binding.AsyncConnectionSource; @@ -46,13 +46,11 @@ public class ClientSessionBinding extends AbstractReferenceCounted implements As private final AsyncClusterAwareReadWriteBinding wrapped; private final ClientSession session; private final boolean ownsSession; - private final OperationContext operationContext; public ClientSessionBinding(final ClientSession session, final boolean ownsSession, final AsyncClusterAwareReadWriteBinding wrapped) { this.wrapped = notNull("wrapped", wrapped).retain(); this.ownsSession = ownsSession; this.session = notNull("session", session); - this.operationContext = wrapped.getOperationContext().withSessionContext(new AsyncClientSessionContext(session)); } @Override @@ -61,37 +59,38 @@ public ReadPreference getReadPreference() { } @Override - public OperationContext getOperationContext() { - return operationContext; - } - - @Override - public void getReadConnectionSource(final SingleResultCallback callback) { - getConnectionSource(wrapped::getReadConnectionSource, callback); + public void getReadConnectionSource(final OperationContext operationContext, + final SingleResultCallback callback) { + getConnectionSource(wrapped::getReadConnectionSource, operationContext, callback); } @Override public void getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference, + final OperationContext operationContext, final SingleResultCallback callback) { - getConnectionSource(wrappedConnectionSourceCallback -> - wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, wrappedConnectionSourceCallback), + getConnectionSource((opContext, wrappedConnectionSourceCallback) -> + wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, opContext, + wrappedConnectionSourceCallback), + operationContext, callback); } - public void getWriteConnectionSource(final SingleResultCallback callback) { - getConnectionSource(wrapped::getWriteConnectionSource, callback); + @Override + public void getWriteConnectionSource(final OperationContext operationContext, final SingleResultCallback callback) { + getConnectionSource(wrapped::getWriteConnectionSource, operationContext, callback); } - private void getConnectionSource(final AsyncCallbackSupplier connectionSourceSupplier, + private void getConnectionSource(final AsyncCallbackFunction connectionSourceSupplier, + final OperationContext operationContext, final SingleResultCallback callback) { WrappingCallback wrappingCallback = new WrappingCallback(callback); if (!session.hasActiveTransaction()) { - connectionSourceSupplier.get(wrappingCallback); + connectionSourceSupplier.apply(operationContext, wrappingCallback); return; } if (TransactionContext.get(session) == null) { - connectionSourceSupplier.get((source, t) -> { + connectionSourceSupplier.apply(operationContext, (source, t) -> { if (t != null) { wrappingCallback.onResult(null, t); } else { @@ -105,7 +104,7 @@ private void getConnectionSource(final AsyncCallbackSupplier callback) { + public void getConnection(final OperationContext operationContext, final SingleResultCallback callback) { TransactionContext transactionContext = TransactionContext.get(session); if (transactionContext != null && transactionContext.isConnectionPinningRequired()) { AsyncConnection pinnedConnection = transactionContext.getPinnedConnection(); if (pinnedConnection == null) { - wrapped.getConnection((connection, t) -> { + wrapped.getConnection(operationContext, (connection, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -168,7 +162,7 @@ public void getConnection(final SingleResultCallback callback) callback.onResult(pinnedConnection.retain(), null); } } else { - wrapped.getConnection(callback); + wrapped.getConnection(operationContext, callback); } } @@ -193,13 +187,19 @@ public int release() { } } - private final class AsyncClientSessionContext extends ClientSessionContext { + public static final class AsyncClientSessionContext extends ClientSessionContext { private final ClientSession clientSession; + private final ReadConcern inheritedReadConcern; + private final boolean ownsSession; - AsyncClientSessionContext(final ClientSession clientSession) { + AsyncClientSessionContext(final ClientSession clientSession, + final boolean ownsSession, + final ReadConcern inheritedReadConcern) { super(clientSession); this.clientSession = clientSession; + this.ownsSession = ownsSession; + this.inheritedReadConcern = inheritedReadConcern; } @@ -242,7 +242,10 @@ public ReadConcern getReadConcern() { } else if (isSnapshot()) { return ReadConcern.SNAPSHOT; } else { - return wrapped.getOperationContext().getSessionContext().getReadConcern(); + //COMMENT the read concern was specified on wrapped BindingContext was the one that was inherited from either MongoCollection, MongoDatabase, etc. + // since we removed the BindingContext, we can now embedd the parent read concern directly. + //return wrapped.getOperationContext().getSessionContext().getReadConcern(); + return inheritedReadConcern; } } } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MapReducePublisherImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MapReducePublisherImpl.java index 27e69762a09..31260d1329a 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MapReducePublisherImpl.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/MapReducePublisherImpl.java @@ -27,6 +27,8 @@ import com.mongodb.internal.binding.AsyncWriteBinding; import com.mongodb.internal.binding.WriteBinding; import com.mongodb.internal.client.model.FindOptions; +import com.mongodb.internal.connection.OperationContext; +import com.mongodb.internal.operation.AsyncOperations; import com.mongodb.internal.operation.MapReduceAsyncBatchCursor; import com.mongodb.internal.operation.MapReduceBatchCursor; import com.mongodb.internal.operation.MapReduceStatistics; @@ -241,8 +243,8 @@ public String getCommandName() { } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { - operation.executeAsync(binding, callback::onResult); + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback> callback) { + operation.executeAsync(binding, operationContext, callback::onResult); } } @@ -268,8 +270,8 @@ public Void execute(final WriteBinding binding) { } @Override - public void executeAsync(final AsyncWriteBinding binding, final SingleResultCallback callback) { - operation.executeAsync(binding, (result, t) -> callback.onResult(null, t)); + public void executeAsync(final AsyncWriteBinding binding, final OperationContext operationContext, final SingleResultCallback callback) { + operation.executeAsync(binding, operationContext, (result, t) -> callback.onResult(null, t)); } } } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java index 56b0526e4cb..0138295eb32 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/OperationExecutorImpl.java @@ -85,14 +85,19 @@ public Mono execute(final ReadOperation operation, final ReadPrefer return Mono.from(subscriber -> clientSessionHelper.withClientSession(session, this) - .map(clientSession -> getReadWriteBinding(getContext(subscriber), - readPreference, readConcern, clientSession, session == null, operation.getCommandName())) - .flatMap(binding -> { + .flatMap(actualClientSession -> { + AsyncReadWriteBinding binding = + getReadWriteBinding(readPreference, actualClientSession, isImplicitSession(session)); + RequestContext requestContext = getContext(subscriber); + OperationContext operationContext = getOperationContext(requestContext, actualClientSession, readConcern, operation.getCommandName()) + .withSessionContext(new ClientSessionBinding.AsyncClientSessionContext(actualClientSession, + isImplicitSession(session), readConcern)); + if (session != null && session.hasActiveTransaction() && !binding.getReadPreference().equals(primary())) { binding.release(); return Mono.error(new MongoClientException("Read preference in a transaction must be primary")); } else { - return Mono.create(sink -> operation.executeAsync(binding, (result, t) -> { + return Mono.create(sink -> operation.executeAsync(binding, operationContext, (result, t) -> { try { binding.release(); } finally { @@ -121,10 +126,14 @@ public Mono execute(final WriteOperation operation, final ReadConcern return Mono.from(subscriber -> clientSessionHelper.withClientSession(session, this) - .map(clientSession -> getReadWriteBinding(getContext(subscriber), - primary(), readConcern, clientSession, session == null, operation.getCommandName())) - .flatMap(binding -> - Mono.create(sink -> operation.executeAsync(binding, (result, t) -> { + .flatMap(actualClientSession -> { + AsyncReadWriteBinding binding = getReadWriteBinding(primary(), actualClientSession, session == null); + RequestContext requestContext = getContext(subscriber); + OperationContext operationContext = getOperationContext(requestContext, actualClientSession, readConcern, operation.getCommandName()) + .withSessionContext(new ClientSessionBinding.AsyncClientSessionContext(actualClientSession, + isImplicitSession(session), readConcern)); + + return Mono.create(sink -> operation.executeAsync(binding, operationContext, (result, t) -> { try { binding.release(); } finally { @@ -134,7 +143,8 @@ public Mono execute(final WriteOperation operation, final ReadConcern Throwable exceptionToHandle = t instanceof MongoException ? OperationHelper.unwrap((MongoException) t) : t; labelException(session, exceptionToHandle); unpinServerAddressOnTransientTransactionError(session, exceptionToHandle); - }) + }); + } ).subscribe(subscriber) ); } @@ -177,13 +187,12 @@ private void unpinServerAddressOnTransientTransactionError(@Nullable final Clien } } - private AsyncReadWriteBinding getReadWriteBinding(final RequestContext requestContext, - final ReadPreference readPreference, final ReadConcern readConcern, final ClientSession session, - final boolean ownsSession, final String commandName) { + private AsyncReadWriteBinding getReadWriteBinding(final ReadPreference readPreference, + final ClientSession session, + final boolean ownsSession) { notNull("readPreference", readPreference); AsyncClusterAwareReadWriteBinding readWriteBinding = new AsyncClusterBinding(mongoClient.getCluster(), - getReadPreferenceForBinding(readPreference, session), readConcern, - getOperationContext(requestContext, session, readConcern, commandName)); + getReadPreferenceForBinding(readPreference, session)); Crypt crypt = mongoClient.getCrypt(); if (crypt != null) { @@ -221,4 +230,8 @@ private ReadPreference getReadPreferenceForBinding(final ReadPreference readPref } return readPreference; } + + private boolean isImplicitSession(@Nullable final ClientSession session) { + return session == null; + } } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/VoidReadOperationThenCursorReadOperation.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/VoidReadOperationThenCursorReadOperation.java index e74949432b9..71314bd2eab 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/VoidReadOperationThenCursorReadOperation.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/VoidReadOperationThenCursorReadOperation.java @@ -19,6 +19,7 @@ import com.mongodb.internal.async.AsyncBatchCursor; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.operation.ReadOperationCursor; import com.mongodb.internal.operation.ReadOperationSimple; @@ -46,12 +47,12 @@ public String getCommandName() { } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { - readOperation.executeAsync(binding, (result, t) -> { + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback> callback) { + readOperation.executeAsync(binding, operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { - cursorReadOperation.executeAsync(binding, callback); + cursorReadOperation.executeAsync(binding, operationContext, callback); } }); } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/VoidWriteOperationThenCursorReadOperation.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/VoidWriteOperationThenCursorReadOperation.java index 428ad21ca26..79328467a82 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/VoidWriteOperationThenCursorReadOperation.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/VoidWriteOperationThenCursorReadOperation.java @@ -20,6 +20,7 @@ import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.AsyncWriteBinding; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.operation.ReadOperationCursor; import com.mongodb.internal.operation.WriteOperation; @@ -39,12 +40,12 @@ public String getCommandName() { } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { - writeOperation.executeAsync((AsyncWriteBinding) binding, (result, t) -> { + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback> callback) { + writeOperation.executeAsync((AsyncWriteBinding) binding, operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { - cursorReadOperation.executeAsync(binding, callback); + cursorReadOperation.executeAsync(binding, operationContext, callback); } }); } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/Crypt.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/Crypt.java index 61ccaa320fe..a9407799dc7 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/Crypt.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/Crypt.java @@ -143,9 +143,9 @@ public Mono encrypt(final String databaseName, final RawBsonDoc * * @param commandResponse the encrypted command response */ - public Mono decrypt(final RawBsonDocument commandResponse, @Nullable final Timeout operationTimeout) { + public Mono decrypt(final RawBsonDocument commandResponse, @Nullable final Timeout timeout) { notNull("commandResponse", commandResponse); - return executeStateMachine(() -> mongoCrypt.createDecryptionContext(commandResponse), operationTimeout) + return executeStateMachine(() -> mongoCrypt.createDecryptionContext(commandResponse), timeout) .onErrorMap(this::wrapInClientException); } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptBinding.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptBinding.java index 1dcc8a07d62..cff663e663c 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptBinding.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptBinding.java @@ -44,8 +44,8 @@ public ReadPreference getReadPreference() { } @Override - public void getWriteConnectionSource(final SingleResultCallback callback) { - wrapped.getWriteConnectionSource((result, t) -> { + public void getWriteConnectionSource(final OperationContext operationContext, final SingleResultCallback callback) { + wrapped.getWriteConnectionSource(operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -55,13 +55,8 @@ public void getWriteConnectionSource(final SingleResultCallback callback) { - wrapped.getReadConnectionSource((result, t) -> { + public void getReadConnectionSource(final OperationContext operationContext, final SingleResultCallback callback) { + wrapped.getReadConnectionSource(operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -72,8 +67,9 @@ public void getReadConnectionSource(final SingleResultCallback callback) { - wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, (result, t) -> { + wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -84,8 +80,10 @@ public void getReadConnectionSource(final int minWireVersion, final ReadPreferen @Override - public void getConnectionSource(final ServerAddress serverAddress, final SingleResultCallback callback) { - wrapped.getConnectionSource(serverAddress, (result, t) -> { + public void getConnectionSource(final ServerAddress serverAddress, + final OperationContext operationContext, + final SingleResultCallback callback) { + wrapped.getConnectionSource(serverAddress, operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { @@ -110,6 +108,7 @@ public int release() { return wrapped.release(); } + private class CryptConnectionSource implements AsyncConnectionSource { private final AsyncConnectionSource wrapped; @@ -123,19 +122,14 @@ public ServerDescription getServerDescription() { return wrapped.getServerDescription(); } - @Override - public OperationContext getOperationContext() { - return wrapped.getOperationContext(); - } - @Override public ReadPreference getReadPreference() { return wrapped.getReadPreference(); } @Override - public void getConnection(final SingleResultCallback callback) { - wrapped.getConnection((result, t) -> { + public void getConnection(final OperationContext operationContext, final SingleResultCallback callback) { + wrapped.getConnection(operationContext, (result, t) -> { if (t != null) { callback.onResult(null, t); } else { diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java index c05bfb663f2..f7febad164c 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java @@ -121,14 +121,14 @@ public void commandAsync(final String database, final BsonDocument command, : new SplittablePayloadBsonWriter(bsonBinaryWriter, bsonOutput, createSplittablePayloadMessageSettings(), payload, MAX_SPLITTABLE_DOCUMENT_SIZE); - Timeout operationTimeout = operationContext.getTimeoutContext().getTimeout(); + Timeout timeout = operationContext.getTimeoutContext().getTimeout(); getEncoder(command).encode(writer, command, EncoderContext.builder().build()); - crypt.encrypt(database, new RawBsonDocument(bsonOutput.getInternalBuffer(), 0, bsonOutput.getSize()), operationTimeout) + crypt.encrypt(database, new RawBsonDocument(bsonOutput.getInternalBuffer(), 0, bsonOutput.getSize()), timeout) .flatMap((Function>) encryptedCommand -> Mono.create(sink -> wrapped.commandAsync(database, encryptedCommand, commandFieldNameValidator, readPreference, new RawBsonDocumentCodec(), operationContext, responseExpected, EmptyMessageSequences.INSTANCE, sinkToCallback(sink)))) - .flatMap(rawBsonDocument -> crypt.decrypt(rawBsonDocument, operationTimeout)) + .flatMap(rawBsonDocument -> crypt.decrypt(rawBsonDocument, timeout)) .map(decryptedResponse -> commandResultDecoder.decode(new BsonBinaryReader(decryptedResponse.getByteBuffer().asNIO()), DecoderContext.builder().build()) diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSBucketImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSBucketImpl.java index 948c666489c..0b3a166e698 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSBucketImpl.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSBucketImpl.java @@ -232,7 +232,7 @@ public GridFSDownloadPublisher downloadToPublisher(final String filename) { @Override public GridFSDownloadPublisher downloadToPublisher(final String filename, final GridFSDownloadOptions options) { Function findPublisherCreator = - operationTimeout -> createGridFSFindPublisher(filesCollection, null, filename, options, operationTimeout); + timeout -> createGridFSFindPublisher(filesCollection, null, filename, options, timeout); return createGridFSDownloadPublisher(chunksCollection, null, findPublisherCreator); } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSPublisherCreator.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSPublisherCreator.java index 166abca6a0b..cf50a9a2dc5 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSPublisherCreator.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/gridfs/GridFSPublisherCreator.java @@ -93,9 +93,9 @@ public static GridFSFindPublisher createGridFSFindPublisher( final MongoCollection filesCollection, @Nullable final ClientSession clientSession, @Nullable final Bson filter, - @Nullable final Timeout operationTimeout) { + @Nullable final Timeout timeout) { notNull("filesCollection", filesCollection); - return new GridFSFindPublisherImpl(createFindPublisher(filesCollection, clientSession, filter, operationTimeout)); + return new GridFSFindPublisherImpl(createFindPublisher(filesCollection, clientSession, filter, timeout)); } public static GridFSFindPublisher createGridFSFindPublisher( @@ -103,7 +103,7 @@ public static GridFSFindPublisher createGridFSFindPublisher( @Nullable final ClientSession clientSession, final String filename, final GridFSDownloadOptions options, - @Nullable final Timeout operationTimeout) { + @Nullable final Timeout timeout) { notNull("filesCollection", filesCollection); notNull("filename", filename); notNull("options", options); @@ -119,7 +119,8 @@ public static GridFSFindPublisher createGridFSFindPublisher( sort = -1; } - return createGridFSFindPublisher(filesCollection, clientSession, new Document("filename", filename), operationTimeout).skip(skip) + return createGridFSFindPublisher(filesCollection, clientSession, new Document("filename", filename), timeout) + .skip(skip) .sort(new Document("uploadDate", sort)); } @@ -127,19 +128,19 @@ public static FindPublisher createFindPublisher( final MongoCollection filesCollection, @Nullable final ClientSession clientSession, @Nullable final Bson filter, - @Nullable final Timeout operationTimeout) { + @Nullable final Timeout timeout) { notNull("filesCollection", filesCollection); FindPublisher publisher; if (clientSession == null) { - publisher = collectionWithTimeout(filesCollection, operationTimeout).find(); + publisher = collectionWithTimeout(filesCollection, timeout).find(); } else { - publisher = collectionWithTimeout(filesCollection, operationTimeout).find(clientSession); + publisher = collectionWithTimeout(filesCollection, timeout).find(clientSession); } if (filter != null) { publisher = publisher.filter(filter); } - if (operationTimeout != null) { + if (timeout != null) { publisher.timeoutMode(TimeoutMode.CURSOR_LIFETIME); } return publisher; @@ -175,8 +176,8 @@ public static Publisher createDeletePublisher(final MongoCollection { - Timeout operationTimeout = startTimeout(filesCollection.getTimeout(MILLISECONDS)); - return collectionWithTimeoutMono(filesCollection, operationTimeout) + Timeout timeout = startTimeout(filesCollection.getTimeout(MILLISECONDS)); + return collectionWithTimeoutMono(filesCollection, timeout) .flatMap(wrappedCollection -> { if (clientSession == null) { return Mono.from(wrappedCollection.deleteOne(filter)); @@ -187,7 +188,7 @@ public static Publisher createDeletePublisher(final MongoCollection { if (clientSession == null) { return Mono.from(wrappedCollection.deleteMany(new BsonDocument("files_id", id))); @@ -228,15 +229,15 @@ public static Publisher createDropPublisher(final MongoCollection { - Timeout operationTimeout = startTimeout(filesCollection.getTimeout(MILLISECONDS)); - return collectionWithTimeoutMono(filesCollection, operationTimeout) + Timeout timeout = startTimeout(filesCollection.getTimeout(MILLISECONDS)); + return collectionWithTimeoutMono(filesCollection, timeout) .flatMap(wrappedCollection -> { if (clientSession == null) { return Mono.from(wrappedCollection.drop()); } else { return Mono.from(wrappedCollection.drop(clientSession)); } - }).then(collectionWithTimeoutDeferred(chunksCollection, operationTimeout)) + }).then(collectionWithTimeoutDeferred(chunksCollection, timeout)) .flatMap(wrappedCollection -> { if (clientSession == null) { return Mono.from(wrappedCollection.drop()); diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java index 5613e6dbcd8..0039594f409 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideOperationTimeoutProseTest.java @@ -466,7 +466,7 @@ public void testTimeoutMsISHonoredForNnextOperationWhenSeveralGetMoreExecutedInt .getCollection(namespace.getCollectionName()).withReadPreference(ReadPreference.primary()); //when - ChangeStreamPublisher documentChangeStreamPublisher = collection.watch(); + ChangeStreamPublisher documentChangeStreamPublisher = collection.watch().maxAwaitTime(1000, TimeUnit.MILLISECONDS); StepVerifier.create(documentChangeStreamPublisher, 2) //then .expectError(MongoOperationTimeoutException.class) diff --git a/driver-reactive-streams/src/test/resources/logback-test.xml b/driver-reactive-streams/src/test/resources/logback-test.xml index 022806f0e4e..b25f68499b2 100644 --- a/driver-reactive-streams/src/test/resources/logback-test.xml +++ b/driver-reactive-streams/src/test/resources/logback-test.xml @@ -6,7 +6,7 @@ - + diff --git a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/ClientSessionBindingSpecification.groovy b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/ClientSessionBindingSpecification.groovy index d6233342291..7308560241f 100644 --- a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/ClientSessionBindingSpecification.groovy +++ b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/ClientSessionBindingSpecification.groovy @@ -16,7 +16,7 @@ package com.mongodb.reactivestreams.client.internal -import com.mongodb.ReadConcern + import com.mongodb.ReadPreference import com.mongodb.ServerAddress import com.mongodb.async.FutureResultCallback @@ -36,59 +36,31 @@ import spock.lang.Specification import static com.mongodb.ClusterFixture.OPERATION_CONTEXT class ClientSessionBindingSpecification extends Specification { - def 'should return the session context from the binding'() { - given: - def session = Stub(ClientSession) - def wrappedBinding = Stub(AsyncClusterAwareReadWriteBinding) { - getOperationContext() >> OPERATION_CONTEXT - } - def binding = new ClientSessionBinding(session, false, wrappedBinding) - - when: - def context = binding.getOperationContext().getSessionContext() - - then: - (context as ClientSessionContext).getClientSession() == session - } def 'should return the session context from the connection source'() { given: def session = Stub(ClientSession) - def wrappedBinding = Mock(AsyncClusterAwareReadWriteBinding) { - getOperationContext() >> OPERATION_CONTEXT - } + def wrappedBinding = Mock(AsyncClusterAwareReadWriteBinding); wrappedBinding.retain() >> wrappedBinding def binding = new ClientSessionBinding(session, false, wrappedBinding) when: def futureResultCallback = new FutureResultCallback() - binding.getReadConnectionSource(futureResultCallback) + binding.getReadConnectionSource(OPERATION_CONTEXT, futureResultCallback) then: - 1 * wrappedBinding.getReadConnectionSource(_) >> { - it[0].onResult(Stub(AsyncConnectionSource), null) + 1 * wrappedBinding.getReadConnectionSource(OPERATION_CONTEXT, _) >> { + it[1].onResult(Stub(AsyncConnectionSource), null) } - when: - def context = futureResultCallback.get().getOperationContext().getSessionContext() - - then: - (context as ClientSessionContext).getClientSession() == session - when: futureResultCallback = new FutureResultCallback() - binding.getWriteConnectionSource(futureResultCallback) + binding.getWriteConnectionSource(OPERATION_CONTEXT, futureResultCallback) then: - 1 * wrappedBinding.getWriteConnectionSource(_) >> { - it[0].onResult(Stub(AsyncConnectionSource), null) + 1 * wrappedBinding.getWriteConnectionSource(OPERATION_CONTEXT, _) >> { + it[1].onResult(Stub(AsyncConnectionSource), null) } - - when: - context = futureResultCallback.get().getOperationContext().getSessionContext() - - then: - (context as ClientSessionContext).getClientSession() == session } def 'should close client session when binding reference count drops to zero if it is owned by the binding'() { @@ -117,10 +89,10 @@ class ClientSessionBindingSpecification extends Specification { def wrappedBinding = createStubBinding() def binding = new ClientSessionBinding(session, true, wrappedBinding) def futureResultCallback = new FutureResultCallback() - binding.getReadConnectionSource(futureResultCallback) + binding.getReadConnectionSource(OPERATION_CONTEXT, futureResultCallback) def readConnectionSource = futureResultCallback.get() futureResultCallback = new FutureResultCallback() - binding.getWriteConnectionSource(futureResultCallback) + binding.getWriteConnectionSource(OPERATION_CONTEXT, futureResultCallback) def writeConnectionSource = futureResultCallback.get() when: @@ -162,20 +134,21 @@ class ClientSessionBindingSpecification extends Specification { 0 * session.close() } - def 'owned session is implicit'() { - given: - def session = Mock(ClientSession) - def wrappedBinding = createStubBinding() - - when: - def binding = new ClientSessionBinding(session, ownsSession, wrappedBinding) - - then: - binding.getOperationContext().getSessionContext().isImplicitSession() == ownsSession - - where: - ownsSession << [true, false] - } + //TODO move to SessionContext test +// def 'owned session is implicit'() { +// given: +// def session = Mock(ClientSession) +// def wrappedBinding = createStubBinding() +// +// when: +// def binding = new ClientSessionBinding(session, ownsSession, wrappedBinding) +// +// then: +// binding.getOperationContext(_).getSessionContext().isImplicitSession() == ownsSession +// +// where: +// ownsSession << [true, false] +// } private AsyncClusterAwareReadWriteBinding createStubBinding() { def cluster = Mock(Cluster) { @@ -187,6 +160,6 @@ class ClientSessionBindingSpecification extends Specification { .build()), null) } } - new AsyncClusterBinding(cluster, ReadPreference.primary(), ReadConcern.DEFAULT, OPERATION_CONTEXT) + new AsyncClusterBinding(cluster, ReadPreference.primary()) } } diff --git a/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java b/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java index 2d8a4dbfb30..48694bef74d 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java +++ b/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java @@ -16,6 +16,7 @@ package com.mongodb.client.internal; +import com.mongodb.Function; import com.mongodb.ReadConcern; import com.mongodb.ReadPreference; import com.mongodb.client.ClientSession; @@ -30,8 +31,6 @@ import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.session.ClientSessionContext; -import java.util.function.Supplier; - import static com.mongodb.connection.ClusterType.LOAD_BALANCED; import static com.mongodb.connection.ClusterType.SHARDED; import static org.bson.assertions.Assertions.assertNotNull; @@ -44,14 +43,14 @@ public class ClientSessionBinding extends AbstractReferenceCounted implements Re private final ClusterAwareReadWriteBinding wrapped; private final ClientSession session; private final boolean ownsSession; - private final OperationContext operationContext; - public ClientSessionBinding(final ClientSession session, final boolean ownsSession, final ClusterAwareReadWriteBinding wrapped) { + public ClientSessionBinding(final ClientSession session, + final boolean ownsSession, + final ClusterAwareReadWriteBinding wrapped) { this.wrapped = wrapped; wrapped.retain(); this.session = notNull("session", session); this.ownsSession = ownsSession; - this.operationContext = wrapped.getOperationContext().withSessionContext(new SyncClientSessionContext(session)); } @Override @@ -84,32 +83,34 @@ public int release() { } @Override - public ConnectionSource getReadConnectionSource() { - return new SessionBindingConnectionSource(getConnectionSource(wrapped::getReadConnectionSource)); + public ConnectionSource getReadConnectionSource(final OperationContext operationContext) { + return new SessionBindingConnectionSource(getConnectionSource(wrapped::getReadConnectionSource, operationContext)); } @Override - public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference) { - return new SessionBindingConnectionSource(getConnectionSource(() -> - wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference))); - } + public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference, + final OperationContext operationContext) { + ConnectionSource connectionSource = getConnectionSource( + opContext -> wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, opContext), + operationContext); - public ConnectionSource getWriteConnectionSource() { - return new SessionBindingConnectionSource(getConnectionSource(wrapped::getWriteConnectionSource)); + return new SessionBindingConnectionSource(connectionSource); } @Override - public OperationContext getOperationContext() { - return operationContext; + public ConnectionSource getWriteConnectionSource(final OperationContext operationContext) { + ConnectionSource connectionSource = getConnectionSource(wrapped::getWriteConnectionSource, operationContext); + return new SessionBindingConnectionSource(connectionSource); } - private ConnectionSource getConnectionSource(final Supplier wrappedConnectionSourceSupplier) { + private ConnectionSource getConnectionSource(final Function wrappedConnectionSourceSupplier, + final OperationContext operationContext) { if (!session.hasActiveTransaction()) { - return wrappedConnectionSourceSupplier.get(); + return wrappedConnectionSourceSupplier.apply(operationContext); } if (TransactionContext.get(session) == null) { - ConnectionSource source = wrappedConnectionSourceSupplier.get(); + ConnectionSource source = wrappedConnectionSourceSupplier.apply(operationContext); ClusterType clusterType = source.getServerDescription().getClusterType(); if (clusterType == SHARDED || clusterType == LOAD_BALANCED) { TransactionContext transactionContext = new TransactionContext<>(clusterType); @@ -118,7 +119,7 @@ private ConnectionSource getConnectionSource(final Supplier wr } return source; } else { - return wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress())); + return wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress()), operationContext); } } @@ -135,30 +136,25 @@ public ServerDescription getServerDescription() { return wrapped.getServerDescription(); } - @Override - public OperationContext getOperationContext() { - return operationContext; - } - @Override public ReadPreference getReadPreference() { return wrapped.getReadPreference(); } @Override - public Connection getConnection() { + public Connection getConnection(final OperationContext operationContext) { TransactionContext transactionContext = TransactionContext.get(session); if (transactionContext != null && transactionContext.isConnectionPinningRequired()) { Connection pinnedConnection = transactionContext.getPinnedConnection(); if (pinnedConnection == null) { - Connection connection = wrapped.getConnection(); + Connection connection = wrapped.getConnection(operationContext); transactionContext.pinConnection(connection, Connection::markAsPinned); return connection; } else { return pinnedConnection.retain(); } } else { - return wrapped.getConnection(); + return wrapped.getConnection(operationContext); } } @@ -184,13 +180,25 @@ public int release() { } } - private final class SyncClientSessionContext extends ClientSessionContext { + public static final class SyncClientSessionContext extends ClientSessionContext { private final ClientSession clientSession; - - SyncClientSessionContext(final ClientSession clientSession) { + private final boolean ownsSession; + private final ReadConcern inheritedReadConcern; + + /** + * @param clientSession the client session to use. + * @param inheritedReadConcern the read concern inherited from either {@link com.mongodb.client.MongoCollection}, + * {@link com.mongodb.client.MongoDatabase} and etc. + * @param ownsSession if true, the session is implicit. + */ + SyncClientSessionContext(final ClientSession clientSession, + final ReadConcern inheritedReadConcern, + final boolean ownsSession) { super(clientSession); this.clientSession = clientSession; + this.ownsSession = ownsSession; + this.inheritedReadConcern = inheritedReadConcern; } @Override @@ -215,7 +223,10 @@ public ReadConcern getReadConcern() { } else if (isSnapshot()) { return ReadConcern.SNAPSHOT; } else { - return wrapped.getOperationContext().getSessionContext().getReadConcern(); + //COMMENT the read concern was specified on wrapped BindingContext was the one that was inherited from either MongoCollection, MongoDatabase, etc. + // since we removed the BindingContext, we can now embedd the parent read concern directly. + //return wrapped.getOperationContext().getSessionContext().getReadConcern(); + return inheritedReadConcern; } } } diff --git a/driver-sync/src/main/com/mongodb/client/internal/CollectionInfoRetriever.java b/driver-sync/src/main/com/mongodb/client/internal/CollectionInfoRetriever.java index 9d02a1e8756..9ce34fd18a9 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/CollectionInfoRetriever.java +++ b/driver-sync/src/main/com/mongodb/client/internal/CollectionInfoRetriever.java @@ -36,8 +36,8 @@ class CollectionInfoRetriever { this.client = notNull("client", client); } - public List filter(final String databaseName, final BsonDocument filter, @Nullable final Timeout operationTimeout) { - return databaseWithTimeout(client.getDatabase(databaseName), TIMEOUT_ERROR_MESSAGE, operationTimeout) + public List filter(final String databaseName, final BsonDocument filter, @Nullable final Timeout timeout) { + return databaseWithTimeout(client.getDatabase(databaseName), TIMEOUT_ERROR_MESSAGE, timeout) .listCollections(BsonDocument.class) .filter(filter) .into(new ArrayList<>()); diff --git a/driver-sync/src/main/com/mongodb/client/internal/CommandMarker.java b/driver-sync/src/main/com/mongodb/client/internal/CommandMarker.java index 73eed8efd01..d694cf682a3 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/CommandMarker.java +++ b/driver-sync/src/main/com/mongodb/client/internal/CommandMarker.java @@ -84,11 +84,11 @@ class CommandMarker implements Closeable { } } - RawBsonDocument mark(final String databaseName, final RawBsonDocument command, @Nullable final Timeout operationTimeout) { + RawBsonDocument mark(final String databaseName, final RawBsonDocument command, @Nullable final Timeout timeout) { if (client != null) { try { try { - return executeCommand(databaseName, command, operationTimeout); + return executeCommand(databaseName, command, timeout); } catch (MongoOperationTimeoutException e){ throw e; } catch (MongoTimeoutException e) { @@ -96,7 +96,7 @@ RawBsonDocument mark(final String databaseName, final RawBsonDocument command, @ throw e; } startProcess(processBuilder); - return executeCommand(databaseName, command, operationTimeout); + return executeCommand(databaseName, command, timeout); } } catch (MongoException e) { throw wrapInClientException(e); @@ -113,14 +113,17 @@ public void close() { } } - private RawBsonDocument executeCommand(final String databaseName, final RawBsonDocument markableCommand, @Nullable final Timeout operationTimeout) { + private RawBsonDocument executeCommand( + final String databaseName, + final RawBsonDocument markableCommand, + @Nullable final Timeout timeout) { assertNotNull(client); MongoDatabase mongoDatabase = client.getDatabase(databaseName) .withReadConcern(ReadConcern.DEFAULT) .withReadPreference(ReadPreference.primary()); - return databaseWithTimeout(mongoDatabase, TIMEOUT_ERROR_MESSAGE, operationTimeout) + return databaseWithTimeout(mongoDatabase, TIMEOUT_ERROR_MESSAGE, timeout) .runCommand(markableCommand, RawBsonDocument.class); } diff --git a/driver-sync/src/main/com/mongodb/client/internal/Crypt.java b/driver-sync/src/main/com/mongodb/client/internal/Crypt.java index 15ba16e66da..ae7a75ae626 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/Crypt.java +++ b/driver-sync/src/main/com/mongodb/client/internal/Crypt.java @@ -132,7 +132,10 @@ public class Crypt implements Closeable { * @param command the unencrypted command * @return the encrypted command */ - RawBsonDocument encrypt(final String databaseName, final RawBsonDocument command, @Nullable final Timeout timeoutOperation) { + RawBsonDocument encrypt( + final String databaseName, + final RawBsonDocument command, + @Nullable final Timeout timeout) { notNull("databaseName", databaseName); notNull("command", command); @@ -141,7 +144,7 @@ RawBsonDocument encrypt(final String databaseName, final RawBsonDocument command } try (MongoCryptContext encryptionContext = mongoCrypt.createEncryptionContext(databaseName, command)) { - return executeStateMachine(encryptionContext, databaseName, timeoutOperation); + return executeStateMachine(encryptionContext, databaseName, timeout); } catch (MongoCryptException e) { throw wrapInMongoException(e); } @@ -274,24 +277,27 @@ public void close() { } } - private RawBsonDocument executeStateMachine(final MongoCryptContext cryptContext, @Nullable final String databaseName, @Nullable final Timeout operationTimeout) { + private RawBsonDocument executeStateMachine( + final MongoCryptContext cryptContext, + @Nullable final String databaseName, + @Nullable final Timeout timeout) { while (true) { State state = cryptContext.getState(); switch (state) { case NEED_MONGO_COLLINFO: - collInfo(cryptContext, notNull("databaseName", databaseName), operationTimeout); + collInfo(cryptContext, notNull("databaseName", databaseName), timeout); break; case NEED_MONGO_MARKINGS: - mark(cryptContext, notNull("databaseName", databaseName), operationTimeout); + mark(cryptContext, notNull("databaseName", databaseName), timeout); break; case NEED_KMS_CREDENTIALS: fetchCredentials(cryptContext); break; case NEED_MONGO_KEYS: - fetchKeys(cryptContext, operationTimeout); + fetchKeys(cryptContext, timeout); break; case NEED_KMS: - decryptKeys(cryptContext, operationTimeout); + decryptKeys(cryptContext, timeout); break; case READY: return cryptContext.finish(); @@ -320,9 +326,9 @@ private void collInfo(final MongoCryptContext cryptContext, final String databas } } - private void mark(final MongoCryptContext cryptContext, final String databaseName, @Nullable final Timeout operationTimeout) { + private void mark(final MongoCryptContext cryptContext, final String databaseName, @Nullable final Timeout timeout) { try { - RawBsonDocument markedCommand = assertNotNull(commandMarker).mark(databaseName, cryptContext.getMongoOperation(), operationTimeout); + RawBsonDocument markedCommand = assertNotNull(commandMarker).mark(databaseName, cryptContext.getMongoOperation(), timeout); cryptContext.addMongoOperationResult(markedCommand); cryptContext.completeMongoOperation(); } catch (Throwable t) { diff --git a/driver-sync/src/main/com/mongodb/client/internal/CryptBinding.java b/driver-sync/src/main/com/mongodb/client/internal/CryptBinding.java index 036466077ec..8032bea371c 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/CryptBinding.java +++ b/driver-sync/src/main/com/mongodb/client/internal/CryptBinding.java @@ -41,28 +41,23 @@ public ReadPreference getReadPreference() { } @Override - public ConnectionSource getReadConnectionSource() { - return new CryptConnectionSource(wrapped.getReadConnectionSource()); + public ConnectionSource getReadConnectionSource(final OperationContext operationContext) { + return new CryptConnectionSource(wrapped.getReadConnectionSource(operationContext)); } @Override - public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference) { - return new CryptConnectionSource(wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference)); + public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference, final OperationContext operationContext) { + return new CryptConnectionSource(wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, operationContext)); } @Override - public ConnectionSource getWriteConnectionSource() { - return new CryptConnectionSource(wrapped.getWriteConnectionSource()); + public ConnectionSource getWriteConnectionSource(final OperationContext operationContext) { + return new CryptConnectionSource(wrapped.getWriteConnectionSource(operationContext)); } @Override - public ConnectionSource getConnectionSource(final ServerAddress serverAddress) { - return new CryptConnectionSource(wrapped.getConnectionSource(serverAddress)); - } - - @Override - public OperationContext getOperationContext() { - return wrapped.getOperationContext(); + public ConnectionSource getConnectionSource(final ServerAddress serverAddress, final OperationContext operationContext) { + return new CryptConnectionSource(wrapped.getConnectionSource(serverAddress, operationContext)); } @Override @@ -93,19 +88,14 @@ public ServerDescription getServerDescription() { return wrapped.getServerDescription(); } - @Override - public OperationContext getOperationContext() { - return wrapped.getOperationContext(); - } - @Override public ReadPreference getReadPreference() { return wrapped.getReadPreference(); } @Override - public Connection getConnection() { - return new CryptConnection(wrapped.getConnection(), crypt); + public Connection getConnection(final OperationContext operationContext) { + return new CryptConnection(wrapped.getConnection(operationContext), crypt); } @Override diff --git a/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java b/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java index 803df89a6b6..b5c814edb04 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java +++ b/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java @@ -112,9 +112,11 @@ public T command(final String database, final BsonDocument command, final Fi getEncoder(command).encode(writer, command, EncoderContext.builder().build()); - Timeout operationTimeout = operationContext.getTimeoutContext().getTimeout(); - RawBsonDocument encryptedCommand = crypt.encrypt(database, - new RawBsonDocument(bsonOutput.getInternalBuffer(), 0, bsonOutput.getSize()), operationTimeout); + Timeout timeout = operationContext.getTimeoutContext().getTimeout(); + RawBsonDocument encryptedCommand = crypt.encrypt( + database, + new RawBsonDocument(bsonOutput.getInternalBuffer(), 0, bsonOutput.getSize()), + timeout); RawBsonDocument encryptedResponse = wrapped.command(database, encryptedCommand, commandFieldNameValidator, readPreference, new RawBsonDocumentCodec(), operationContext, responseExpected, EmptyMessageSequences.INSTANCE); @@ -123,7 +125,7 @@ public T command(final String database, final BsonDocument command, final Fi return null; } - RawBsonDocument decryptedResponse = crypt.decrypt(encryptedResponse, operationTimeout); + RawBsonDocument decryptedResponse = crypt.decrypt(encryptedResponse, timeout); BsonBinaryReader reader = new BsonBinaryReader(decryptedResponse.getByteBuffer().asNIO()); diff --git a/driver-sync/src/main/com/mongodb/client/internal/MapReduceIterableImpl.java b/driver-sync/src/main/com/mongodb/client/internal/MapReduceIterableImpl.java index be3e8ca05e9..c1118635647 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/MapReduceIterableImpl.java +++ b/driver-sync/src/main/com/mongodb/client/internal/MapReduceIterableImpl.java @@ -29,6 +29,7 @@ import com.mongodb.internal.binding.AsyncReadBinding; import com.mongodb.internal.binding.ReadBinding; import com.mongodb.internal.client.model.FindOptions; +import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.operation.BatchCursor; import com.mongodb.internal.operation.MapReduceStatistics; import com.mongodb.internal.operation.Operations; @@ -241,12 +242,12 @@ public String getCommandName() { } @Override - public BatchCursor execute(final ReadBinding binding) { - return operation.execute(binding); + public BatchCursor execute(final ReadBinding binding, final OperationContext operationContext) { + return operation.execute(binding, operationContext); } @Override - public void executeAsync(final AsyncReadBinding binding, final SingleResultCallback> callback) { + public void executeAsync(final AsyncReadBinding binding, final OperationContext operationContext, final SingleResultCallback> callback) { throw new UnsupportedOperationException("This operation is sync only"); } } diff --git a/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java b/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java index 058122e9c26..77d217f5c99 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java +++ b/driver-sync/src/main/com/mongodb/client/internal/MongoClusterImpl.java @@ -416,14 +416,15 @@ public T execute(final ReadOperation operation, final ReadPreference r } ClientSession actualClientSession = getClientSession(session); - ReadBinding binding = getReadBinding(readPreference, readConcern, actualClientSession, session == null, - operation.getCommandName()); + OperationContext operationContext = getOperationContext(actualClientSession, readConcern, operation.getCommandName()) + .withSessionContext(new ClientSessionBinding.SyncClientSessionContext(actualClientSession, readConcern, isImplicitSession(session))); + ReadBinding binding = getReadBinding(readPreference, actualClientSession, isImplicitSession(session)); try { if (actualClientSession.hasActiveTransaction() && !binding.getReadPreference().equals(primary())) { throw new MongoClientException("Read preference in a transaction must be primary"); } - return operation.execute(binding); + return operation.execute(binding, operationContext); } catch (MongoException e) { MongoException exceptionToHandle = OperationHelper.unwrap(e); labelException(actualClientSession, exceptionToHandle); @@ -442,10 +443,12 @@ public T execute(final WriteOperation operation, final ReadConcern readCo } ClientSession actualClientSession = getClientSession(session); - WriteBinding binding = getWriteBinding(readConcern, actualClientSession, session == null, operation.getCommandName()); + OperationContext operationContext = getOperationContext(actualClientSession, readConcern, operation.getCommandName()) + .withSessionContext(new ClientSessionBinding.SyncClientSessionContext(actualClientSession, readConcern, isImplicitSession(session))); + WriteBinding binding = getWriteBinding(actualClientSession, isImplicitSession(session)); try { - return operation.execute(binding); + return operation.execute(binding, operationContext); } catch (MongoException e) { MongoException exceptionToHandle = OperationHelper.unwrap(e); labelException(actualClientSession, exceptionToHandle); @@ -469,23 +472,19 @@ public TimeoutSettings getTimeoutSettings() { return executorTimeoutSettings; } - WriteBinding getWriteBinding(final ReadConcern readConcern, final ClientSession session, final boolean ownsSession, - final String commandName) { - return getReadWriteBinding(primary(), readConcern, session, ownsSession, commandName); + WriteBinding getWriteBinding(final ClientSession session, final boolean ownsSession) { + return getReadWriteBinding(primary(), session, ownsSession); } - ReadBinding getReadBinding(final ReadPreference readPreference, final ReadConcern readConcern, final ClientSession session, - final boolean ownsSession, final String commandName) { - return getReadWriteBinding(readPreference, readConcern, session, ownsSession, commandName); + ReadBinding getReadBinding(final ReadPreference readPreference, final ClientSession session, + final boolean ownsSession) { + return getReadWriteBinding(readPreference, session, ownsSession); } - ReadWriteBinding getReadWriteBinding(final ReadPreference readPreference, - final ReadConcern readConcern, final ClientSession session, final boolean ownsSession, - final String commandName) { + ReadWriteBinding getReadWriteBinding(final ReadPreference readPreference, final ClientSession session, final boolean ownsSession) { ClusterAwareReadWriteBinding readWriteBinding = new ClusterBinding(cluster, - getReadPreferenceForBinding(readPreference, session), readConcern, - getOperationContext(session, readConcern, commandName)); + getReadPreferenceForBinding(readPreference, session)); if (crypt != null) { readWriteBinding = new CryptBinding(readWriteBinding, crypt); @@ -526,7 +525,7 @@ private void clearTransactionContextOnTransientTransactionError(@Nullable final } private ReadPreference getReadPreferenceForBinding(final ReadPreference readPreference, @Nullable final ClientSession session) { - if (session == null) { + if (isImplicitSession(session)) { return readPreference; } if (session.hasActiveTransaction()) { @@ -557,4 +556,8 @@ ClientSession getClientSession(@Nullable final ClientSession clientSessionFromOp return session; } } + + private boolean isImplicitSession(@Nullable final ClientSession session) { + return session == null; + } } diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java index 5cb042eaad4..524ba02061b 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java @@ -819,6 +819,37 @@ public void testKillCursorsIsNotExecutedAfterGetMoreNetworkErrorWhenTimeoutIsNot } + /** + * Not a prose spec test. However, it is additional test case for better coverage. + */ + @DisplayName("KillCursors is not executed after getMore network error when timeout is not enabled") + @Test + public void test() { + long rtt = ClusterFixture.getPrimaryRTT(); + collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions()); + collectionHelper.insertDocuments(new Document(), new Document()); + try (MongoClient mongoClient = createMongoClient( + getMongoClientSettingsBuilder() + .timeout(100, TimeUnit.MINUTES) + .retryReads(true) + .applyToConnectionPoolSettings(builder -> + builder.maxConnectionIdleTime(10, TimeUnit.MINUTES)) + .applyToClusterSettings(builder -> builder.serverSelectionTimeout(500, TimeUnit.MINUTES)) + .applyToSocketSettings(builder -> builder.readTimeout(500, TimeUnit.MILLISECONDS)))) { + MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) + .getCollection(namespace.getCollectionName()).withReadPreference(ReadPreference.primary()); + + MongoCursor cursor = collection.find() + .batchSize(1) + .cursor(); + + sleep(500); + cursor.next(); + sleep(500); + cursor.close(); + } + } + /** * Not a prose spec test. However, it is additional test case for better coverage. */ diff --git a/driver-sync/src/test/resources/logback-test.xml b/driver-sync/src/test/resources/logback-test.xml index 022806f0e4e..b25f68499b2 100644 --- a/driver-sync/src/test/resources/logback-test.xml +++ b/driver-sync/src/test/resources/logback-test.xml @@ -6,7 +6,7 @@ - + diff --git a/driver-sync/src/test/unit/com/mongodb/client/internal/ClientSessionBindingSpecification.groovy b/driver-sync/src/test/unit/com/mongodb/client/internal/ClientSessionBindingSpecification.groovy index 49332bc8ed3..d5383a30ca5 100644 --- a/driver-sync/src/test/unit/com/mongodb/client/internal/ClientSessionBindingSpecification.groovy +++ b/driver-sync/src/test/unit/com/mongodb/client/internal/ClientSessionBindingSpecification.groovy @@ -16,59 +16,38 @@ package com.mongodb.client.internal -import com.mongodb.ReadConcern + import com.mongodb.ReadPreference import com.mongodb.client.ClientSession import com.mongodb.internal.binding.ClusterBinding import com.mongodb.internal.binding.ConnectionSource import com.mongodb.internal.binding.ReadWriteBinding import com.mongodb.internal.connection.Cluster -import com.mongodb.internal.session.ClientSessionContext import spock.lang.Specification import static com.mongodb.ClusterFixture.OPERATION_CONTEXT class ClientSessionBindingSpecification extends Specification { - def 'should return the session context from the binding'() { - given: - def session = Stub(ClientSession) - def wrappedBinding = Stub(ClusterBinding) { - getOperationContext() >> OPERATION_CONTEXT - } - def binding = new ClientSessionBinding(session, false, wrappedBinding) - - when: - def context = binding.getOperationContext().getSessionContext() - - then: - (context as ClientSessionContext).getClientSession() == session - } - def 'should return the session context from the connection source'() { + def 'should call underlying wrapped binding'() { given: def session = Stub(ClientSession) - def wrappedBinding = Mock(ClusterBinding) { - getOperationContext() >> OPERATION_CONTEXT - } + def wrappedBinding = Mock(ClusterBinding); def binding = new ClientSessionBinding(session, false, wrappedBinding) when: - def readConnectionSource = binding.getReadConnectionSource() - def context = readConnectionSource.getOperationContext().getSessionContext() + binding.getReadConnectionSource(OPERATION_CONTEXT) then: - (context as ClientSessionContext).getClientSession() == session - 1 * wrappedBinding.getReadConnectionSource() >> { + 1 * wrappedBinding.getReadConnectionSource(OPERATION_CONTEXT) >> { Stub(ConnectionSource) } when: - def writeConnectionSource = binding.getWriteConnectionSource() - context = writeConnectionSource.getOperationContext().getSessionContext() + binding.getWriteConnectionSource(OPERATION_CONTEXT) then: - (context as ClientSessionContext).getClientSession() == session - 1 * wrappedBinding.getWriteConnectionSource() >> { + 1 * wrappedBinding.getWriteConnectionSource(OPERATION_CONTEXT) >> { Stub(ConnectionSource) } } @@ -98,8 +77,8 @@ class ClientSessionBindingSpecification extends Specification { def session = Mock(ClientSession) def wrappedBinding = createStubBinding() def binding = new ClientSessionBinding(session, true, wrappedBinding) - def readConnectionSource = binding.getReadConnectionSource() - def writeConnectionSource = binding.getWriteConnectionSource() + def readConnectionSource = binding.getReadConnectionSource(OPERATION_CONTEXT) + def writeConnectionSource = binding.getWriteConnectionSource(OPERATION_CONTEXT) when: binding.release() @@ -140,23 +119,24 @@ class ClientSessionBindingSpecification extends Specification { 0 * session.close() } - def 'owned session is implicit'() { - given: - def session = Mock(ClientSession) - def wrappedBinding = createStubBinding() - - when: - def binding = new ClientSessionBinding(session, ownsSession, wrappedBinding) - - then: - binding.getOperationContext().getSessionContext().isImplicitSession() == ownsSession - - where: - ownsSession << [true, false] - } + //TODO move to SessionContext test +// def 'owned session is implicit'() { +// given: +// def session = Mock(ClientSession) +// def wrappedBinding = createStubBinding() +// +// when: +// def binding = new ClientSessionBinding(session, ownsSession, wrappedBinding) +// +// then: +// binding.getOperationContext().getSessionContext().isImplicitSession() == ownsSession +// +// where: +// ownsSession << [true, false] +// } private ReadWriteBinding createStubBinding() { def cluster = Stub(Cluster) - new ClusterBinding(cluster, ReadPreference.primary(), ReadConcern.DEFAULT, OPERATION_CONTEXT) + new ClusterBinding(cluster, ReadPreference.primary()) } }