Skip to content

Commit 105cf7a

Browse files
authored
Close server/pool in LoadBalancedCluster and LoadBalancedServer respectively (#720)
JAVA-4178
1 parent a82ce71 commit 105cf7a

File tree

3 files changed

+62
-13
lines changed

3 files changed

+62
-13
lines changed

driver-core/src/main/com/mongodb/internal/connection/LoadBalancedCluster.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import com.mongodb.event.ClusterListener;
3838
import com.mongodb.event.ClusterOpeningEvent;
3939
import com.mongodb.internal.async.SingleResultCallback;
40+
import com.mongodb.lang.Nullable;
4041
import com.mongodb.selector.ServerSelector;
4142
import org.bson.BsonTimestamp;
4243

@@ -72,6 +73,7 @@ final class LoadBalancedCluster implements Cluster {
7273
private final ClusterClock clusterClock = new ClusterClock();
7374
private final ClusterListener clusterListener;
7475
private ClusterDescription description;
76+
@Nullable
7577
private ClusterableServer server;
7678
private final AtomicBoolean closed = new AtomicBoolean();
7779
private final DnsSrvRecordMonitor dnsSrvRecordMonitor;
@@ -109,6 +111,9 @@ public void initialize(final Collection<ServerAddress> hosts) {
109111
List<ServerSelectionRequest> localWaitQueue;
110112
lock.lock();
111113
try {
114+
if (isClosed()) {
115+
return;
116+
}
112117
srvResolutionException = null;
113118
if (hosts.size() != 1) {
114119
srvRecordResolvedToMultipleHosts = true;
@@ -259,12 +264,17 @@ public void close() {
259264
if (dnsSrvRecordMonitor != null) {
260265
dnsSrvRecordMonitor.close();
261266
}
267+
ClusterableServer localServer;
262268
lock.lock();
263269
try {
264270
condition.signalAll();
271+
localServer = server;
265272
} finally {
266273
lock.unlock();
267274
}
275+
if (localServer != null) {
276+
localServer.close();
277+
}
268278
}
269279
}
270280

driver-core/src/main/com/mongodb/internal/connection/LoadBalancedServer.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ private void invalidate(final Throwable t, final ObjectId serviceId, final int g
109109
@Override
110110
public void close() {
111111
if (!closed.getAndSet(true)) {
112+
connectionPool.close();
112113
serverListener.serverClosed(new ServerClosedEvent(serverId));
113114
}
114115
}

driver-core/src/test/unit/com/mongodb/internal/connection/LoadBalancedClusterTest.java

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.junit.jupiter.api.RepeatedTest;
3636
import org.junit.jupiter.api.Tag;
3737
import org.junit.jupiter.api.Test;
38+
import org.mockito.ArgumentCaptor;
3839

3940
import java.time.Duration;
4041
import java.util.ArrayList;
@@ -57,7 +58,11 @@
5758
import static org.junit.jupiter.api.Assertions.assertTrue;
5859
import static org.mockito.ArgumentMatchers.any;
5960
import static org.mockito.ArgumentMatchers.eq;
61+
import static org.mockito.Mockito.atLeastOnce;
6062
import static org.mockito.Mockito.mock;
63+
import static org.mockito.Mockito.never;
64+
import static org.mockito.Mockito.times;
65+
import static org.mockito.Mockito.verify;
6166
import static org.mockito.Mockito.when;
6267

6368
public class LoadBalancedClusterTest {
@@ -317,6 +322,42 @@ public void shouldTimeoutSelectServerAsynchronouslyWhenThereIsSRVLookupException
317322
exception.getMessage());
318323
}
319324

325+
@Test
326+
void shouldNotInitServerAfterClosing() {
327+
// prepare mocks
328+
ClusterableServerFactory serverFactory = mock(ClusterableServerFactory.class);
329+
when(serverFactory.getSettings()).thenReturn(mock(ServerSettings.class));
330+
DnsSrvRecordMonitorFactory srvRecordMonitorFactory = mock(DnsSrvRecordMonitorFactory.class);
331+
when(srvRecordMonitorFactory.create(any(), any(DnsSrvRecordInitializer.class))).thenReturn(mock(DnsSrvRecordMonitor.class));
332+
ArgumentCaptor<DnsSrvRecordInitializer> serverInitializerCaptor = ArgumentCaptor.forClass(DnsSrvRecordInitializer.class);
333+
// create `cluster` and capture its `DnsSrvRecordInitializer` (server initializer)
334+
LoadBalancedCluster cluster = new LoadBalancedCluster(new ClusterId(),
335+
ClusterSettings.builder().mode(ClusterConnectionMode.LOAD_BALANCED).srvHost("foo.bar.com").build(),
336+
serverFactory, srvRecordMonitorFactory);
337+
verify(srvRecordMonitorFactory, times(1)).create(any(), serverInitializerCaptor.capture());
338+
// close `cluster`, call `DnsSrvRecordInitializer.initialize` and check that it does not result in creating a `ClusterableServer`
339+
cluster.close();
340+
serverInitializerCaptor.getValue().initialize(Collections.singleton(new ServerAddress()));
341+
verify(serverFactory, never()).create(any(), any(), any(), any());
342+
}
343+
344+
@Test
345+
void shouldCloseServerWhenClosing() {
346+
// prepare mocks
347+
ClusterableServerFactory serverFactory = mock(ClusterableServerFactory.class);
348+
when(serverFactory.getSettings()).thenReturn(mock(ServerSettings.class));
349+
ClusterableServer server = mock(ClusterableServer.class);
350+
when(serverFactory.create(any(), any(), any(), any())).thenReturn(server);
351+
// create `cluster` and check that it creates a `ClusterableServer`
352+
LoadBalancedCluster cluster = new LoadBalancedCluster(new ClusterId(),
353+
ClusterSettings.builder().mode(ClusterConnectionMode.LOAD_BALANCED).build(), serverFactory,
354+
mock(DnsSrvRecordMonitorFactory.class));
355+
verify(serverFactory, times(1)).create(any(), any(), any(), any());
356+
// close `cluster` and check that it closes `server`
357+
cluster.close();
358+
verify(server, atLeastOnce()).close();
359+
}
360+
320361
@RepeatedTest(value = 10, name = RepeatedTest.LONG_DISPLAY_NAME)
321362
@Tag("Slow")
322363
public void synchronousConcurrentTest() throws InterruptedException, ExecutionException, TimeoutException {
@@ -497,20 +538,17 @@ public boolean isInitialized() {
497538

498539
@Override
499540
public void start() {
500-
thread = new Thread(new Runnable() {
501-
@Override
502-
public void run() {
503-
try {
504-
Thread.sleep(sleepTime.toMillis());
505-
if (exception != null) {
506-
initializer.initialize(exception);
507-
} else {
508-
initializer.initialize(hosts);
509-
}
510-
initialized = true;
511-
} catch (InterruptedException e) {
512-
// ignore
541+
thread = new Thread(() -> {
542+
try {
543+
Thread.sleep(sleepTime.toMillis());
544+
if (exception != null) {
545+
initializer.initialize(exception);
546+
} else {
547+
initializer.initialize(hosts);
513548
}
549+
initialized = true;
550+
} catch (InterruptedException e) {
551+
// ignore
514552
}
515553
});
516554
thread.start();

0 commit comments

Comments
 (0)