diff --git a/docs/changelog/92729.yaml b/docs/changelog/92729.yaml new file mode 100644 index 0000000000000..8c4a34c9dcbc3 --- /dev/null +++ b/docs/changelog/92729.yaml @@ -0,0 +1,5 @@ +pr: 92729 +summary: Improve scalability of `BroadcastReplicationActions` +area: Network +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java index a828f6e413d77..790eb1f18641a 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import java.util.List; @@ -46,15 +47,12 @@ public TransportFlushAction( client, actionFilters, indexNameExpressionResolver, - TransportShardFlushAction.TYPE + TransportShardFlushAction.TYPE, + ThreadPool.Names.FLUSH, + 5 // the FLUSH threadpool has at most 5 threads by default, so this should be enough to keep everyone busy ); } - @Override - protected ReplicationResponse newShardResponse() { - return new ReplicationResponse(); - } - @Override protected ShardFlushRequest newShardRequest(FlushRequest request, ShardId shardId) { return new ShardFlushRequest(request, shardId); diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java index ff9f6640b4120..813c4a77fbe27 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import java.util.List; @@ -48,15 +49,12 @@ public TransportRefreshAction( client, actionFilters, indexNameExpressionResolver, - TransportShardRefreshAction.TYPE + TransportShardRefreshAction.TYPE, + ThreadPool.Names.REFRESH, + 10 // the REFRESH threadpool has at most 10 threads by default, so this should be enough to keep everyone busy ); } - @Override - protected ReplicationResponse newShardResponse() { - return new ReplicationResponse(); - } - @Override protected BasicReplicationRequest newShardRequest(RefreshRequest request, ShardId shardId) { BasicReplicationRequest replicationRequest = new BasicReplicationRequest(shardId); diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java index 62a2d3d38e061..dbe15d35e511c 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java @@ -10,30 +10,38 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.DefaultShardOperationFailedException; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.action.support.TransportActions; import org.elasticsearch.action.support.broadcast.BaseBroadcastResponse; import org.elasticsearch.action.support.broadcast.BroadcastRequest; import org.elasticsearch.action.support.broadcast.BroadcastShardOperationFailedException; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.common.util.concurrent.ThrottledIterator; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.transport.Transports; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; +import java.util.Map; /** * Base class for requests that should be executed on all shards of an index or several indices. @@ -48,9 +56,11 @@ public abstract class TransportBroadcastReplicationAction< private final ActionType replicatedBroadcastShardAction; private final ClusterService clusterService; private final IndexNameExpressionResolver indexNameExpressionResolver; + private final String executor; private final NodeClient client; + private final int maxRequestsPerNode; - public TransportBroadcastReplicationAction( + protected TransportBroadcastReplicationAction( String name, Writeable.Reader requestReader, ClusterService clusterService, @@ -58,126 +68,155 @@ public TransportBroadcastReplicationAction( NodeClient client, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, - ActionType replicatedBroadcastShardAction + ActionType replicatedBroadcastShardAction, + String executor, + int maxRequestsPerNode ) { - super(name, transportService, actionFilters, requestReader); + // Explicitly SAME since the REST layer runs this directly via the NodeClient so it doesn't fork even if we tell it to (see #92730) + super(name, transportService, actionFilters, requestReader, ThreadPool.Names.SAME); + this.client = client; this.replicatedBroadcastShardAction = replicatedBroadcastShardAction; this.clusterService = clusterService; this.indexNameExpressionResolver = indexNameExpressionResolver; + this.executor = executor; + this.maxRequestsPerNode = maxRequestsPerNode; } @Override protected void doExecute(Task task, Request request, ActionListener listener) { - final ClusterState clusterState = clusterService.state(); - List shards = shards(request, clusterState); - final CopyOnWriteArrayList shardsResponses = new CopyOnWriteArrayList<>(); - if (shards.size() == 0) { - finishAndNotifyListener(listener, shardsResponses); - } - final CountDown responsesCountDown = new CountDown(shards.size()); - for (final ShardId shardId : shards) { - ActionListener shardActionListener = new ActionListener() { - @Override - public void onResponse(ShardResponse shardResponse) { - shardsResponses.add(shardResponse); - logger.trace("{}: got response from {}", actionName, shardId); - if (responsesCountDown.countDown()) { - finishAndNotifyListener(listener, shardsResponses); - } - } - - @Override - public void onFailure(Exception e) { - logger.trace("{}: got failure from {}", actionName, shardId); - int totalNumCopies = clusterState.getMetadata().getIndexSafe(shardId.getIndex()).getNumberOfReplicas() + 1; - ShardResponse shardResponse = newShardResponse(); - ReplicationResponse.ShardInfo.Failure[] failures; - if (TransportActions.isShardNotAvailableException(e)) { - failures = new ReplicationResponse.ShardInfo.Failure[0]; - } else { - ReplicationResponse.ShardInfo.Failure failure = new ReplicationResponse.ShardInfo.Failure( - shardId, - null, - e, - ExceptionsHelper.status(e), - true - ); - failures = new ReplicationResponse.ShardInfo.Failure[totalNumCopies]; - Arrays.fill(failures, failure); - } - shardResponse.setShardInfo(new ReplicationResponse.ShardInfo(totalNumCopies, 0, failures)); - shardsResponses.add(shardResponse); - if (responsesCountDown.countDown()) { - finishAndNotifyListener(listener, shardsResponses); - } - } - }; - shardExecute(task, request, shardId, shardActionListener); - } + clusterService.threadPool().executor(executor).execute(ActionRunnable.wrap(listener, l -> { + final var clusterState = clusterService.state(); + ExceptionsHelper.reThrowIfNotNull(clusterState.blocks().globalBlockedException(ClusterBlockLevel.METADATA_READ)); + final var context = new Context(task, request, clusterState.metadata().indices(), listener); + ThrottledIterator.run( + shardIds(request, clusterState), + context::processShard, + clusterState.nodes().getDataNodes().size() * maxRequestsPerNode, + () -> {}, + context::finish + ); + })); } protected void shardExecute(Task task, Request request, ShardId shardId, ActionListener shardActionListener) { + assert Transports.assertNotTransportThread("per-shard requests might be high-volume"); ShardRequest shardRequest = newShardRequest(request, shardId); shardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); client.executeLocally(replicatedBroadcastShardAction, shardRequest, shardActionListener); } /** - * @return all shard ids the request should run on + * @return all shard ids on which the request should run; exposed for tests */ - protected List shards(Request request, ClusterState clusterState) { - List shardIds = new ArrayList<>(); - String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(clusterState, request); - for (String index : concreteIndices) { - IndexMetadata indexMetadata = clusterState.metadata().getIndices().get(index); - if (indexMetadata != null) { - final IndexRoutingTable indexRoutingTable = clusterState.getRoutingTable().indicesRouting().get(index); - for (int i = 0; i < indexRoutingTable.size(); i++) { - shardIds.add(indexRoutingTable.shard(i).shardId()); - } - } - } - return shardIds; + Iterator shardIds(Request request, ClusterState clusterState) { + var indexMetadataByName = clusterState.metadata().indices(); + return Iterators.flatMap( + Iterators.forArray(indexNameExpressionResolver.concreteIndexNames(clusterState, request)), + indexName -> indexShardIds(indexMetadataByName.get(indexName)) + ); } - protected abstract ShardResponse newShardResponse(); - - protected abstract ShardRequest newShardRequest(Request request, ShardId shardId); - - private void finishAndNotifyListener(ActionListener listener, CopyOnWriteArrayList shardsResponses) { - logger.trace("{}: got all shard responses", actionName); - int successfulShards = 0; - int failedShards = 0; - int totalNumCopies = 0; - List shardFailures = null; - for (int i = 0; i < shardsResponses.size(); i++) { - ReplicationResponse shardResponse = shardsResponses.get(i); - if (shardResponse == null) { - // non active shard, ignore - } else { - failedShards += shardResponse.getShardInfo().getFailed(); - successfulShards += shardResponse.getShardInfo().getSuccessful(); - totalNumCopies += shardResponse.getShardInfo().getTotal(); - if (shardFailures == null) { - shardFailures = new ArrayList<>(); - } - for (ReplicationResponse.ShardInfo.Failure failure : shardResponse.getShardInfo().getFailures()) { - shardFailures.add( - new DefaultShardOperationFailedException( - new BroadcastShardOperationFailedException(failure.fullShardId(), failure.getCause()) - ) - ); - } - } + private static Iterator indexShardIds(@Nullable IndexMetadata indexMetadata) { + if (indexMetadata == null) { + return Collections.emptyIterator(); + } + var shardIds = new ShardId[indexMetadata.getNumberOfShards()]; + for (int i = 0; i < shardIds.length; i++) { + shardIds[i] = new ShardId(indexMetadata.getIndex(), i); } - listener.onResponse(newResponse(successfulShards, failedShards, totalNumCopies, shardFailures)); + return Iterators.forArray(shardIds); } + protected abstract ShardRequest newShardRequest(Request request, ShardId shardId); + protected abstract Response newResponse( int successfulShards, int failedShards, int totalNumCopies, List shardFailures ); + + private class Context { + private final Task task; + private final Request request; + private final Map indexMetadataByName; + private final ActionListener listener; + + private int totalNumCopies; + private int totalSuccessful; + private final List allFailures = new ArrayList<>(); + + Context(Task task, Request request, Map indexMetadataByName, ActionListener listener) { + this.task = task; + this.request = request; + this.indexMetadataByName = indexMetadataByName; + this.listener = listener; + } + + void processShard(ThrottledIterator.ItemRefs refs, ShardId shardId) { + shardExecute( + task, + request, + shardId, + new ThreadedActionListener<>( + logger, + clusterService.threadPool(), + executor, + ActionListener.releaseAfter(createListener(shardId), refs.acquire()), + false + ) + ); + } + + private ActionListener createListener(ShardId shardId) { + return new ActionListener<>() { + @Override + public void onResponse(ShardResponse shardResponse) { + assert shardResponse != null; + logger.trace("{}: got response from {}", actionName, shardId); + addShardResponse( + shardResponse.getShardInfo().getTotal(), + shardResponse.getShardInfo().getSuccessful(), + Arrays.stream(shardResponse.getShardInfo().getFailures()) + .map( + f -> new DefaultShardOperationFailedException( + new BroadcastShardOperationFailedException(shardId, f.getCause()) + ) + ) + .toList() + ); + } + + @Override + public void onFailure(Exception e) { + logger.trace("{}: got failure from {}", actionName, shardId); + final int numCopies = indexMetadataByName.get(shardId.getIndexName()).getNumberOfReplicas() + 1; + addShardResponse(numCopies, 0, createSyntheticFailures(numCopies, e)); + } + + private List createSyntheticFailures(int numCopies, Exception e) { + if (TransportActions.isShardNotAvailableException(e)) { + return List.of(); + } + + final var failures = new DefaultShardOperationFailedException[numCopies]; + Arrays.fill(failures, new DefaultShardOperationFailedException(new BroadcastShardOperationFailedException(shardId, e))); + return Arrays.asList(failures); + } + }; + } + + private synchronized void addShardResponse(int numCopies, int successful, List failures) { + totalNumCopies += numCopies; + totalSuccessful += successful; + allFailures.addAll(failures); + } + + void finish() { + // no need for synchronized here, the ThrottledIterator guarantees that all the addShardResponse calls happen-before this point + logger.trace("{}: got all shard responses", actionName); + listener.onResponse(newResponse(totalSuccessful, allFailures.size(), totalNumCopies, allFailures)); + } + } } diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java new file mode 100644 index 0000000000000..ebc11f383eed0 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; + +import java.util.Iterator; +import java.util.Objects; +import java.util.concurrent.Semaphore; +import java.util.function.BiConsumer; + +public class ThrottledIterator implements Releasable { + + private static final Logger logger = LogManager.getLogger(ThrottledIterator.class); + + /** + * Iterate through the given collection, performing an operation on each item which may fork background tasks, but with a limit on the + * number of such background tasks running concurrently to avoid overwhelming the rest of the system (e.g. starving other work of access + * to an executor). + * + * @param iterator The items to iterate. May be accessed by multiple threads, but accesses are all protected by synchronizing on itself. + * @param itemConsumer The operation to perform on each item. Each operation receives a {@link RefCounted} which can be used to track + * the execution of any background tasks spawned for this item. This operation may run on the thread which + * originally called {@link #run}, if this method has not yet returned. Otherwise it will run on a thread on which a + * background task previously called {@link RefCounted#decRef()} on its ref count. This operation should not throw + * any exceptions. + * @param maxConcurrency The maximum number of ongoing operations at any time. + * @param onItemCompletion Executed when each item is completed, which can be used for instance to report on progress. Must not throw + * exceptions. + * @param onCompletion Executed when all items are completed. + */ + public static void run( + Iterator iterator, + BiConsumer itemConsumer, + int maxConcurrency, + Runnable onItemCompletion, + Runnable onCompletion + ) { + try (var throttledIterator = new ThrottledIterator<>(iterator, itemConsumer, maxConcurrency, onItemCompletion, onCompletion)) { + throttledIterator.run(); + } + } + + private final RefCounted throttleRefs; + private final Iterator iterator; + private final BiConsumer itemConsumer; + private final Semaphore permits; + private final Runnable onItemCompletion; + + private ThrottledIterator( + Iterator iterator, + BiConsumer itemConsumer, + int maxConcurrency, + Runnable onItemCompletion, + Runnable onCompletion + ) { + this.iterator = Objects.requireNonNull(iterator); + this.itemConsumer = Objects.requireNonNull(itemConsumer); + if (maxConcurrency <= 0) { + throw new IllegalArgumentException("maxConcurrency must be positive"); + } + this.permits = new Semaphore(maxConcurrency); + this.onItemCompletion = Objects.requireNonNull(onItemCompletion); + this.throttleRefs = AbstractRefCounted.of(onCompletion); + } + + private void run() { + while (permits.tryAcquire()) { + final T item; + synchronized (iterator) { + if (iterator.hasNext()) { + item = iterator.next(); + } else { + permits.release(); + return; + } + } + try (var itemRefs = new ItemRefCounted()) { + itemConsumer.accept(itemRefs, item); + } catch (Exception e) { + logger.error(Strings.format("exception when processing [%s] with [%s]", item, itemConsumer), e); + assert false : e; + } + } + } + + @Override + public void close() { + throttleRefs.decRef(); + } + + public interface ItemRefs { + Releasable acquire(); + } + + // A RefCounted for a single item, including protection against calling back into run() if it's created and closed within a single + // invocation of run(). + private class ItemRefCounted extends AbstractRefCounted implements Releasable, ItemRefs { + private boolean isRecursive = true; + + ItemRefCounted() { + throttleRefs.incRef(); + } + + @Override + protected void closeInternal() { + try { + onItemCompletion.run(); + } catch (Exception e) { + logger.error("exception in onItemCompletion", e); + assert false : e; + } finally { + permits.release(); + try { + // Someone must now pick up the next item. Here we might be called from the run() invocation which started processing + // the just-completed item (via close() -> decRef()) if that item's processing didn't fork or all its forked tasks + // finished first. If so, there's no need to call run() here, we can just return and the next iteration of the run() + // loop will continue the processing; moreover calling run() in this situation could lead to a stack overflow. However + // if we're not within that run() invocation then ... + if (isRecursive() == false) { + // ... we're not within any other run() invocation either, so it's safe (and necessary) to call run() here. + run(); + } + } finally { + throttleRefs.decRef(); + } + } + } + + // Note on blocking: we call both of these synchronized methods exactly once (and must enter close() before calling isRecursive()). + // If close() releases the last ref and calls closeInternal(), and hence isRecursive(), then there's no other threads involved and + // hence no blocking. In contrast if close() doesn't release the last ref then it exits immediately, so the call to isRecursive() + // will proceed without delay in this case too. + + private synchronized boolean isRecursive() { + return isRecursive; + } + + @Override + public synchronized void close() { + decRef(); + isRecursive = false; + } + + @Override + public Releasable acquire() { + if (tryIncRef()) { + return this::decRef; + } else { + assert false; + throw new IllegalStateException("already closed"); + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java index 51363e76d8adb..d0ecf6c2abb70 100644 --- a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java @@ -7,6 +7,7 @@ */ package org.elasticsearch.action.support.replication; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NoShardAvailableActionException; @@ -20,10 +21,16 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.broadcast.BaseBroadcastResponse; import org.elasticsearch.action.support.broadcast.BroadcastRequest; +import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.routing.ShardRoutingState; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.network.NetworkService; @@ -62,6 +69,9 @@ import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.state; import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithAssignedPrimariesAndOneReplica; import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithNoShard; +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_VERSION_CREATED; import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; import static org.elasticsearch.test.ClusterServiceUtils.setState; import static org.hamcrest.Matchers.equalTo; @@ -75,6 +85,7 @@ public class BroadcastReplicationTests extends ESTestCase { private ClusterService clusterService; private TransportService transportService; private TestBroadcastReplicationAction broadcastReplicationAction; + private static final int MAX_REQUESTS_PER_NODE = 10; @BeforeClass public static void beforeClass() { @@ -210,14 +221,14 @@ public void testResultCombine() throws InterruptedException, ExecutionException, assertBroadcastResponse(2 * numShards, succeeded, failed, response.get(), Exception.class); } - public void testNoShards() throws InterruptedException, ExecutionException, IOException { + public void testNoShards() { setState(clusterService, stateWithNoShard()); logger.debug("--> using initial state:\n{}", clusterService.state()); BaseBroadcastResponse response = executeAndAssertImmediateResponse(broadcastReplicationAction, new DummyBroadcastRequest()); assertBroadcastResponse(0, 0, 0, response, null); } - public void testShardsList() throws InterruptedException, ExecutionException { + public void testShardsIteratorOneShard() { final String index = "test"; final ShardId shardId = new ShardId(index, "_na_", 0); ClusterState clusterState = state( @@ -227,9 +238,96 @@ public void testShardsList() throws InterruptedException, ExecutionException { ShardRoutingState.UNASSIGNED ); logger.debug("--> using initial state:\n{}", clusterService.state()); - List shards = broadcastReplicationAction.shards(new DummyBroadcastRequest().indices(shardId.getIndexName()), clusterState); - assertThat(shards.size(), equalTo(1)); - assertThat(shards.get(0), equalTo(shardId)); + var shards = broadcastReplicationAction.shardIds(new DummyBroadcastRequest().indices(shardId.getIndexName()), clusterState); + assertTrue(shards.hasNext()); + assertEquals(shardId, shards.next()); + assertFalse(shards.hasNext()); + } + + public void testShardsIterator() { + final var metadataBuilder = Metadata.builder(); + final var indexCount = between(1, 5); + for (int i = 0; i < indexCount; i++) { + metadataBuilder.put( + new IndexMetadata.Builder("index-" + i).settings( + Settings.builder() + .put(SETTING_VERSION_CREATED, Version.CURRENT) + .put(SETTING_NUMBER_OF_SHARDS, between(1, 3)) + .put(SETTING_NUMBER_OF_REPLICAS, between(0, 2)) + ) + ); + } + final var clusterState = ClusterState.builder(ClusterName.DEFAULT).metadata(metadataBuilder).build(); + final var allIndexNames = clusterState.metadata().indices().keySet(); + final var indexNames = randomSubsetOf(between(1, indexCount), allIndexNames).toArray(Strings.EMPTY_ARRAY); + + final var expectedShards = new HashSet(); + for (final var indexName : indexNames) { + final var indexMetadata = clusterState.metadata().index(indexName); + for (int i = 0; i < indexMetadata.getNumberOfShards(); i++) { + expectedShards.add(new ShardId(indexMetadata.getIndex(), i)); + } + } + + final var actualShards = new HashSet(); + final var iterator = broadcastReplicationAction.shardIds(new DummyBroadcastRequest().indices(indexNames), clusterState); + while (iterator.hasNext()) { + actualShards.add(iterator.next()); + } + + assertEquals(expectedShards, actualShards); + } + + public void testThrottling() { + final var shardCount = between(MAX_REQUESTS_PER_NODE, MAX_REQUESTS_PER_NODE * 5); + final var replicaCount = between(0, 2); + setState( + clusterService, + ClusterState.builder(ClusterName.DEFAULT) + .metadata( + Metadata.builder() + .put( + new IndexMetadata.Builder("test").settings( + Settings.builder() + .put(SETTING_VERSION_CREATED, Version.CURRENT) + .put(SETTING_NUMBER_OF_SHARDS, shardCount) + .put(SETTING_NUMBER_OF_REPLICAS, replicaCount) + ) + ) + ) + .nodes(DiscoveryNodes.builder().add(new DiscoveryNode("test", buildNewFakeTransportAddress(), Version.CURRENT))) + .build() + ); + + PlainActionFuture future = PlainActionFuture.newFuture(); + ActionTestUtils.execute(broadcastReplicationAction, null, new DummyBroadcastRequest().indices("test"), future); + + final var maxOutstandingRequests = clusterService.state().nodes().getDataNodes().size() * MAX_REQUESTS_PER_NODE; + assertThat(broadcastReplicationAction.capturedShardRequests.size(), equalTo(maxOutstandingRequests)); + assertFalse(future.isDone()); + final boolean[] handled = new boolean[clusterService.state().metadata().index("test").getNumberOfShards()]; + + var successes = 0; + while (broadcastReplicationAction.capturedShardRequests.isEmpty() == false) { + assertThat(broadcastReplicationAction.capturedShardRequests.size(), lessThanOrEqualTo(maxOutstandingRequests)); + final var request = randomFrom(broadcastReplicationAction.capturedShardRequests); + assertTrue(broadcastReplicationAction.capturedShardRequests.remove(request)); + assertFalse(handled[request.v1().id()]); + handled[request.v1().id()] = true; + if (randomBoolean()) { + successes += 1; + ReplicationResponse replicationResponse = new ReplicationResponse(); + replicationResponse.setShardInfo(new ReplicationResponse.ShardInfo(replicaCount + 1, replicaCount + 1)); + request.v2().onResponse(replicationResponse); + } else { + request.v2().onFailure(new ElasticsearchException("unexpected")); + } + } + + var response = future.actionGet(10, TimeUnit.SECONDS); + assertEquals(shardCount * (replicaCount + 1), response.getTotalShards()); + assertEquals(successes * (replicaCount + 1), response.getSuccessfulShards()); + assertEquals((shardCount - successes) * (replicaCount + 1), response.getFailedShards()); } private class TestBroadcastReplicationAction extends TransportBroadcastReplicationAction< @@ -254,15 +352,12 @@ private class TestBroadcastReplicationAction extends TransportBroadcastReplicati null, actionFilters, indexNameExpressionResolver, - null + null, + ThreadPool.Names.SAME, + MAX_REQUESTS_PER_NODE ); } - @Override - protected ReplicationResponse newShardResponse() { - return new ReplicationResponse(); - } - @Override protected BasicReplicationRequest newShardRequest(DummyBroadcastRequest request, ShardId shardId) { return new BasicReplicationRequest(shardId); diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/ThrottledIteratorTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/ThrottledIteratorTests.java new file mode 100644 index 0000000000000..0cde4e3c9632d --- /dev/null +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/ThrottledIteratorTests.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.apache.lucene.tests.util.LuceneTestCase; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BooleanSupplier; +import java.util.stream.IntStream; + +public class ThrottledIteratorTests extends ESTestCase { + private static final String CONSTRAINED = "constrained"; + private static final String RELAXED = "relaxed"; + + public void testConcurrency() throws InterruptedException { + final var maxConstrainedThreads = between(1, 3); + final var maxRelaxedThreads = between(1, 100); + final var constrainedQueue = between(3, 6); + final var threadPool = new TestThreadPool( + "test", + new FixedExecutorBuilder(Settings.EMPTY, CONSTRAINED, maxConstrainedThreads, constrainedQueue, CONSTRAINED, false), + new ScalingExecutorBuilder(RELAXED, 1, maxRelaxedThreads, TimeValue.timeValueSeconds(30), true) + ); + try { + final var items = between(1, 10000); // large enough that inadvertent recursion will trigger a StackOverflowError + final var itemStartLatch = new CountDownLatch(items); + final var completedItems = new AtomicInteger(); + final var maxConcurrency = between(1, (constrainedQueue + maxConstrainedThreads) * 2); + final var itemPermits = new Semaphore(maxConcurrency); + final var completionLatch = new CountDownLatch(1); + final BooleanSupplier forkSupplier = randomFrom( + () -> false, + ESTestCase::randomBoolean, + LuceneTestCase::rarely, + LuceneTestCase::usually, + () -> true + ); + final var blockPermits = new Semaphore(between(0, Math.min(maxRelaxedThreads, maxConcurrency) - 1)); + + ThrottledIterator.run(IntStream.range(0, items).boxed().iterator(), (refs, item) -> { + assertTrue(itemPermits.tryAcquire()); + if (forkSupplier.getAsBoolean()) { + // noinspection resource + var ref = refs.acquire(); + final var executor = randomFrom(CONSTRAINED, RELAXED); + threadPool.executor(executor).execute(new AbstractRunnable() { + + @Override + public void onRejection(Exception e) { + assertEquals(CONSTRAINED, executor); + itemStartLatch.countDown(); + } + + @Override + protected void doRun() { + itemStartLatch.countDown(); + if (RELAXED.equals(executor) && randomBoolean() && blockPermits.tryAcquire()) { + // simulate at most (maxConcurrency-1) long-running operations, to demonstrate that they don't + // hold up the processing of the other operations + try { + assertTrue(itemStartLatch.await(30, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + throw new AssertionError("unexpected", e); + } finally { + blockPermits.release(); + } + } + } + + @Override + public void onAfter() { + itemPermits.release(); + ref.close(); + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError("unexpected", e); + } + }); + } else { + itemStartLatch.countDown(); + itemPermits.release(); + } + }, maxConcurrency, completedItems::incrementAndGet, completionLatch::countDown); + + assertTrue(completionLatch.await(30, TimeUnit.SECONDS)); + assertEquals(items, completedItems.get()); + assertTrue(itemPermits.tryAcquire(maxConcurrency)); + assertTrue(itemStartLatch.await(0, TimeUnit.SECONDS)); + } finally { + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + } +}