Skip to content
Merged
5 changes: 5 additions & 0 deletions docs/changelog/131937.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 131937
summary: Fix race condition in `RemoteClusterService.collectNodes()`
area: Distributed
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.client.internal.RemoteClusterClient;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
Expand All @@ -29,7 +30,6 @@
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.TimeValue;
Expand Down Expand Up @@ -567,36 +567,26 @@ public void collectNodes(Set<String> clusters, ActionListener<BiFunction<String,
"this node does not have the " + DiscoveryNodeRole.REMOTE_CLUSTER_CLIENT_ROLE.roleName() + " role"
);
}
final var connectionsMap = new HashMap<String, RemoteClusterConnection>();
for (String cluster : clusters) {
if (this.remoteClusters.containsKey(cluster) == false) {
final var connection = this.remoteClusters.get(cluster);
if (connection == null) {
listener.onFailure(new NoSuchRemoteClusterException(cluster));
return;
}
connectionsMap.put(cluster, connection);
}

final Map<String, Function<String, DiscoveryNode>> clusterMap = new HashMap<>();
CountDown countDown = new CountDown(clusters.size());
Function<String, DiscoveryNode> nullFunction = s -> null;
for (final String cluster : clusters) {
RemoteClusterConnection connection = this.remoteClusters.get(cluster);
connection.collectNodes(new ActionListener<Function<String, DiscoveryNode>>() {
@Override
public void onResponse(Function<String, DiscoveryNode> nodeLookup) {
synchronized (clusterMap) {
clusterMap.put(cluster, nodeLookup);
}
if (countDown.countDown()) {
listener.onResponse((clusterAlias, nodeId) -> clusterMap.getOrDefault(clusterAlias, nullFunction).apply(nodeId));
}
}

@Override
public void onFailure(Exception e) {
if (countDown.fastForward()) { // we need to check if it's true since we could have multiple failures
listener.onFailure(e);
}
final var finalListener = listener.<Void>safeMap(
ignored -> (clusterAlias, nodeId) -> clusterMap.getOrDefault(clusterAlias, s -> null).apply(nodeId)
);
try (var refs = new RefCountingListener(finalListener)) {
connectionsMap.forEach((cluster, connection) -> connection.collectNodes(refs.acquire(nodeLookup -> {
synchronized (clusterMap) {
clusterMap.put(cluster, nodeLookup);
}
});
})));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
package org.elasticsearch.transport;

import org.apache.logging.log4j.Level;
import org.apache.lucene.store.AlreadyClosedException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.IndicesOptions;
Expand Down Expand Up @@ -1060,6 +1062,85 @@ public void onFailure(Exception e) {
}
}

public void testCollectNodesConcurrentWithSettingsChanges() throws IOException {
final List<DiscoveryNode> knownNodes_c1 = new CopyOnWriteArrayList<>();

try (
var c1N1 = startTransport(
"cluster_1_node_1",
knownNodes_c1,
VersionInformation.CURRENT,
TransportVersion.current(),
Settings.EMPTY
);
var transportService = MockTransportService.createNewService(
Settings.EMPTY,
VersionInformation.CURRENT,
TransportVersion.current(),
threadPool,
null
)
) {
final var c1N1Node = c1N1.getLocalNode();
knownNodes_c1.add(c1N1Node);
final var seedList = List.of(c1N1Node.getAddress().toString());
transportService.start();
transportService.acceptIncomingRequests();

try (RemoteClusterService service = new RemoteClusterService(createSettings("cluster_1", seedList), transportService)) {
service.initializeRemoteClusters();
assertTrue(service.isCrossClusterSearchEnabled());
final var numTasks = between(3, 5);
final var taskLatch = new CountDownLatch(numTasks);

ESTestCase.startInParallel(numTasks, threadNumber -> {
if (threadNumber == 0) {
taskLatch.countDown();
boolean isLinked = true;
while (taskLatch.getCount() != 0) {
final var future = new PlainActionFuture<RemoteClusterService.RemoteClusterConnectionStatus>();
final var settings = createSettings("cluster_1", isLinked ? Collections.emptyList() : seedList);
service.updateRemoteCluster("cluster_1", settings, future);
safeGet(future);
isLinked = isLinked == false;
}
return;
}

// Verify collectNodes() always invokes the listener, even if the node is concurrently being unlinked.
try {
for (int i = 0; i < 10; ++i) {
final var latch = new CountDownLatch(1);
final var exRef = new AtomicReference<Exception>();
service.collectNodes(Set.of("cluster_1"), new LatchedActionListener<>(new ActionListener<>() {
@Override
public void onResponse(BiFunction<String, String, DiscoveryNode> func) {
assertEquals(c1N1Node, func.apply("cluster_1", c1N1Node.getId()));
}

@Override
public void onFailure(Exception e) {
exRef.set(e);
}
}, latch));
safeAwait(latch);
if (exRef.get() != null) {
assertThat(
exRef.get(),
either(instanceOf(TransportException.class)).or(instanceOf(NoSuchRemoteClusterException.class))
.or(instanceOf(AlreadyClosedException.class))
.or(instanceOf(NoSeedNodeLeftException.class))
);
}
}
} finally {
taskLatch.countDown();
}
});
}
}
}

public void testRemoteClusterSkipIfDisconnectedSetting() {
{
Settings settings = Settings.builder()
Expand Down