diff --git a/docs/changelog/131937.yaml b/docs/changelog/131937.yaml new file mode 100644 index 0000000000000..d3132c5d30135 --- /dev/null +++ b/docs/changelog/131937.yaml @@ -0,0 +1,5 @@ +pr: 131937 +summary: Fix race condition in `RemoteClusterService.collectNodes()` +area: Distributed +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java b/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java index ac5233f1d54b4..31a90c1ca1631 100644 --- a/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java +++ b/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java @@ -17,6 +17,7 @@ import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.RefCountingListener; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.client.internal.RemoteClusterClient; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; @@ -29,7 +30,6 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; -import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; @@ -567,36 +567,26 @@ public void collectNodes(Set clusters, ActionListener(); for (String cluster : clusters) { - if (this.remoteClusters.containsKey(cluster) == false) { + final var connection = this.remoteClusters.get(cluster); + if (connection == null) { listener.onFailure(new NoSuchRemoteClusterException(cluster)); return; } + connectionsMap.put(cluster, connection); } final Map> clusterMap = new HashMap<>(); - CountDown countDown = new CountDown(clusters.size()); - Function nullFunction = s -> null; - for (final String cluster : clusters) { - RemoteClusterConnection connection = this.remoteClusters.get(cluster); - connection.collectNodes(new ActionListener>() { - @Override - public void onResponse(Function nodeLookup) { - synchronized (clusterMap) { - clusterMap.put(cluster, nodeLookup); - } - if (countDown.countDown()) { - listener.onResponse((clusterAlias, nodeId) -> clusterMap.getOrDefault(clusterAlias, nullFunction).apply(nodeId)); - } - } - - @Override - public void onFailure(Exception e) { - if (countDown.fastForward()) { // we need to check if it's true since we could have multiple failures - listener.onFailure(e); - } + final var finalListener = listener.safeMap( + ignored -> (clusterAlias, nodeId) -> clusterMap.getOrDefault(clusterAlias, s -> null).apply(nodeId) + ); + try (var refs = new RefCountingListener(finalListener)) { + connectionsMap.forEach((cluster, connection) -> connection.collectNodes(refs.acquire(nodeLookup -> { + synchronized (clusterMap) { + clusterMap.put(cluster, nodeLookup); } - }); + }))); } } diff --git a/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java b/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java index 99c4dde4d396f..48ef90d0772db 100644 --- a/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java +++ b/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java @@ -9,8 +9,10 @@ package org.elasticsearch.transport; import org.apache.logging.log4j.Level; +import org.apache.lucene.store.AlreadyClosedException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.IndicesOptions; @@ -1060,6 +1062,85 @@ public void onFailure(Exception e) { } } + public void testCollectNodesConcurrentWithSettingsChanges() throws IOException { + final List knownNodes_c1 = new CopyOnWriteArrayList<>(); + + try ( + var c1N1 = startTransport( + "cluster_1_node_1", + knownNodes_c1, + VersionInformation.CURRENT, + TransportVersion.current(), + Settings.EMPTY + ); + var transportService = MockTransportService.createNewService( + Settings.EMPTY, + VersionInformation.CURRENT, + TransportVersion.current(), + threadPool, + null + ) + ) { + final var c1N1Node = c1N1.getLocalNode(); + knownNodes_c1.add(c1N1Node); + final var seedList = List.of(c1N1Node.getAddress().toString()); + transportService.start(); + transportService.acceptIncomingRequests(); + + try (RemoteClusterService service = new RemoteClusterService(createSettings("cluster_1", seedList), transportService)) { + service.initializeRemoteClusters(); + assertTrue(service.isCrossClusterSearchEnabled()); + final var numTasks = between(3, 5); + final var taskLatch = new CountDownLatch(numTasks); + + ESTestCase.startInParallel(numTasks, threadNumber -> { + if (threadNumber == 0) { + taskLatch.countDown(); + boolean isLinked = true; + while (taskLatch.getCount() != 0) { + final var future = new PlainActionFuture(); + final var settings = createSettings("cluster_1", isLinked ? Collections.emptyList() : seedList); + service.updateRemoteCluster("cluster_1", settings, future); + safeGet(future); + isLinked = isLinked == false; + } + return; + } + + // Verify collectNodes() always invokes the listener, even if the node is concurrently being unlinked. + try { + for (int i = 0; i < 10; ++i) { + final var latch = new CountDownLatch(1); + final var exRef = new AtomicReference(); + service.collectNodes(Set.of("cluster_1"), new LatchedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(BiFunction func) { + assertEquals(c1N1Node, func.apply("cluster_1", c1N1Node.getId())); + } + + @Override + public void onFailure(Exception e) { + exRef.set(e); + } + }, latch)); + safeAwait(latch); + if (exRef.get() != null) { + assertThat( + exRef.get(), + either(instanceOf(TransportException.class)).or(instanceOf(NoSuchRemoteClusterException.class)) + .or(instanceOf(AlreadyClosedException.class)) + .or(instanceOf(NoSeedNodeLeftException.class)) + ); + } + } + } finally { + taskLatch.countDown(); + } + }); + } + } + } + public void testRemoteClusterSkipIfDisconnectedSetting() { { Settings settings = Settings.builder()