Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/92729.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 92729
summary: Improve scalability of `BroadcastReplicationActions`
area: Network
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

import java.util.List;
Expand Down Expand Up @@ -46,15 +47,12 @@ public TransportFlushAction(
client,
actionFilters,
indexNameExpressionResolver,
TransportShardFlushAction.TYPE
TransportShardFlushAction.TYPE,
ThreadPool.Names.FLUSH,
5 // the FLUSH threadpool has at most 5 threads by default, so this should be enough to keep everyone busy
);
}

@Override
protected ReplicationResponse newShardResponse() {
return new ReplicationResponse();
}

@Override
protected ShardFlushRequest newShardRequest(FlushRequest request, ShardId shardId) {
return new ShardFlushRequest(request, shardId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

import java.util.List;
Expand Down Expand Up @@ -48,15 +49,12 @@ public TransportRefreshAction(
client,
actionFilters,
indexNameExpressionResolver,
TransportShardRefreshAction.TYPE
TransportShardRefreshAction.TYPE,
ThreadPool.Names.REFRESH,
10 // the REFRESH threadpool has at most 10 threads by default, so this should be enough to keep everyone busy
);
}

@Override
protected ReplicationResponse newShardResponse() {
return new ReplicationResponse();
}

@Override
protected BasicReplicationRequest newShardRequest(RefreshRequest request, ShardId shardId) {
BasicReplicationRequest replicationRequest = new BasicReplicationRequest(shardId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,38 @@

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.DefaultShardOperationFailedException;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.action.support.TransportActions;
import org.elasticsearch.action.support.broadcast.BaseBroadcastResponse;
import org.elasticsearch.action.support.broadcast.BroadcastRequest;
import org.elasticsearch.action.support.broadcast.BroadcastShardOperationFailedException;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.routing.IndexRoutingTable;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.common.util.concurrent.ThrottledIterator;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.Transports;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.Map;

/**
* Base class for requests that should be executed on all shards of an index or several indices.
Expand All @@ -48,136 +56,167 @@ public abstract class TransportBroadcastReplicationAction<
private final ActionType<ShardResponse> replicatedBroadcastShardAction;
private final ClusterService clusterService;
private final IndexNameExpressionResolver indexNameExpressionResolver;
private final String executor;
private final NodeClient client;
private final int maxRequestsPerNode;

public TransportBroadcastReplicationAction(
protected TransportBroadcastReplicationAction(
String name,
Writeable.Reader<Request> requestReader,
ClusterService clusterService,
TransportService transportService,
NodeClient client,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
ActionType<ShardResponse> replicatedBroadcastShardAction
ActionType<ShardResponse> replicatedBroadcastShardAction,
String executor,
int maxRequestsPerNode
) {
super(name, transportService, actionFilters, requestReader);
// Explicitly SAME since the REST layer runs this directly via the NodeClient so it doesn't fork even if we tell it to (see #92730)
super(name, transportService, actionFilters, requestReader, ThreadPool.Names.SAME);

this.client = client;
this.replicatedBroadcastShardAction = replicatedBroadcastShardAction;
this.clusterService = clusterService;
this.indexNameExpressionResolver = indexNameExpressionResolver;
this.executor = executor;
this.maxRequestsPerNode = maxRequestsPerNode;
}

@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
final ClusterState clusterState = clusterService.state();
List<ShardId> shards = shards(request, clusterState);
final CopyOnWriteArrayList<ShardResponse> shardsResponses = new CopyOnWriteArrayList<>();
if (shards.size() == 0) {
finishAndNotifyListener(listener, shardsResponses);
}
final CountDown responsesCountDown = new CountDown(shards.size());
for (final ShardId shardId : shards) {
ActionListener<ShardResponse> shardActionListener = new ActionListener<ShardResponse>() {
@Override
public void onResponse(ShardResponse shardResponse) {
shardsResponses.add(shardResponse);
logger.trace("{}: got response from {}", actionName, shardId);
if (responsesCountDown.countDown()) {
finishAndNotifyListener(listener, shardsResponses);
}
}

@Override
public void onFailure(Exception e) {
logger.trace("{}: got failure from {}", actionName, shardId);
int totalNumCopies = clusterState.getMetadata().getIndexSafe(shardId.getIndex()).getNumberOfReplicas() + 1;
ShardResponse shardResponse = newShardResponse();
ReplicationResponse.ShardInfo.Failure[] failures;
if (TransportActions.isShardNotAvailableException(e)) {
failures = new ReplicationResponse.ShardInfo.Failure[0];
} else {
ReplicationResponse.ShardInfo.Failure failure = new ReplicationResponse.ShardInfo.Failure(
shardId,
null,
e,
ExceptionsHelper.status(e),
true
);
failures = new ReplicationResponse.ShardInfo.Failure[totalNumCopies];
Arrays.fill(failures, failure);
}
shardResponse.setShardInfo(new ReplicationResponse.ShardInfo(totalNumCopies, 0, failures));
shardsResponses.add(shardResponse);
if (responsesCountDown.countDown()) {
finishAndNotifyListener(listener, shardsResponses);
}
}
};
shardExecute(task, request, shardId, shardActionListener);
}
clusterService.threadPool().executor(executor).execute(ActionRunnable.wrap(listener, l -> {
final var clusterState = clusterService.state();
ExceptionsHelper.reThrowIfNotNull(clusterState.blocks().globalBlockedException(ClusterBlockLevel.METADATA_READ));
final var context = new Context(task, request, clusterState.metadata().indices(), listener);
ThrottledIterator.run(
shardIds(request, clusterState),
context::processShard,
clusterState.nodes().getDataNodes().size() * maxRequestsPerNode,
() -> {},
context::finish
);
}));
}

protected void shardExecute(Task task, Request request, ShardId shardId, ActionListener<ShardResponse> shardActionListener) {
assert Transports.assertNotTransportThread("per-shard requests might be high-volume");
ShardRequest shardRequest = newShardRequest(request, shardId);
shardRequest.setParentTask(clusterService.localNode().getId(), task.getId());
client.executeLocally(replicatedBroadcastShardAction, shardRequest, shardActionListener);
}

/**
* @return all shard ids the request should run on
* @return all shard ids on which the request should run; exposed for tests
*/
protected List<ShardId> shards(Request request, ClusterState clusterState) {
List<ShardId> shardIds = new ArrayList<>();
String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(clusterState, request);
for (String index : concreteIndices) {
IndexMetadata indexMetadata = clusterState.metadata().getIndices().get(index);
if (indexMetadata != null) {
final IndexRoutingTable indexRoutingTable = clusterState.getRoutingTable().indicesRouting().get(index);
for (int i = 0; i < indexRoutingTable.size(); i++) {
shardIds.add(indexRoutingTable.shard(i).shardId());
}
}
}
return shardIds;
Iterator<? extends ShardId> shardIds(Request request, ClusterState clusterState) {
var indexMetadataByName = clusterState.metadata().indices();
return Iterators.flatMap(
Iterators.forArray(indexNameExpressionResolver.concreteIndexNames(clusterState, request)),
indexName -> indexShardIds(indexMetadataByName.get(indexName))
);
}

protected abstract ShardResponse newShardResponse();

protected abstract ShardRequest newShardRequest(Request request, ShardId shardId);

private void finishAndNotifyListener(ActionListener<Response> listener, CopyOnWriteArrayList<ShardResponse> shardsResponses) {
logger.trace("{}: got all shard responses", actionName);
int successfulShards = 0;
int failedShards = 0;
int totalNumCopies = 0;
List<DefaultShardOperationFailedException> shardFailures = null;
for (int i = 0; i < shardsResponses.size(); i++) {
ReplicationResponse shardResponse = shardsResponses.get(i);
if (shardResponse == null) {
// non active shard, ignore
} else {
failedShards += shardResponse.getShardInfo().getFailed();
successfulShards += shardResponse.getShardInfo().getSuccessful();
totalNumCopies += shardResponse.getShardInfo().getTotal();
if (shardFailures == null) {
shardFailures = new ArrayList<>();
}
for (ReplicationResponse.ShardInfo.Failure failure : shardResponse.getShardInfo().getFailures()) {
shardFailures.add(
new DefaultShardOperationFailedException(
new BroadcastShardOperationFailedException(failure.fullShardId(), failure.getCause())
)
);
}
}
private static Iterator<ShardId> indexShardIds(@Nullable IndexMetadata indexMetadata) {
if (indexMetadata == null) {
return Collections.emptyIterator();
}
var shardIds = new ShardId[indexMetadata.getNumberOfShards()];
for (int i = 0; i < shardIds.length; i++) {
shardIds[i] = new ShardId(indexMetadata.getIndex(), i);
}
listener.onResponse(newResponse(successfulShards, failedShards, totalNumCopies, shardFailures));
return Iterators.forArray(shardIds);
}

protected abstract ShardRequest newShardRequest(Request request, ShardId shardId);

protected abstract Response newResponse(
int successfulShards,
int failedShards,
int totalNumCopies,
List<DefaultShardOperationFailedException> shardFailures
);

private class Context {
private final Task task;
private final Request request;
private final Map<String, IndexMetadata> indexMetadataByName;
private final ActionListener<Response> listener;

private int totalNumCopies;
private int totalSuccessful;
private final List<DefaultShardOperationFailedException> allFailures = new ArrayList<>();

Context(Task task, Request request, Map<String, IndexMetadata> indexMetadataByName, ActionListener<Response> listener) {
this.task = task;
this.request = request;
this.indexMetadataByName = indexMetadataByName;
this.listener = listener;
}

void processShard(ThrottledIterator.ItemRefs refs, ShardId shardId) {
shardExecute(
task,
request,
shardId,
new ThreadedActionListener<>(
logger,
clusterService.threadPool(),
executor,
ActionListener.releaseAfter(createListener(shardId), refs.acquire()),
false
)
);
}

private ActionListener<ShardResponse> createListener(ShardId shardId) {
return new ActionListener<>() {
@Override
public void onResponse(ShardResponse shardResponse) {
assert shardResponse != null;
logger.trace("{}: got response from {}", actionName, shardId);
addShardResponse(
shardResponse.getShardInfo().getTotal(),
shardResponse.getShardInfo().getSuccessful(),
Arrays.stream(shardResponse.getShardInfo().getFailures())
.map(
f -> new DefaultShardOperationFailedException(
new BroadcastShardOperationFailedException(shardId, f.getCause())
)
)
.toList()
);
}

@Override
public void onFailure(Exception e) {
logger.trace("{}: got failure from {}", actionName, shardId);
final int numCopies = indexMetadataByName.get(shardId.getIndexName()).getNumberOfReplicas() + 1;
addShardResponse(numCopies, 0, createSyntheticFailures(numCopies, e));
}

private List<DefaultShardOperationFailedException> createSyntheticFailures(int numCopies, Exception e) {
if (TransportActions.isShardNotAvailableException(e)) {
return List.of();
}

final var failures = new DefaultShardOperationFailedException[numCopies];
Arrays.fill(failures, new DefaultShardOperationFailedException(new BroadcastShardOperationFailedException(shardId, e)));
return Arrays.asList(failures);
}
};
}

private synchronized void addShardResponse(int numCopies, int successful, List<DefaultShardOperationFailedException> failures) {
totalNumCopies += numCopies;
totalSuccessful += successful;
allFailures.addAll(failures);
}

void finish() {
// no need for synchronized here, the ThrottledIterator guarantees that all the addShardResponse calls happen-before this point
logger.trace("{}: got all shard responses", actionName);
listener.onResponse(newResponse(totalSuccessful, allFailures.size(), totalNumCopies, allFailures));
}
}
}
Loading