diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java new file mode 100644 index 0000000000000..d6ac42a9211c9 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java @@ -0,0 +1,154 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Strings; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; + +import java.util.Iterator; +import java.util.Objects; +import java.util.concurrent.Semaphore; +import java.util.function.BiConsumer; + +public class ThrottledIterator implements Releasable { + + private static final Logger logger = LogManager.getLogger(ThrottledIterator.class); + + /** + * Iterate through the given collection, performing an operation on each item which may fork background tasks, but with a limit on the + * number of such background tasks running concurrently to avoid overwhelming the rest of the system (e.g. starving other work of access + * to an executor). + * + * @param iterator The items to iterate. May be accessed by multiple threads, but accesses are all protected by synchronizing on itself. + * @param itemConsumer The operation to perform on each item. Each operation receives a {@link RefCounted} which can be used to track + * the execution of any background tasks spawned for this item. This operation may run on the thread which + * originally called {@link #run}, if this method has not yet returned. Otherwise it will run on a thread on which a + * background task previously called {@link RefCounted#decRef()} on its ref count. This operation should not throw + * any exceptions. + * @param maxConcurrency The maximum number of ongoing operations at any time. + * @param onItemCompletion Executed when each item is completed, which can be used for instance to report on progress. Must not throw + * exceptions. + * @param onCompletion Executed when all items are completed. + */ + public static void run( + Iterator iterator, + BiConsumer itemConsumer, + int maxConcurrency, + Runnable onItemCompletion, + Runnable onCompletion + ) { + try (var throttledIterator = new ThrottledIterator<>(iterator, itemConsumer, maxConcurrency, onItemCompletion, onCompletion)) { + throttledIterator.run(); + } + } + + private final RefCounted refs; // one ref for each running item, plus one for the iterator if incomplete + private final Iterator iterator; + private final BiConsumer itemConsumer; + private final Semaphore permits; + private final Runnable onItemCompletion; + + private ThrottledIterator( + Iterator iterator, + BiConsumer itemConsumer, + int maxConcurrency, + Runnable onItemCompletion, + Runnable onCompletion + ) { + this.iterator = Objects.requireNonNull(iterator); + this.itemConsumer = Objects.requireNonNull(itemConsumer); + if (maxConcurrency <= 0) { + throw new IllegalArgumentException("maxConcurrency must be positive"); + } + this.permits = new Semaphore(maxConcurrency); + this.onItemCompletion = Objects.requireNonNull(onItemCompletion); + this.refs = AbstractRefCounted.of(onCompletion); + } + + private void run() { + while (permits.tryAcquire()) { + final T item; + synchronized (iterator) { + if (iterator.hasNext()) { + item = iterator.next(); + } else { + permits.release(); + return; + } + } + try (var itemRefs = new ItemRefCounted()) { + itemRefs.incRef(); + itemConsumer.accept(Releasables.releaseOnce(itemRefs::decRef), item); + } catch (Exception e) { + logger.error(Strings.format("exception when processing [%s] with [%s]", item, itemConsumer), e); + assert false : e; + } + } + } + + @Override + public void close() { + refs.decRef(); + } + + // A RefCounted for a single item, including protection against calling back into run() if it's created and closed within a single + // invocation of run(). + private class ItemRefCounted extends AbstractRefCounted implements Releasable { + private boolean isRecursive = true; + + ItemRefCounted() { + refs.incRef(); + } + + @Override + protected void closeInternal() { + try { + onItemCompletion.run(); + } catch (Exception e) { + logger.error("exception in onItemCompletion", e); + assert false : e; + } finally { + permits.release(); + try { + // Someone must now pick up the next item. Here we might be called from the run() invocation which started processing + // the just-completed item (via close() -> decRef()) if that item's processing didn't fork or all its forked tasks + // finished first. If so, there's no need to call run() here, we can just return and the next iteration of the run() + // loop will continue the processing; moreover calling run() in this situation could lead to a stack overflow. However + // if we're not within that run() invocation then ... + if (isRecursive() == false) { + // ... we're not within any other run() invocation either, so it's safe (and necessary) to call run() here. + run(); + } + } finally { + refs.decRef(); + } + } + } + + // Note on blocking: we call both of these synchronized methods exactly once (and must enter close() before calling isRecursive()). + // If close() releases the last ref and calls closeInternal(), and hence isRecursive(), then there's no other threads involved and + // hence no blocking. In contrast if close() doesn't release the last ref then it exits immediately, so the call to isRecursive() + // will proceed without delay in this case too. + + private synchronized boolean isRecursive() { + return isRecursive; + } + + @Override + public synchronized void close() { + decRef(); + isRecursive = false; + } + } +} diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/ThrottledIteratorTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/ThrottledIteratorTests.java new file mode 100644 index 0000000000000..9521677e2db5f --- /dev/null +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/ThrottledIteratorTests.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.apache.lucene.tests.util.LuceneTestCase; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BooleanSupplier; +import java.util.stream.IntStream; + +public class ThrottledIteratorTests extends ESTestCase { + private static final String CONSTRAINED = "constrained"; + private static final String RELAXED = "relaxed"; + + public void testConcurrency() throws InterruptedException { + final var maxConstrainedThreads = between(1, 3); + final var maxRelaxedThreads = between(1, 100); + final var constrainedQueue = between(3, 6); + final var threadPool = new TestThreadPool( + "test", + new FixedExecutorBuilder(Settings.EMPTY, CONSTRAINED, maxConstrainedThreads, constrainedQueue, CONSTRAINED, false), + new ScalingExecutorBuilder(RELAXED, 1, maxRelaxedThreads, TimeValue.timeValueSeconds(30), true) + ); + try { + final var items = between(1, 10000); // large enough that inadvertent recursion will trigger a StackOverflowError + final var itemStartLatch = new CountDownLatch(items); + final var completedItems = new AtomicInteger(); + final var maxConcurrency = between(1, (constrainedQueue + maxConstrainedThreads) * 2); + final var itemPermits = new Semaphore(maxConcurrency); + final var completionLatch = new CountDownLatch(1); + final BooleanSupplier forkSupplier = randomFrom( + () -> false, + ESTestCase::randomBoolean, + LuceneTestCase::rarely, + LuceneTestCase::usually, + () -> true + ); + final var blockPermits = new Semaphore(between(0, Math.min(maxRelaxedThreads, maxConcurrency) - 1)); + + ThrottledIterator.run(IntStream.range(0, items).boxed().iterator(), (releasable, item) -> { + try (var refs = new RefCountingRunnable(releasable::close)) { + assertTrue(itemPermits.tryAcquire()); + if (forkSupplier.getAsBoolean()) { + var ref = refs.acquire(); + final var executor = randomFrom(CONSTRAINED, RELAXED); + threadPool.executor(executor).execute(new AbstractRunnable() { + + @Override + public void onRejection(Exception e) { + assertEquals(CONSTRAINED, executor); + itemStartLatch.countDown(); + } + + @Override + protected void doRun() { + itemStartLatch.countDown(); + if (RELAXED.equals(executor) && randomBoolean() && blockPermits.tryAcquire()) { + // simulate at most (maxConcurrency-1) long-running operations, to demonstrate that they don't + // hold up the processing of the other operations + try { + assertTrue(itemStartLatch.await(30, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + throw new AssertionError("unexpected", e); + } finally { + blockPermits.release(); + } + } + } + + @Override + public void onAfter() { + itemPermits.release(); + ref.close(); + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError("unexpected", e); + } + }); + } else { + itemStartLatch.countDown(); + itemPermits.release(); + } + } + }, maxConcurrency, completedItems::incrementAndGet, completionLatch::countDown); + + assertTrue(completionLatch.await(30, TimeUnit.SECONDS)); + assertEquals(items, completedItems.get()); + assertTrue(itemPermits.tryAcquire(maxConcurrency)); + assertTrue(itemStartLatch.await(0, TimeUnit.SECONDS)); + } finally { + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + } +} diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java index ec79367f2b57c..719ae6f2c3b1f 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java @@ -33,7 +33,9 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.common.util.concurrent.ThrottledIterator; import org.elasticsearch.common.xcontent.StatusToXContentObject; +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.repositories.Repository; @@ -56,6 +58,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Queue; @@ -63,6 +66,7 @@ import java.util.Set; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.LongSupplier; import java.util.stream.IntStream; @@ -357,7 +361,7 @@ public static class AsyncAction { // choose the blob path nondeterministically to avoid clashes, assuming that the actual path doesn't matter for reproduction private final String blobPath = "temp-analysis-" + UUIDs.randomBase64UUID(); - private final Queue queue = ConcurrentCollections.newQueue(); + private final Queue> queue = ConcurrentCollections.newQueue(); private final AtomicReference failure = new AtomicReference<>(); private final Semaphore innerFailures = new Semaphore(5); // limit the number of suppressed failures private final RefCountingRunnable requestRefs = new RefCountingRunnable(this::runCleanUp); @@ -447,80 +451,87 @@ public void run() { final long targetLength = blobSizes.get(i); final boolean smallBlob = targetLength <= MAX_ATOMIC_WRITE_SIZE; // avoid the atomic API for larger blobs final boolean abortWrite = smallBlob && request.isAbortWritePermitted() && rarely(random); - final VerifyBlobTask verifyBlobTask = new VerifyBlobTask( - nodes.get(random.nextInt(nodes.size())), - new BlobAnalyzeAction.Request( - request.getRepositoryName(), - blobPath, - "test-blob-" + i + "-" + UUIDs.randomBase64UUID(random), - targetLength, - random.nextLong(), - nodes, - request.getReadNodeCount(), - request.getEarlyReadNodeCount(), - smallBlob && rarely(random), - repository.supportURLRepo() - && repository.hasAtomicOverwrites() - && smallBlob - && rarely(random) - && abortWrite == false, - abortWrite - ) + final BlobAnalyzeAction.Request blobAnalyzeRequest = new BlobAnalyzeAction.Request( + this.request.getRepositoryName(), + blobPath, + "test-blob-" + i + "-" + UUIDs.randomBase64UUID(random), + targetLength, + random.nextLong(), + nodes, + this.request.getReadNodeCount(), + this.request.getEarlyReadNodeCount(), + smallBlob && rarely(random), + repository.supportURLRepo() && repository.hasAtomicOverwrites() && smallBlob && rarely(random) && abortWrite == false, + abortWrite ); - queue.add(verifyBlobTask); + final DiscoveryNode node = nodes.get(random.nextInt(nodes.size())); + queue.add(ref -> runBlobAnalysis(ref, blobAnalyzeRequest, node)); } - try (var ignored = requestRefs) { - for (int i = 0; i < request.getConcurrency(); i++) { - processNextTask(); - } - } + ThrottledIterator.run( + getQueueIterator(), + (ref, task) -> task.accept(ref), + request.getConcurrency(), + () -> {}, + requestRefs::close + ); } private boolean rarely(Random random) { return random.nextDouble() < request.getRareActionProbability(); } - private void processNextTask() { - final VerifyBlobTask thisTask = queue.poll(); - if (isRunning() && thisTask != null) { - logger.trace("processing [{}]", thisTask); - // NB although all this is on the SAME thread, the per-blob verification runs on a SNAPSHOT thread so we don't have to worry - // about local requests resulting in a stack overflow here - final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout( - TimeValue.timeValueMillis(timeoutTimeMillis - currentTimeMillisSupplier.getAsLong()) - ); - transportService.sendChildRequest( - thisTask.node, - BlobAnalyzeAction.NAME, - thisTask.request, - task, - transportRequestOptions, - new ActionListenerResponseHandler<>(ActionListener.releaseAfter(new ActionListener<>() { - @Override - public void onResponse(BlobAnalyzeAction.Response response) { - logger.trace("finished [{}]", thisTask); - if (thisTask.request.getAbortWrite() == false) { - expectedBlobs.add(thisTask.request.getBlobName()); // each task cleans up its own mess on failure - } - if (request.detailed) { - synchronized (responses) { - responses.add(response); - } - } - summary.add(response); - processNextTask(); - } + private Iterator> getQueueIterator() { + return new Iterator<>() { + Consumer nextItem = queue.poll(); - @Override - public void onFailure(Exception exp) { - logger.debug(() -> "failed [" + thisTask + "]", exp); - fail(exp); + @Override + public boolean hasNext() { + return isRunning() && nextItem != null; + } + + @Override + public Consumer next() { + assert nextItem != null; + final var currentItem = nextItem; + nextItem = queue.poll(); + return currentItem; + } + }; + } + + private void runBlobAnalysis(Releasable ref, final BlobAnalyzeAction.Request request, DiscoveryNode node) { + logger.trace("processing [{}] on [{}]", request, node); + // NB although all this is on the SAME thread, the per-blob verification runs on a SNAPSHOT thread so we don't have to worry + // about local requests resulting in a stack overflow here + transportService.sendChildRequest( + node, + BlobAnalyzeAction.NAME, + request, + task, + TransportRequestOptions.timeout(TimeValue.timeValueMillis(timeoutTimeMillis - currentTimeMillisSupplier.getAsLong())), + new ActionListenerResponseHandler<>(ActionListener.releaseAfter(new ActionListener<>() { + @Override + public void onResponse(BlobAnalyzeAction.Response response) { + logger.trace("finished [{}] on [{}]", request, node); + if (request.getAbortWrite() == false) { + expectedBlobs.add(request.getBlobName()); // each task cleans up its own mess on failure } - }, requestRefs.acquire()), BlobAnalyzeAction.Response::new) - ); - } + if (AsyncAction.this.request.detailed) { + synchronized (responses) { + responses.add(response); + } + } + summary.add(response); + } + @Override + public void onFailure(Exception exp) { + logger.debug(() -> "failed [" + request + "] on [" + node + "]", exp); + fail(exp); + } + }, ref), BlobAnalyzeAction.Response::new) + ); } private BlobContainer getBlobContainer() { @@ -634,8 +645,6 @@ private void sendResponse(final long listingStartTimeNanos, final long deleteSta ); } } - - private record VerifyBlobTask(DiscoveryNode node, BlobAnalyzeAction.Request request) {} } public static class Request extends ActionRequest {