From d5f4861f14880d26adb19c3a4a0c8f9fce63f960 Mon Sep 17 00:00:00 2001 From: David Turner Date: Fri, 6 Jan 2023 16:53:40 +0000 Subject: [PATCH 1/7] Improve scalability of BroadcastReplicationActions BroadcastReplicationAction derivatives (`POST //_refresh` and `POST //_flush`) are pretty inefficient when targeting high shard counts due to how `TransportBroadcastReplicationAction` works: - It computes the list of all target shards up-front on the calling (transport) thread. - It eagerly sends one request for every target shard in a tight loop on the calling (transport) thread. - It accumulates responses in a `CopyOnWriteArrayList` which takes quadratic work to populate, even though nothing reads this list until it's fully populated. - It then mostly discards the accumulated responses, keeping only the total number of shards, the number of successful shards, and a list of any failures. - Each failure is wrapped up in a `ReplicationResponse.ShardInfo.Failure` but then unwrapped at the end to be re-wrapped in a `DefaultShardOperationFailedException`. This commit fixes all this: - It avoids allocating a list of all target shards, instead iterating over the target indices and generating shard IDs on the fly. - The computation of the list of shards, and the sending of the per-shard requests, now happens on the relevant threadpool (`REFRESH` or `FLUSH`) rather than a transport thread. - The per-shard requests are now throttled, with a meaningful yet fairly generous concurrency limit of `#(data nodes) * 10`. - Rather than accumulating the full responses for later processing we track the counts and failures directly. - The failures are tracked in a regular `ArrayList`, avoiding the accidentally-quadratic complexity. - The failures are tracked in their final form, skipping the unwrap-and-rewrap step at the end. Relates #77466 --- .../elasticsearch/action/ActionListener.java | 31 +++ .../indices/flush/TransportFlushAction.java | 9 +- .../flush/TransportShardFlushAction.java | 6 + .../refresh/TransportRefreshAction.java | 9 +- .../refresh/TransportShardRefreshAction.java | 6 + .../TransportBroadcastReplicationAction.java | 221 ++++++++++-------- .../util/concurrent/ThrottledIterator.java | 166 +++++++++++++ .../action/ActionListenerTests.java | 83 +++++++ .../BroadcastReplicationTests.java | 114 ++++++++- .../concurrent/ThrottledIteratorTests.java | 111 +++++++++ 10 files changed, 636 insertions(+), 120 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java create mode 100644 server/src/test/java/org/elasticsearch/common/util/concurrent/ThrottledIteratorTests.java diff --git a/server/src/main/java/org/elasticsearch/action/ActionListener.java b/server/src/main/java/org/elasticsearch/action/ActionListener.java index 14425338ecb7a..37f9e9b41630c 100644 --- a/server/src/main/java/org/elasticsearch/action/ActionListener.java +++ b/server/src/main/java/org/elasticsearch/action/ActionListener.java @@ -13,6 +13,8 @@ import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.core.CheckedRunnable; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; import java.util.ArrayList; import java.util.List; @@ -252,6 +254,13 @@ public String toString() { } } + /** + * Creates a listener which releases the given resource on completion (whether success or failure) + */ + static ActionListener releasing(Releasable releasable) { + return wrap(runnableFromReleasable(releasable)); + } + /** * Creates a listener that listens for a response (or failure) and executes the * corresponding runnable when the response (or failure) is received. @@ -362,6 +371,14 @@ static ActionListener runAfter(ActionListener del return new RunAfterActionListener<>(delegate, runAfter); } + /** + * Wraps a given listener and returns a new listener which releases the provided {@code releaseAfter} + * resource when the listener is notified via either {@code #onResponse} or {@code #onFailure}. + */ + static ActionListener releaseAfter(ActionListener delegate, Releasable releaseAfter) { + return new RunAfterActionListener<>(delegate, runnableFromReleasable(releaseAfter)); + } + final class RunAfterActionListener extends Delegating { private final Runnable runAfter; @@ -498,4 +515,18 @@ static void completeWith(ActionListener listener, CheckedSu throw ex; } } + + private static Runnable runnableFromReleasable(Releasable releasable) { + return new Runnable() { + @Override + public void run() { + Releasables.closeExpectNoException(releasable); + } + + @Override + public String toString() { + return "release[" + releasable + "]"; + } + }; + } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java index a828f6e413d77..6fce79e31b911 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import java.util.List; @@ -46,15 +47,11 @@ public TransportFlushAction( client, actionFilters, indexNameExpressionResolver, - TransportShardFlushAction.TYPE + TransportShardFlushAction.TYPE, + ThreadPool.Names.FLUSH ); } - @Override - protected ReplicationResponse newShardResponse() { - return new ReplicationResponse(); - } - @Override protected ShardFlushRequest newShardRequest(FlushRequest request, ShardId shardId) { return new ShardFlushRequest(request, shardId); diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java index 32e67e95d1936..22c509deccc2f 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.support.replication.TransportReplicationAction; import org.elasticsearch.cluster.action.shard.ShardStateAction; +import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.StreamInput; @@ -68,6 +69,11 @@ public TransportShardFlushAction( ); } + @Override + protected ClusterBlockLevel globalBlockLevel() { + return ClusterBlockLevel.METADATA_READ; + } + @Override protected ReplicationResponse newResponseInstance(StreamInput in) throws IOException { return new ReplicationResponse(in); diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java index ff9f6640b4120..ceb940502da5d 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import java.util.List; @@ -48,15 +49,11 @@ public TransportRefreshAction( client, actionFilters, indexNameExpressionResolver, - TransportShardRefreshAction.TYPE + TransportShardRefreshAction.TYPE, + ThreadPool.Names.REFRESH ); } - @Override - protected ReplicationResponse newShardResponse() { - return new ReplicationResponse(); - } - @Override protected BasicReplicationRequest newShardRequest(RefreshRequest request, ShardId shardId) { BasicReplicationRequest replicationRequest = new BasicReplicationRequest(shardId); diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java index 27e185e98a9f4..9bac02530197e 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.support.replication.TransportReplicationAction; import org.elasticsearch.cluster.action.shard.ShardStateAction; +import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.StreamInput; @@ -64,6 +65,11 @@ protected ReplicationResponse newResponseInstance(StreamInput in) throws IOExcep return new ReplicationResponse(in); } + @Override + protected ClusterBlockLevel globalBlockLevel() { + return ClusterBlockLevel.METADATA_READ; + } + @Override protected void shardOperationOnPrimary( BasicReplicationRequest shardRequest, diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java index 62a2d3d38e061..da33845c7c273 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java @@ -8,12 +8,12 @@ package org.elasticsearch.action.support.replication; -import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.DefaultShardOperationFailedException; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.action.support.TransportActions; import org.elasticsearch.action.support.broadcast.BaseBroadcastResponse; import org.elasticsearch.action.support.broadcast.BroadcastRequest; @@ -22,18 +22,22 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.common.util.concurrent.ThrottledIterator; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.transport.Transports; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; +import java.util.Map; /** * Base class for requests that should be executed on all shards of an index or several indices. @@ -48,9 +52,10 @@ public abstract class TransportBroadcastReplicationAction< private final ActionType replicatedBroadcastShardAction; private final ClusterService clusterService; private final IndexNameExpressionResolver indexNameExpressionResolver; + private final String executor; private final NodeClient client; - public TransportBroadcastReplicationAction( + protected TransportBroadcastReplicationAction( String name, Writeable.Reader requestReader, ClusterService clusterService, @@ -58,126 +63,148 @@ public TransportBroadcastReplicationAction( NodeClient client, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, - ActionType replicatedBroadcastShardAction + ActionType replicatedBroadcastShardAction, + String executor ) { - super(name, transportService, actionFilters, requestReader); + super(name, transportService, actionFilters, requestReader, executor); this.client = client; this.replicatedBroadcastShardAction = replicatedBroadcastShardAction; this.clusterService = clusterService; this.indexNameExpressionResolver = indexNameExpressionResolver; + this.executor = executor; } @Override protected void doExecute(Task task, Request request, ActionListener listener) { - final ClusterState clusterState = clusterService.state(); - List shards = shards(request, clusterState); - final CopyOnWriteArrayList shardsResponses = new CopyOnWriteArrayList<>(); - if (shards.size() == 0) { - finishAndNotifyListener(listener, shardsResponses); - } - final CountDown responsesCountDown = new CountDown(shards.size()); - for (final ShardId shardId : shards) { - ActionListener shardActionListener = new ActionListener() { - @Override - public void onResponse(ShardResponse shardResponse) { - shardsResponses.add(shardResponse); - logger.trace("{}: got response from {}", actionName, shardId); - if (responsesCountDown.countDown()) { - finishAndNotifyListener(listener, shardsResponses); - } - } - - @Override - public void onFailure(Exception e) { - logger.trace("{}: got failure from {}", actionName, shardId); - int totalNumCopies = clusterState.getMetadata().getIndexSafe(shardId.getIndex()).getNumberOfReplicas() + 1; - ShardResponse shardResponse = newShardResponse(); - ReplicationResponse.ShardInfo.Failure[] failures; - if (TransportActions.isShardNotAvailableException(e)) { - failures = new ReplicationResponse.ShardInfo.Failure[0]; - } else { - ReplicationResponse.ShardInfo.Failure failure = new ReplicationResponse.ShardInfo.Failure( - shardId, - null, - e, - ExceptionsHelper.status(e), - true - ); - failures = new ReplicationResponse.ShardInfo.Failure[totalNumCopies]; - Arrays.fill(failures, failure); - } - shardResponse.setShardInfo(new ReplicationResponse.ShardInfo(totalNumCopies, 0, failures)); - shardsResponses.add(shardResponse); - if (responsesCountDown.countDown()) { - finishAndNotifyListener(listener, shardsResponses); - } - } - }; - shardExecute(task, request, shardId, shardActionListener); - } + final var clusterState = clusterService.state(); + final var context = new Context(task, request, clusterState.metadata().indices(), listener); + ThrottledIterator.run( + shardIds(request, clusterState), + context::processShard, + clusterState.nodes().getDataNodes().size() * 10, + () -> {}, + context::finish + ); } protected void shardExecute(Task task, Request request, ShardId shardId, ActionListener shardActionListener) { + assert Transports.assertNotTransportThread("per-shard requests might be high-volume"); ShardRequest shardRequest = newShardRequest(request, shardId); shardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); client.executeLocally(replicatedBroadcastShardAction, shardRequest, shardActionListener); } /** - * @return all shard ids the request should run on + * @return all shard ids on which the request should run; exposed for tests */ - protected List shards(Request request, ClusterState clusterState) { - List shardIds = new ArrayList<>(); - String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(clusterState, request); - for (String index : concreteIndices) { - IndexMetadata indexMetadata = clusterState.metadata().getIndices().get(index); - if (indexMetadata != null) { - final IndexRoutingTable indexRoutingTable = clusterState.getRoutingTable().indicesRouting().get(index); - for (int i = 0; i < indexRoutingTable.size(); i++) { - shardIds.add(indexRoutingTable.shard(i).shardId()); - } - } - } - return shardIds; + Iterator shardIds(Request request, ClusterState clusterState) { + var indexMetadataByName = clusterState.metadata().indices(); + return Iterators.flatMap( + Iterators.forArray(indexNameExpressionResolver.concreteIndexNames(clusterState, request)), + indexName -> indexShardIds(indexMetadataByName.get(indexName)) + ); } - protected abstract ShardResponse newShardResponse(); - - protected abstract ShardRequest newShardRequest(Request request, ShardId shardId); - - private void finishAndNotifyListener(ActionListener listener, CopyOnWriteArrayList shardsResponses) { - logger.trace("{}: got all shard responses", actionName); - int successfulShards = 0; - int failedShards = 0; - int totalNumCopies = 0; - List shardFailures = null; - for (int i = 0; i < shardsResponses.size(); i++) { - ReplicationResponse shardResponse = shardsResponses.get(i); - if (shardResponse == null) { - // non active shard, ignore - } else { - failedShards += shardResponse.getShardInfo().getFailed(); - successfulShards += shardResponse.getShardInfo().getSuccessful(); - totalNumCopies += shardResponse.getShardInfo().getTotal(); - if (shardFailures == null) { - shardFailures = new ArrayList<>(); - } - for (ReplicationResponse.ShardInfo.Failure failure : shardResponse.getShardInfo().getFailures()) { - shardFailures.add( - new DefaultShardOperationFailedException( - new BroadcastShardOperationFailedException(failure.fullShardId(), failure.getCause()) - ) - ); - } - } + private static Iterator indexShardIds(@Nullable IndexMetadata indexMetadata) { + if (indexMetadata == null) { + return Collections.emptyIterator(); + } + var shardIds = new ShardId[indexMetadata.getNumberOfShards()]; + for (int i = 0; i < shardIds.length; i++) { + shardIds[i] = new ShardId(indexMetadata.getIndex(), i); } - listener.onResponse(newResponse(successfulShards, failedShards, totalNumCopies, shardFailures)); + return Iterators.forArray(shardIds); } + protected abstract ShardRequest newShardRequest(Request request, ShardId shardId); + protected abstract Response newResponse( int successfulShards, int failedShards, int totalNumCopies, List shardFailures ); + + private class Context { + private final Task task; + private final Request request; + private final Map indexMetadataByName; + private final ActionListener listener; + + private int totalNumCopies; + private int totalSuccessful; + private final List allFailures = new ArrayList<>(); + + Context(Task task, Request request, Map indexMetadataByName, ActionListener listener) { + this.task = task; + this.request = request; + this.indexMetadataByName = indexMetadataByName; + this.listener = listener; + } + + void processShard(ThrottledIterator.ItemRefs refs, ShardId shardId) { + shardExecute( + task, + request, + shardId, + new ThreadedActionListener<>( + logger, + clusterService.threadPool(), + executor, + ActionListener.releaseAfter(createListener(shardId), refs.acquire()), + false + ) + ); + } + + private ActionListener createListener(ShardId shardId) { + return new ActionListener<>() { + @Override + public void onResponse(ShardResponse shardResponse) { + assert shardResponse != null; + logger.trace("{}: got response from {}", actionName, shardId); + addShardResponse( + shardResponse.getShardInfo().getTotal(), + shardResponse.getShardInfo().getSuccessful(), + Arrays.stream(shardResponse.getShardInfo().getFailures()) + .map( + f -> new DefaultShardOperationFailedException( + new BroadcastShardOperationFailedException(shardId, f.getCause()) + ) + ) + .toList() + ); + } + + @Override + public void onFailure(Exception e) { + logger.trace("{}: got failure from {}", actionName, shardId); + final int numCopies = indexMetadataByName.get(shardId.getIndexName()).getNumberOfReplicas() + 1; + addShardResponse(numCopies, 0, createSyntheticFailures(numCopies, e)); + } + + private List createSyntheticFailures(int numCopies, Exception e) { + if (TransportActions.isShardNotAvailableException(e)) { + return List.of(); + } + + final var failures = new DefaultShardOperationFailedException[numCopies]; + Arrays.fill(failures, new DefaultShardOperationFailedException(new BroadcastShardOperationFailedException(shardId, e))); + return Arrays.asList(failures); + } + }; + } + + private synchronized void addShardResponse(int numCopies, int successful, List failures) { + totalNumCopies += numCopies; + totalSuccessful += successful; + allFailures.addAll(failures); + } + + void finish() { + // no need for synchronized here, the ThrottledIterator guarantees that all the addShardResponse calls happen-before this point + logger.trace("{}: got all shard responses", actionName); + listener.onResponse(newResponse(totalSuccessful, allFailures.size(), totalNumCopies, allFailures)); + } + } } diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java new file mode 100644 index 0000000000000..ebc11f383eed0 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; + +import java.util.Iterator; +import java.util.Objects; +import java.util.concurrent.Semaphore; +import java.util.function.BiConsumer; + +public class ThrottledIterator implements Releasable { + + private static final Logger logger = LogManager.getLogger(ThrottledIterator.class); + + /** + * Iterate through the given collection, performing an operation on each item which may fork background tasks, but with a limit on the + * number of such background tasks running concurrently to avoid overwhelming the rest of the system (e.g. starving other work of access + * to an executor). + * + * @param iterator The items to iterate. May be accessed by multiple threads, but accesses are all protected by synchronizing on itself. + * @param itemConsumer The operation to perform on each item. Each operation receives a {@link RefCounted} which can be used to track + * the execution of any background tasks spawned for this item. This operation may run on the thread which + * originally called {@link #run}, if this method has not yet returned. Otherwise it will run on a thread on which a + * background task previously called {@link RefCounted#decRef()} on its ref count. This operation should not throw + * any exceptions. + * @param maxConcurrency The maximum number of ongoing operations at any time. + * @param onItemCompletion Executed when each item is completed, which can be used for instance to report on progress. Must not throw + * exceptions. + * @param onCompletion Executed when all items are completed. + */ + public static void run( + Iterator iterator, + BiConsumer itemConsumer, + int maxConcurrency, + Runnable onItemCompletion, + Runnable onCompletion + ) { + try (var throttledIterator = new ThrottledIterator<>(iterator, itemConsumer, maxConcurrency, onItemCompletion, onCompletion)) { + throttledIterator.run(); + } + } + + private final RefCounted throttleRefs; + private final Iterator iterator; + private final BiConsumer itemConsumer; + private final Semaphore permits; + private final Runnable onItemCompletion; + + private ThrottledIterator( + Iterator iterator, + BiConsumer itemConsumer, + int maxConcurrency, + Runnable onItemCompletion, + Runnable onCompletion + ) { + this.iterator = Objects.requireNonNull(iterator); + this.itemConsumer = Objects.requireNonNull(itemConsumer); + if (maxConcurrency <= 0) { + throw new IllegalArgumentException("maxConcurrency must be positive"); + } + this.permits = new Semaphore(maxConcurrency); + this.onItemCompletion = Objects.requireNonNull(onItemCompletion); + this.throttleRefs = AbstractRefCounted.of(onCompletion); + } + + private void run() { + while (permits.tryAcquire()) { + final T item; + synchronized (iterator) { + if (iterator.hasNext()) { + item = iterator.next(); + } else { + permits.release(); + return; + } + } + try (var itemRefs = new ItemRefCounted()) { + itemConsumer.accept(itemRefs, item); + } catch (Exception e) { + logger.error(Strings.format("exception when processing [%s] with [%s]", item, itemConsumer), e); + assert false : e; + } + } + } + + @Override + public void close() { + throttleRefs.decRef(); + } + + public interface ItemRefs { + Releasable acquire(); + } + + // A RefCounted for a single item, including protection against calling back into run() if it's created and closed within a single + // invocation of run(). + private class ItemRefCounted extends AbstractRefCounted implements Releasable, ItemRefs { + private boolean isRecursive = true; + + ItemRefCounted() { + throttleRefs.incRef(); + } + + @Override + protected void closeInternal() { + try { + onItemCompletion.run(); + } catch (Exception e) { + logger.error("exception in onItemCompletion", e); + assert false : e; + } finally { + permits.release(); + try { + // Someone must now pick up the next item. Here we might be called from the run() invocation which started processing + // the just-completed item (via close() -> decRef()) if that item's processing didn't fork or all its forked tasks + // finished first. If so, there's no need to call run() here, we can just return and the next iteration of the run() + // loop will continue the processing; moreover calling run() in this situation could lead to a stack overflow. However + // if we're not within that run() invocation then ... + if (isRecursive() == false) { + // ... we're not within any other run() invocation either, so it's safe (and necessary) to call run() here. + run(); + } + } finally { + throttleRefs.decRef(); + } + } + } + + // Note on blocking: we call both of these synchronized methods exactly once (and must enter close() before calling isRecursive()). + // If close() releases the last ref and calls closeInternal(), and hence isRecursive(), then there's no other threads involved and + // hence no blocking. In contrast if close() doesn't release the last ref then it exits immediately, so the call to isRecursive() + // will proceed without delay in this case too. + + private synchronized boolean isRecursive() { + return isRecursive; + } + + @Override + public synchronized void close() { + decRef(); + isRecursive = false; + } + + @Override + public Releasable acquire() { + if (tryIncRef()) { + return this::decRef; + } else { + assert false; + throw new IllegalStateException("already closed"); + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java index 188e3e2915d51..08ef030e90706 100644 --- a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java +++ b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.core.Releasable; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -22,6 +23,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -441,4 +443,85 @@ public void onFailure(Exception e) { mapped.onFailure(new IllegalStateException()); assertThat(exReference.get(), instanceOf(IllegalStateException.class)); } + + public void testReleasing() { + runReleasingTest(true); + runReleasingTest(false); + } + + private static void runReleasingTest(boolean successResponse) { + final AtomicBoolean releasedFlag = new AtomicBoolean(); + final ActionListener l = ActionListener.releasing(makeReleasable(releasedFlag)); + assertThat(l.toString(), containsString("release[test releasable]}")); + completeListener(successResponse, l); + assertTrue(releasedFlag.get()); + } + + private static void completeListener(boolean successResponse, ActionListener listener) { + if (successResponse) { + try { + listener.onResponse(null); + } catch (Exception e) { + // ok + } + } else { + listener.onFailure(new RuntimeException("simulated")); + } + } + + public void testReleaseAfter() { + runReleaseAfterTest(true, false); + runReleaseAfterTest(true, true); + runReleaseAfterTest(false, false); + } + + private static void runReleaseAfterTest(boolean successResponse, final boolean throwFromOnResponse) { + final AtomicBoolean released = new AtomicBoolean(); + final ActionListener l = ActionListener.releaseAfter(new ActionListener<>() { + @Override + public void onResponse(Void unused) { + if (throwFromOnResponse) { + throw new RuntimeException("onResponse"); + } + } + + @Override + public void onFailure(Exception e) { + // ok + } + + @Override + public String toString() { + return "test listener"; + } + }, makeReleasable(released)); + assertThat(l.toString(), containsString("test listener/release[test releasable]")); + + if (successResponse) { + try { + l.onResponse(null); + } catch (Exception e) { + // ok + } + } else { + l.onFailure(new RuntimeException("supplied")); + } + + assertTrue(released.get()); + } + + private static Releasable makeReleasable(AtomicBoolean releasedFlag) { + return new Releasable() { + @Override + public void close() { + assertTrue(releasedFlag.compareAndSet(false, true)); + } + + @Override + public String toString() { + return "test releasable"; + } + }; + } + } diff --git a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java index 51363e76d8adb..8d49df665db2f 100644 --- a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java @@ -7,6 +7,7 @@ */ package org.elasticsearch.action.support.replication; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NoShardAvailableActionException; @@ -20,10 +21,16 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.broadcast.BaseBroadcastResponse; import org.elasticsearch.action.support.broadcast.BroadcastRequest; +import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.routing.ShardRoutingState; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.network.NetworkService; @@ -62,6 +69,9 @@ import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.state; import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithAssignedPrimariesAndOneReplica; import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithNoShard; +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_VERSION_CREATED; import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; import static org.elasticsearch.test.ClusterServiceUtils.setState; import static org.hamcrest.Matchers.equalTo; @@ -210,14 +220,14 @@ public void testResultCombine() throws InterruptedException, ExecutionException, assertBroadcastResponse(2 * numShards, succeeded, failed, response.get(), Exception.class); } - public void testNoShards() throws InterruptedException, ExecutionException, IOException { + public void testNoShards() { setState(clusterService, stateWithNoShard()); logger.debug("--> using initial state:\n{}", clusterService.state()); BaseBroadcastResponse response = executeAndAssertImmediateResponse(broadcastReplicationAction, new DummyBroadcastRequest()); assertBroadcastResponse(0, 0, 0, response, null); } - public void testShardsList() throws InterruptedException, ExecutionException { + public void testShardsIteratorOneShard() { final String index = "test"; final ShardId shardId = new ShardId(index, "_na_", 0); ClusterState clusterState = state( @@ -227,9 +237,95 @@ public void testShardsList() throws InterruptedException, ExecutionException { ShardRoutingState.UNASSIGNED ); logger.debug("--> using initial state:\n{}", clusterService.state()); - List shards = broadcastReplicationAction.shards(new DummyBroadcastRequest().indices(shardId.getIndexName()), clusterState); - assertThat(shards.size(), equalTo(1)); - assertThat(shards.get(0), equalTo(shardId)); + var shards = broadcastReplicationAction.shardIds(new DummyBroadcastRequest().indices(shardId.getIndexName()), clusterState); + assertTrue(shards.hasNext()); + assertEquals(shardId, shards.next()); + assertFalse(shards.hasNext()); + } + + public void testShardsIterator() { + final var metadataBuilder = Metadata.builder(); + final var indexCount = between(1, 5); + for (int i = 0; i < indexCount; i++) { + metadataBuilder.put( + new IndexMetadata.Builder("index-" + i).settings( + Settings.builder() + .put(SETTING_VERSION_CREATED, Version.CURRENT) + .put(SETTING_NUMBER_OF_SHARDS, between(1, 3)) + .put(SETTING_NUMBER_OF_REPLICAS, between(0, 2)) + ) + ); + } + final var clusterState = ClusterState.builder(ClusterName.DEFAULT).metadata(metadataBuilder).build(); + final var allIndexNames = clusterState.metadata().indices().keySet(); + final var indexNames = randomSubsetOf(between(1, indexCount), allIndexNames).toArray(Strings.EMPTY_ARRAY); + + final var expectedShards = new HashSet(); + for (final var indexName : indexNames) { + final var indexMetadata = clusterState.metadata().index(indexName); + for (int i = 0; i < indexMetadata.getNumberOfShards(); i++) { + expectedShards.add(new ShardId(indexMetadata.getIndex(), i)); + } + } + + final var actualShards = new HashSet(); + final var iterator = broadcastReplicationAction.shardIds(new DummyBroadcastRequest().indices(indexNames), clusterState); + while (iterator.hasNext()) { + actualShards.add(iterator.next()); + } + + assertEquals(expectedShards, actualShards); + } + + public void testThrottling() { + final var replicaCount = between(0, 2); + setState( + clusterService, + ClusterState.builder(ClusterName.DEFAULT) + .metadata( + Metadata.builder() + .put( + new IndexMetadata.Builder("test").settings( + Settings.builder() + .put(SETTING_VERSION_CREATED, Version.CURRENT) + .put(SETTING_NUMBER_OF_SHARDS, 25) + .put(SETTING_NUMBER_OF_REPLICAS, replicaCount) + ) + ) + ) + .nodes(DiscoveryNodes.builder().add(new DiscoveryNode("test", buildNewFakeTransportAddress(), Version.CURRENT))) + .build() + ); + + PlainActionFuture future = PlainActionFuture.newFuture(); + ActionTestUtils.execute(broadcastReplicationAction, null, new DummyBroadcastRequest().indices("test"), future); + + final var maxOutstandingRequests = clusterService.state().nodes().getDataNodes().size() * 10; + assertThat(broadcastReplicationAction.capturedShardRequests.size(), equalTo(maxOutstandingRequests)); + assertFalse(future.isDone()); + final boolean[] handled = new boolean[clusterService.state().metadata().index("test").getNumberOfShards()]; + + var successes = 0; + while (broadcastReplicationAction.capturedShardRequests.isEmpty() == false) { + assertThat(broadcastReplicationAction.capturedShardRequests.size(), lessThanOrEqualTo(maxOutstandingRequests)); + final var request = randomFrom(broadcastReplicationAction.capturedShardRequests); + assertTrue(broadcastReplicationAction.capturedShardRequests.remove(request)); + assertFalse(handled[request.v1().id()]); + handled[request.v1().id()] = true; + if (randomBoolean()) { + successes += 1; + ReplicationResponse replicationResponse = new ReplicationResponse(); + replicationResponse.setShardInfo(new ReplicationResponse.ShardInfo(replicaCount + 1, replicaCount + 1)); + request.v2().onResponse(replicationResponse); + } else { + request.v2().onFailure(new ElasticsearchException("unexpected")); + } + } + + var response = future.actionGet(10, TimeUnit.SECONDS); + assertEquals(25 * (replicaCount + 1), response.getTotalShards()); + assertEquals(successes * (replicaCount + 1), response.getSuccessfulShards()); + assertEquals((25 - successes) * (replicaCount + 1), response.getFailedShards()); } private class TestBroadcastReplicationAction extends TransportBroadcastReplicationAction< @@ -254,15 +350,11 @@ private class TestBroadcastReplicationAction extends TransportBroadcastReplicati null, actionFilters, indexNameExpressionResolver, - null + null, + ThreadPool.Names.SAME ); } - @Override - protected ReplicationResponse newShardResponse() { - return new ReplicationResponse(); - } - @Override protected BasicReplicationRequest newShardRequest(DummyBroadcastRequest request, ShardId shardId) { return new BasicReplicationRequest(shardId); diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/ThrottledIteratorTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/ThrottledIteratorTests.java new file mode 100644 index 0000000000000..0cde4e3c9632d --- /dev/null +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/ThrottledIteratorTests.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.apache.lucene.tests.util.LuceneTestCase; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BooleanSupplier; +import java.util.stream.IntStream; + +public class ThrottledIteratorTests extends ESTestCase { + private static final String CONSTRAINED = "constrained"; + private static final String RELAXED = "relaxed"; + + public void testConcurrency() throws InterruptedException { + final var maxConstrainedThreads = between(1, 3); + final var maxRelaxedThreads = between(1, 100); + final var constrainedQueue = between(3, 6); + final var threadPool = new TestThreadPool( + "test", + new FixedExecutorBuilder(Settings.EMPTY, CONSTRAINED, maxConstrainedThreads, constrainedQueue, CONSTRAINED, false), + new ScalingExecutorBuilder(RELAXED, 1, maxRelaxedThreads, TimeValue.timeValueSeconds(30), true) + ); + try { + final var items = between(1, 10000); // large enough that inadvertent recursion will trigger a StackOverflowError + final var itemStartLatch = new CountDownLatch(items); + final var completedItems = new AtomicInteger(); + final var maxConcurrency = between(1, (constrainedQueue + maxConstrainedThreads) * 2); + final var itemPermits = new Semaphore(maxConcurrency); + final var completionLatch = new CountDownLatch(1); + final BooleanSupplier forkSupplier = randomFrom( + () -> false, + ESTestCase::randomBoolean, + LuceneTestCase::rarely, + LuceneTestCase::usually, + () -> true + ); + final var blockPermits = new Semaphore(between(0, Math.min(maxRelaxedThreads, maxConcurrency) - 1)); + + ThrottledIterator.run(IntStream.range(0, items).boxed().iterator(), (refs, item) -> { + assertTrue(itemPermits.tryAcquire()); + if (forkSupplier.getAsBoolean()) { + // noinspection resource + var ref = refs.acquire(); + final var executor = randomFrom(CONSTRAINED, RELAXED); + threadPool.executor(executor).execute(new AbstractRunnable() { + + @Override + public void onRejection(Exception e) { + assertEquals(CONSTRAINED, executor); + itemStartLatch.countDown(); + } + + @Override + protected void doRun() { + itemStartLatch.countDown(); + if (RELAXED.equals(executor) && randomBoolean() && blockPermits.tryAcquire()) { + // simulate at most (maxConcurrency-1) long-running operations, to demonstrate that they don't + // hold up the processing of the other operations + try { + assertTrue(itemStartLatch.await(30, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + throw new AssertionError("unexpected", e); + } finally { + blockPermits.release(); + } + } + } + + @Override + public void onAfter() { + itemPermits.release(); + ref.close(); + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError("unexpected", e); + } + }); + } else { + itemStartLatch.countDown(); + itemPermits.release(); + } + }, maxConcurrency, completedItems::incrementAndGet, completionLatch::countDown); + + assertTrue(completionLatch.await(30, TimeUnit.SECONDS)); + assertEquals(items, completedItems.get()); + assertTrue(itemPermits.tryAcquire(maxConcurrency)); + assertTrue(itemStartLatch.await(0, TimeUnit.SECONDS)); + } finally { + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + } +} From f26829aa93f9ad8997f9428bcccf20410addb574 Mon Sep 17 00:00:00 2001 From: David Turner Date: Fri, 6 Jan 2023 19:21:13 +0000 Subject: [PATCH 2/7] Update docs/changelog/92729.yaml --- docs/changelog/92729.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/92729.yaml diff --git a/docs/changelog/92729.yaml b/docs/changelog/92729.yaml new file mode 100644 index 0000000000000..8c4a34c9dcbc3 --- /dev/null +++ b/docs/changelog/92729.yaml @@ -0,0 +1,5 @@ +pr: 92729 +summary: Improve scalability of `BroadcastReplicationActions` +area: Network +type: bug +issues: [] From fd611122bc9205692e8ac00d5638106308965027 Mon Sep 17 00:00:00 2001 From: David Turner Date: Fri, 6 Jan 2023 19:40:20 +0000 Subject: [PATCH 3/7] Explain MAX_REQUESTS_PER_NODE --- .../TransportBroadcastReplicationAction.java | 4 +++- .../support/replication/BroadcastReplicationTests.java | 10 ++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java index da33845c7c273..16cccb012e5b1 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java @@ -49,6 +49,8 @@ public abstract class TransportBroadcastReplicationAction< ShardRequest extends ReplicationRequest, ShardResponse extends ReplicationResponse> extends HandledTransportAction { + static int MAX_REQUESTS_PER_NODE = 10; // The REFRESH threadpool maxes out at 10 by default so this is enough to keep everyone busy. + private final ActionType replicatedBroadcastShardAction; private final ClusterService clusterService; private final IndexNameExpressionResolver indexNameExpressionResolver; @@ -81,7 +83,7 @@ protected void doExecute(Task task, Request request, ActionListener li ThrottledIterator.run( shardIds(request, clusterState), context::processShard, - clusterState.nodes().getDataNodes().size() * 10, + clusterState.nodes().getDataNodes().size() * MAX_REQUESTS_PER_NODE, () -> {}, context::finish ); diff --git a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java index 8d49df665db2f..bff78a28a5024 100644 --- a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java @@ -69,6 +69,7 @@ import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.state; import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithAssignedPrimariesAndOneReplica; import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithNoShard; +import static org.elasticsearch.action.support.replication.TransportBroadcastReplicationAction.MAX_REQUESTS_PER_NODE; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_VERSION_CREATED; @@ -278,6 +279,7 @@ public void testShardsIterator() { } public void testThrottling() { + final var shardCount = between(MAX_REQUESTS_PER_NODE, MAX_REQUESTS_PER_NODE * 5); final var replicaCount = between(0, 2); setState( clusterService, @@ -288,7 +290,7 @@ public void testThrottling() { new IndexMetadata.Builder("test").settings( Settings.builder() .put(SETTING_VERSION_CREATED, Version.CURRENT) - .put(SETTING_NUMBER_OF_SHARDS, 25) + .put(SETTING_NUMBER_OF_SHARDS, shardCount) .put(SETTING_NUMBER_OF_REPLICAS, replicaCount) ) ) @@ -300,7 +302,7 @@ public void testThrottling() { PlainActionFuture future = PlainActionFuture.newFuture(); ActionTestUtils.execute(broadcastReplicationAction, null, new DummyBroadcastRequest().indices("test"), future); - final var maxOutstandingRequests = clusterService.state().nodes().getDataNodes().size() * 10; + final var maxOutstandingRequests = clusterService.state().nodes().getDataNodes().size() * MAX_REQUESTS_PER_NODE; assertThat(broadcastReplicationAction.capturedShardRequests.size(), equalTo(maxOutstandingRequests)); assertFalse(future.isDone()); final boolean[] handled = new boolean[clusterService.state().metadata().index("test").getNumberOfShards()]; @@ -323,9 +325,9 @@ public void testThrottling() { } var response = future.actionGet(10, TimeUnit.SECONDS); - assertEquals(25 * (replicaCount + 1), response.getTotalShards()); + assertEquals(shardCount * (replicaCount + 1), response.getTotalShards()); assertEquals(successes * (replicaCount + 1), response.getSuccessfulShards()); - assertEquals((25 - successes) * (replicaCount + 1), response.getFailedShards()); + assertEquals((shardCount - successes) * (replicaCount + 1), response.getFailedShards()); } private class TestBroadcastReplicationAction extends TransportBroadcastReplicationAction< From 35c61f0f00210cfac9a2b94f16a3bbbd5381c4a0 Mon Sep 17 00:00:00 2001 From: David Turner Date: Fri, 6 Jan 2023 19:50:00 +0000 Subject: [PATCH 4/7] Drop failing assertion --- .../replication/TransportBroadcastReplicationAction.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java index 16cccb012e5b1..73dfee89cc8c7 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java @@ -30,7 +30,6 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.transport.Transports; import java.util.ArrayList; import java.util.Arrays; @@ -90,7 +89,7 @@ protected void doExecute(Task task, Request request, ActionListener li } protected void shardExecute(Task task, Request request, ShardId shardId, ActionListener shardActionListener) { - assert Transports.assertNotTransportThread("per-shard requests might be high-volume"); + // assert Transports.assertNotTransportThread("per-shard requests might be high-volume"); TODO Yikes! ShardRequest shardRequest = newShardRequest(request, shardId); shardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); client.executeLocally(replicatedBroadcastShardAction, shardRequest, shardActionListener); From d67c82aa28220ac71aa6565f37c22adba72352c3 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 10 Jan 2023 15:06:40 +0000 Subject: [PATCH 5/7] Fix forking --- .../indices/flush/TransportFlushAction.java | 3 +- .../refresh/TransportRefreshAction.java | 3 +- .../TransportBroadcastReplicationAction.java | 36 +++++++++++-------- .../BroadcastReplicationTests.java | 5 +-- 4 files changed, 29 insertions(+), 18 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java index 6fce79e31b911..790eb1f18641a 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java @@ -48,7 +48,8 @@ public TransportFlushAction( actionFilters, indexNameExpressionResolver, TransportShardFlushAction.TYPE, - ThreadPool.Names.FLUSH + ThreadPool.Names.FLUSH, + 5 // the FLUSH threadpool has at most 5 threads by default, so this should be enough to keep everyone busy ); } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java index ceb940502da5d..813c4a77fbe27 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java @@ -50,7 +50,8 @@ public TransportRefreshAction( actionFilters, indexNameExpressionResolver, TransportShardRefreshAction.TYPE, - ThreadPool.Names.REFRESH + ThreadPool.Names.REFRESH, + 10 // the REFRESH threadpool has at most 10 threads by default, so this should be enough to keep everyone busy ); } diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java index 73dfee89cc8c7..dcc3d4ac87290 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java @@ -9,6 +9,7 @@ package org.elasticsearch.action.support.replication; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.DefaultShardOperationFailedException; @@ -29,7 +30,9 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.transport.Transports; import java.util.ArrayList; import java.util.Arrays; @@ -48,13 +51,12 @@ public abstract class TransportBroadcastReplicationAction< ShardRequest extends ReplicationRequest, ShardResponse extends ReplicationResponse> extends HandledTransportAction { - static int MAX_REQUESTS_PER_NODE = 10; // The REFRESH threadpool maxes out at 10 by default so this is enough to keep everyone busy. - private final ActionType replicatedBroadcastShardAction; private final ClusterService clusterService; private final IndexNameExpressionResolver indexNameExpressionResolver; private final String executor; private final NodeClient client; + private final int maxRequestsPerNode; protected TransportBroadcastReplicationAction( String name, @@ -65,31 +67,37 @@ protected TransportBroadcastReplicationAction( ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, ActionType replicatedBroadcastShardAction, - String executor + String executor, + int maxRequestsPerNode ) { - super(name, transportService, actionFilters, requestReader, executor); + // Explicitly SAME since the REST layer runs this directly via the NodeClient so it doesn't fork even if we tell it to (see #92730) + super(name, transportService, actionFilters, requestReader, ThreadPool.Names.SAME); + this.client = client; this.replicatedBroadcastShardAction = replicatedBroadcastShardAction; this.clusterService = clusterService; this.indexNameExpressionResolver = indexNameExpressionResolver; this.executor = executor; + this.maxRequestsPerNode = maxRequestsPerNode; } @Override protected void doExecute(Task task, Request request, ActionListener listener) { - final var clusterState = clusterService.state(); - final var context = new Context(task, request, clusterState.metadata().indices(), listener); - ThrottledIterator.run( - shardIds(request, clusterState), - context::processShard, - clusterState.nodes().getDataNodes().size() * MAX_REQUESTS_PER_NODE, - () -> {}, - context::finish - ); + clusterService.threadPool().executor(executor).execute(ActionRunnable.wrap(listener, l -> { + final var clusterState = clusterService.state(); + final var context = new Context(task, request, clusterState.metadata().indices(), listener); + ThrottledIterator.run( + shardIds(request, clusterState), + context::processShard, + clusterState.nodes().getDataNodes().size() * maxRequestsPerNode, + () -> {}, + context::finish + ); + })); } protected void shardExecute(Task task, Request request, ShardId shardId, ActionListener shardActionListener) { - // assert Transports.assertNotTransportThread("per-shard requests might be high-volume"); TODO Yikes! + assert Transports.assertNotTransportThread("per-shard requests might be high-volume"); ShardRequest shardRequest = newShardRequest(request, shardId); shardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); client.executeLocally(replicatedBroadcastShardAction, shardRequest, shardActionListener); diff --git a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java index bff78a28a5024..3e6b17a1b62d7 100644 --- a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java @@ -69,7 +69,6 @@ import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.state; import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithAssignedPrimariesAndOneReplica; import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithNoShard; -import static org.elasticsearch.action.support.replication.TransportBroadcastReplicationAction.MAX_REQUESTS_PER_NODE; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_VERSION_CREATED; @@ -86,6 +85,7 @@ public class BroadcastReplicationTests extends ESTestCase { private ClusterService clusterService; private TransportService transportService; private TestBroadcastReplicationAction broadcastReplicationAction; + private static final int MAX_REQUESTS_PER_NODE = between(1, 10); @BeforeClass public static void beforeClass() { @@ -353,7 +353,8 @@ private class TestBroadcastReplicationAction extends TransportBroadcastReplicati actionFilters, indexNameExpressionResolver, null, - ThreadPool.Names.SAME + ThreadPool.Names.SAME, + MAX_REQUESTS_PER_NODE ); } From fd64608b4e9bdc610c93a625214b23238ec2535d Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 10 Jan 2023 15:10:19 +0000 Subject: [PATCH 6/7] Check block always --- .../admin/indices/flush/TransportShardFlushAction.java | 6 ------ .../admin/indices/refresh/TransportShardRefreshAction.java | 6 ------ .../replication/TransportBroadcastReplicationAction.java | 3 +++ 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java index 22c509deccc2f..32e67e95d1936 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java @@ -15,7 +15,6 @@ import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.support.replication.TransportReplicationAction; import org.elasticsearch.cluster.action.shard.ShardStateAction; -import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.StreamInput; @@ -69,11 +68,6 @@ public TransportShardFlushAction( ); } - @Override - protected ClusterBlockLevel globalBlockLevel() { - return ClusterBlockLevel.METADATA_READ; - } - @Override protected ReplicationResponse newResponseInstance(StreamInput in) throws IOException { return new ReplicationResponse(in); diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java index 41c230a84906b..b1b45b9e3edb1 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java @@ -15,7 +15,6 @@ import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.support.replication.TransportReplicationAction; import org.elasticsearch.cluster.action.shard.ShardStateAction; -import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.StreamInput; @@ -70,11 +69,6 @@ protected ReplicationResponse newResponseInstance(StreamInput in) throws IOExcep return new ReplicationResponse(in); } - @Override - protected ClusterBlockLevel globalBlockLevel() { - return ClusterBlockLevel.METADATA_READ; - } - @Override protected void shardOperationOnPrimary( BasicReplicationRequest shardRequest, diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java index dcc3d4ac87290..dbe15d35e511c 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java @@ -8,6 +8,7 @@ package org.elasticsearch.action.support.replication; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.ActionType; @@ -21,6 +22,7 @@ import org.elasticsearch.action.support.broadcast.BroadcastShardOperationFailedException; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -85,6 +87,7 @@ protected TransportBroadcastReplicationAction( protected void doExecute(Task task, Request request, ActionListener listener) { clusterService.threadPool().executor(executor).execute(ActionRunnable.wrap(listener, l -> { final var clusterState = clusterService.state(); + ExceptionsHelper.reThrowIfNotNull(clusterState.blocks().globalBlockedException(ClusterBlockLevel.METADATA_READ)); final var context = new Context(task, request, clusterState.metadata().indices(), listener); ThrottledIterator.run( shardIds(request, clusterState), From 66c9fa0e139ca9d776234ed9d8bba463009d6a10 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 10 Jan 2023 15:42:25 +0000 Subject: [PATCH 7/7] Too early for randomness --- .../action/support/replication/BroadcastReplicationTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java index 3e6b17a1b62d7..d0ecf6c2abb70 100644 --- a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java @@ -85,7 +85,7 @@ public class BroadcastReplicationTests extends ESTestCase { private ClusterService clusterService; private TransportService transportService; private TestBroadcastReplicationAction broadcastReplicationAction; - private static final int MAX_REQUESTS_PER_NODE = between(1, 10); + private static final int MAX_REQUESTS_PER_NODE = 10; @BeforeClass public static void beforeClass() {