diff --git a/docs/changelog/126452.yaml b/docs/changelog/126452.yaml new file mode 100644 index 0000000000000..a67c1db5211d2 --- /dev/null +++ b/docs/changelog/126452.yaml @@ -0,0 +1,5 @@ +pr: 126452 +summary: Run `newShardSnapshotTask` tasks concurrently +area: Snapshot/Restore +type: bug +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShardsServiceIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShardsServiceIT.java index feb73d4b6f4c1..7030605034752 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShardsServiceIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShardsServiceIT.java @@ -9,16 +9,23 @@ package org.elasticsearch.snapshots; +import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotRequest; +import org.elasticsearch.action.admin.cluster.snapshots.create.TransportCreateSnapshotAction; +import org.elasticsearch.cluster.SnapshotsInProgress; import org.elasticsearch.index.snapshots.IndexShardSnapshotStatus; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.snapshots.mockstore.MockRepository; +import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.disruption.NetworkDisruption; import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.threadpool.ThreadPoolStats; import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.TimeUnit; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; @@ -89,4 +96,58 @@ public void testRetryPostingSnapshotStatusMessages() throws Exception { assertThat(snapshotInfo.successfulShards(), equalTo(shards)); }, 30L, TimeUnit.SECONDS); } + + public void testStartSnapshotsConcurrently() { + internalCluster().startMasterOnlyNode(); + final var dataNode = internalCluster().startDataOnlyNode(); + + final var repoName = randomIdentifier(); + createRepository(repoName, "fs"); + + final var threadPool = internalCluster().getInstance(ThreadPool.class, dataNode); + final var snapshotThreadCount = threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(); + + final var indexName = randomIdentifier(); + final var shardCount = between(1, snapshotThreadCount * 2); + assertAcked(prepareCreate(indexName, 0, indexSettingsNoReplicas(shardCount))); + indexRandomDocs(indexName, scaledRandomIntBetween(50, 100)); + + final var snapshotExecutor = threadPool.executor(ThreadPool.Names.SNAPSHOT); + final var barrier = new CyclicBarrier(snapshotThreadCount + 1); + for (int i = 0; i < snapshotThreadCount; i++) { + snapshotExecutor.submit(() -> { + safeAwait(barrier); + safeAwait(barrier); + }); + } + + // wait until the snapshot threads are all blocked + safeAwait(barrier); + + safeGet( + client().execute( + TransportCreateSnapshotAction.TYPE, + new CreateSnapshotRequest(TEST_REQUEST_TIMEOUT, repoName, randomIdentifier()) + ) + ); + + // one task for each snapshot thread (throttled) or shard (if fewer), plus one for runSyncTasksEagerly() + assertEquals(Math.min(snapshotThreadCount, shardCount) + 1, getSnapshotQueueLength(threadPool)); + + // release all the snapshot threads + safeAwait(barrier); + + // wait for completion + safeAwait(ClusterServiceUtils.addMasterTemporaryStateListener(cs -> SnapshotsInProgress.get(cs).isEmpty())); + } + + private static int getSnapshotQueueLength(ThreadPool threadPool) { + for (ThreadPoolStats.Stats stats : threadPool.stats().stats()) { + if (stats.name().equals(ThreadPool.Names.SNAPSHOT)) { + return stats.queue(); + } + } + + throw new AssertionError("threadpool stats for [" + ThreadPool.Names.SNAPSHOT + "] not found"); + } } diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java index abc5f36eef7da..f42cd1f7924e3 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.engine.Engine; import org.elasticsearch.index.seqno.SequenceNumbers; @@ -54,7 +55,6 @@ import org.elasticsearch.transport.TransportService; import java.io.IOException; -import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -87,6 +87,9 @@ public final class SnapshotShardsService extends AbstractLifecycleComponent impl // A map of snapshots to the shardIds that we already reported to the master as failed private final ResultDeduplicator remoteFailedRequestDeduplicator; + // Runs the tasks that start each shard snapshot (e.g. acquiring the index commit) + private final ThrottledTaskRunner startShardSnapshotTaskRunner; + // Runs the tasks that promptly notify shards of aborted snapshots so that resources can be released ASAP private final ThrottledTaskRunner notifyOnAbortTaskRunner; @@ -114,6 +117,11 @@ public SnapshotShardsService( threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(), threadPool.generic() ); + this.startShardSnapshotTaskRunner = new ThrottledTaskRunner( + "start-shard-snapshots", + threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(), + threadPool.executor(ThreadPool.Names.SNAPSHOT) + ); } @Override @@ -304,7 +312,6 @@ private void startNewShardSnapshots(String localNodeId, SnapshotsInProgress.Entr final var newSnapshotShards = shardSnapshots.computeIfAbsent(snapshot, s -> new HashMap<>()); - final List shardSnapshotTasks = new ArrayList<>(shardsToStart.size()); for (final Map.Entry shardEntry : shardsToStart.entrySet()) { final ShardId shardId = shardEntry.getKey(); final IndexShardSnapshotStatus snapshotStatus = IndexShardSnapshotStatus.newInitializing(shardEntry.getValue()); @@ -316,10 +323,36 @@ private void startNewShardSnapshots(String localNodeId, SnapshotsInProgress.Entr : "Found non-null, non-numeric shard generation [" + snapshotStatus.generation() + "] for snapshot with old-format compatibility"; - shardSnapshotTasks.add(newShardSnapshotTask(shardId, snapshot, indexId, snapshotStatus, entry.version(), entry.startTime())); + final var shardSnapshotTask = newShardSnapshotTask( + shardId, + snapshot, + indexId, + snapshotStatus, + entry.version(), + entry.startTime() + ); + startShardSnapshotTaskRunner.enqueueTask(new ActionListener<>() { + @Override + public void onResponse(Releasable releasable) { + try (releasable) { + shardSnapshotTask.run(); + } + } + + @Override + public void onFailure(Exception e) { + final var wrapperException = new IllegalStateException( + "impossible failure starting shard snapshot for " + shardId + " in " + snapshot, + e + ); + logger.error(wrapperException.getMessage(), wrapperException); + assert false : wrapperException; // impossible + } + }); } - threadPool.executor(ThreadPool.Names.SNAPSHOT).execute(() -> shardSnapshotTasks.forEach(Runnable::run)); + // apply some backpressure by reserving one SNAPSHOT thread for the startup work + startShardSnapshotTaskRunner.runSyncTasksEagerly(threadPool.executor(ThreadPool.Names.SNAPSHOT)); } private void pauseShardSnapshotsForNodeRemoval(String localNodeId, SnapshotsInProgress.Entry entry) {