diff --git a/docs/changelog/126452.yaml b/docs/changelog/126452.yaml new file mode 100644 index 0000000000000..a67c1db5211d2 --- /dev/null +++ b/docs/changelog/126452.yaml @@ -0,0 +1,5 @@ +pr: 126452 +summary: Run `newShardSnapshotTask` tasks concurrently +area: Snapshot/Restore +type: bug +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShardsServiceIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShardsServiceIT.java index 12c92365d5536..d14d543da05eb 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShardsServiceIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShardsServiceIT.java @@ -13,23 +13,30 @@ import org.apache.logging.log4j.core.LogEvent; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotRequest; +import org.elasticsearch.action.admin.cluster.snapshots.create.TransportCreateSnapshotAction; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.cluster.SnapshotsInProgress; import org.elasticsearch.cluster.coordination.Coordinator; import org.elasticsearch.index.snapshots.IndexShardSnapshotStatus; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.snapshots.mockstore.MockRepository; +import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.MockLog; import org.elasticsearch.test.disruption.NetworkDisruption; import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.threadpool.ThreadPoolStats; import org.elasticsearch.transport.TestTransportChannel; import org.elasticsearch.transport.TransportResponse; import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -166,6 +173,59 @@ public boolean innerMatch(LogEvent event) { } } ); + } + + public void testStartSnapshotsConcurrently() { + internalCluster().startMasterOnlyNode(); + final var dataNode = internalCluster().startDataOnlyNode(); + + final var repoName = randomIdentifier(); + createRepository(repoName, "fs"); + + final var threadPool = internalCluster().getInstance(ThreadPool.class, dataNode); + final var snapshotThreadCount = threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(); + + final var indexName = randomIdentifier(); + final var shardCount = between(1, snapshotThreadCount * 2); + assertAcked(prepareCreate(indexName, 0, indexSettingsNoReplicas(shardCount))); + indexRandomDocs(indexName, scaledRandomIntBetween(50, 100)); + + final var snapshotExecutor = threadPool.executor(ThreadPool.Names.SNAPSHOT); + final var barrier = new CyclicBarrier(snapshotThreadCount + 1); + for (int i = 0; i < snapshotThreadCount; i++) { + snapshotExecutor.submit(() -> { + safeAwait(barrier); + safeAwait(barrier); + }); + } + + // wait until the snapshot threads are all blocked + safeAwait(barrier); + + safeGet( + client().execute( + TransportCreateSnapshotAction.TYPE, + new CreateSnapshotRequest(TEST_REQUEST_TIMEOUT, repoName, randomIdentifier()) + ) + ); + + // one task for each snapshot thread (throttled) or shard (if fewer), plus one for runSyncTasksEagerly() + assertEquals(Math.min(snapshotThreadCount, shardCount) + 1, getSnapshotQueueLength(threadPool)); + + // release all the snapshot threads + safeAwait(barrier); + + // wait for completion + safeAwait(ClusterServiceUtils.addMasterTemporaryStateListener(cs -> SnapshotsInProgress.get(cs).isEmpty())); + } + + private static int getSnapshotQueueLength(ThreadPool threadPool) { + for (ThreadPoolStats.Stats stats : threadPool.stats().stats()) { + if (stats.name().equals(ThreadPool.Names.SNAPSHOT)) { + return stats.queue(); + } + } + throw new AssertionError("threadpool stats for [" + ThreadPool.Names.SNAPSHOT + "] not found"); } } diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java index 30b08740b4818..e840acd049b71 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java @@ -31,6 +31,7 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.engine.Engine; import org.elasticsearch.index.seqno.SequenceNumbers; @@ -56,7 +57,6 @@ import org.elasticsearch.transport.TransportService; import java.io.IOException; -import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -98,6 +98,9 @@ public final class SnapshotShardsService extends AbstractLifecycleComponent impl // A map of snapshots to the shardIds that we already reported to the master as failed private final ResultDeduplicator remoteFailedRequestDeduplicator; + // Runs the tasks that start each shard snapshot (e.g. acquiring the index commit) + private final ThrottledTaskRunner startShardSnapshotTaskRunner; + // Runs the tasks that promptly notify shards of aborted snapshots so that resources can be released ASAP private final ThrottledTaskRunner notifyOnAbortTaskRunner; @@ -131,6 +134,11 @@ public SnapshotShardsService( threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(), threadPool.generic() ); + this.startShardSnapshotTaskRunner = new ThrottledTaskRunner( + "start-shard-snapshots", + threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(), + threadPool.executor(ThreadPool.Names.SNAPSHOT) + ); } @Override @@ -384,7 +392,6 @@ private void startNewShardSnapshots(String localNodeId, SnapshotsInProgress.Entr final var newSnapshotShards = shardSnapshots.computeIfAbsent(snapshot, s -> new HashMap<>()); - final List shardSnapshotTasks = new ArrayList<>(shardsToStart.size()); for (final Map.Entry shardEntry : shardsToStart.entrySet()) { final ShardId shardId = shardEntry.getKey(); final IndexShardSnapshotStatus snapshotStatus = IndexShardSnapshotStatus.newInitializing(shardEntry.getValue()); @@ -396,11 +403,37 @@ private void startNewShardSnapshots(String localNodeId, SnapshotsInProgress.Entr : "Found non-null, non-numeric shard generation [" + snapshotStatus.generation() + "] for snapshot with old-format compatibility"; - shardSnapshotTasks.add(newShardSnapshotTask(shardId, snapshot, indexId, snapshotStatus, entry.version(), entry.startTime())); - snapshotStatus.updateStatusDescription("shard snapshot scheduled to start"); + final var shardSnapshotTask = newShardSnapshotTask( + shardId, + snapshot, + indexId, + snapshotStatus, + entry.version(), + entry.startTime() + ); + startShardSnapshotTaskRunner.enqueueTask(new ActionListener<>() { + @Override + public void onResponse(Releasable releasable) { + try (releasable) { + shardSnapshotTask.run(); + } + } + + @Override + public void onFailure(Exception e) { + final var wrapperException = new IllegalStateException( + "impossible failure starting shard snapshot for " + shardId + " in " + snapshot, + e + ); + logger.error(wrapperException.getMessage(), wrapperException); + assert false : wrapperException; // impossible + } + }); + snapshotStatus.updateStatusDescription("shard snapshot enqueued to start"); } - threadPool.executor(ThreadPool.Names.SNAPSHOT).execute(() -> shardSnapshotTasks.forEach(Runnable::run)); + // apply some backpressure by reserving one SNAPSHOT thread for the startup work + startShardSnapshotTaskRunner.runSyncTasksEagerly(threadPool.executor(ThreadPool.Names.SNAPSHOT)); } private void pauseShardSnapshotsForNodeRemoval(String localNodeId, SnapshotsInProgress.Entry entry) {