Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/115668.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 115668
summary: Limit the number of tasks that a single search can submit
area: Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
import java.util.TreeSet;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.LongSupplier;
import java.util.function.ToLongFunction;

Expand Down Expand Up @@ -202,7 +203,7 @@ final class DefaultSearchContext extends SearchContext {
engineSearcher.getQueryCache(),
engineSearcher.getQueryCachingPolicy(),
lowLevelCancellation,
executor,
wrapExecutor(executor),
maximumNumberOfSlices,
minimumDocsPerSlice
);
Expand All @@ -229,6 +230,32 @@ final class DefaultSearchContext extends SearchContext {
}
}

private static Executor wrapExecutor(Executor executor) {
if (executor instanceof ThreadPoolExecutor tpe) {
// let this searcher fork to a limited maximum number of tasks, to protect against situations where Lucene may
// submit too many segment level tasks. With enough parallel search requests and segments per shards, they may all see
// an empty queue and start parallelizing, filling up the queue very quickly and causing rejections, due to
// many small tasks in the queue that become no-op because the active caller thread will execute them instead.
// Note that despite all tasks are completed, TaskExecutor#invokeAll leaves the leftover no-op tasks in queue hence
// they contribute to the queue size until they are removed from it.
AtomicInteger segmentLevelTasks = new AtomicInteger(0);
return command -> {
if (segmentLevelTasks.incrementAndGet() > tpe.getMaximumPoolSize()) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am open to opinions on the threshold. It is quite conservative. For instance, for operations like knn query rewrite that parallelize on number of segments, we end up creating much less tasks on small nodes. yet it is probably a good idea to not create more tasks than the available threads, and there could be multiple shard level requests happening at the same time in the same node, deriving from the same search or others, so the total number of tasks is still potentially higher than max pool size anyways.

We should probably improve this in Lucene as a follow-up, but this is some protection mostly for 8.x which is based on Lucene 9.12.

command.run();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to decrement the counter on this code path as well?

} else {
executor.execute(() -> {
try {
command.run();
} finally {
segmentLevelTasks.decrementAndGet();
}
});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we decrement the counter when the task is done executing in order to allow search parallelism again later on?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I've been going back and forth on it.

The current impl removes the need for CAS once we used the budget. It is also much simpler to test. These are maybe minor points though. The current solution may look too conservative around the number of tasks that get created, and if they are executed fast enough we could indeed create more tasks than the number of threads in total, although not all at the same time. I wonder how likely that is a real scenario, given that TaskExecutor submits all tasks at once, and not gradually. That is why I think that this simple solution provides the most value, assuming that all tasks are submitted at once. I guess that this impl may perhaps hurt latency a little over throughput. Also, we already apply the same limit to the number of slices, statically. We would just apply the same limit effectively to knn query rewrite and degenerate cases where we end up parallelizing from a segment level task, which seems wrong anyway and we should protect from.

Additional thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that my main concern is more that if a query parallelizes at rewrite time (or createWeight time) and then at collection time, and if the query rewrite uses all the parallelism budget, then you won't get any parallelism at collection time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, my comment above does not take into account that invokeAll may be called multiple times in different stages of a search, for instance rewrite and later collection. It is also true that when we do parallelize at rewrite, we care less about parallelizing during collection, but that's an assumption that may only hold in the short term. I do think that we should eventually adjust things in Lucene and remove this conditional in ES on the long run.

I will look further into how we can make this a little better without over-complicating things.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying the same thing at one point and it's somewhat tricky to avoid a serious penalty for tiny tasks if you start increasing the amount of ref-counting that is done overall like this.
The beauty of this solution is that it's just a single thread doing all the counting pretty much, the CAS cost is far from trivial here. I couldn't make any scheme that added another CAS (on top of the existing CAS we do when enqueuing work in Lucene) work without a measurable regression.
My vote would be to see if we can find a more helpful API on the Lucene end to deal with the various contention/scheduling tradeoffs/issues we have today rather than add complexity here => I'd go with Luca's solution as is I think :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a new commit that does the decrement. Testing has become much less predictable like I expected. I do also worry about the CAS for each task. We should probably benchmark both solutions and see what the difference is. What are we losing by e.g. only parallelizing knn query rewrite and not collection when they both happen, compared to what we are losing by performing the CAS at all times?

};
}
return executor;
}

static long getFieldCardinality(String field, IndexService indexService, DirectoryReader directoryReader) {
MappedFieldType mappedFieldType = indexService.mapperService().fieldType(field);
if (mappedFieldType == null) {
Expand Down Expand Up @@ -290,6 +317,8 @@ static int determineMaximumNumberOfSlices(
boolean enableQueryPhaseParallelCollection,
ToLongFunction<String> fieldCardinality
) {
// Note: although this method refers to parallel collection, it affects any kind of parallelism, including query rewrite,
// given that if 1 is the returned value, no executor is provided to the searcher.
return executor instanceof ThreadPoolExecutor tpe
&& tpe.getQueue().size() <= tpe.getMaximumPoolSize()
&& isParallelCollectionSupportedForResults(resultsType, request.source(), fieldCardinality, enableQueryPhaseParallelCollection)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.IndexSettings;
Expand Down Expand Up @@ -78,17 +79,33 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.function.ToLongFunction;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
Expand Down Expand Up @@ -959,11 +976,161 @@ public void testGetFieldCardinalityRuntimeField() {
assertEquals(-1, DefaultSearchContext.getFieldCardinality("field", indexService, null));
}

public void testSingleThreadNoSearchConcurrency() throws IOException, ExecutionException, InterruptedException {
// with a single thread in the pool the max number of slices will always be 1, hence we won't provide the executor to the searcher
int executorPoolSize = 1;
int numIters = randomIntBetween(10, 50);
int numSegmentTasks = randomIntBetween(50, 100);
AtomicInteger completedTasks = new AtomicInteger(0);
ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(executorPoolSize);
try {
doTestSearchConcurrency(executor, numIters, numSegmentTasks, completedTasks);
} finally {
terminate(executor);
}
// Tasks are still created, but the internal executor is a direct one hence there is no parallelism in practice
assertEquals((long) numIters * numSegmentTasks + numIters, completedTasks.get());
assertEquals(numIters, executor.getCompletedTaskCount());
}

@SuppressForbidden(reason = "need to provide queue to ThreadPoolExecutor")
public void testNoSearchConcurrencyWhenQueueing() throws IOException, ExecutionException, InterruptedException {
// with multiple threads, but constant queueing, the max number of slices will always be 1, hence we won't provide the
// executor to the searcher
int executorPoolSize = randomIntBetween(2, 5);
int numIters = randomIntBetween(10, 50);
int numSegmentTasks = randomIntBetween(50, 100);
AtomicInteger completedTasks = new AtomicInteger(0);
final AtomicBoolean terminating = new AtomicBoolean(false);
LinkedBlockingQueue<Runnable> queue = new LinkedBlockingQueue<>() {
@Override
public int size() {
// for the purpose of this test we pretend that we always have more items in the queue than threads, but we need to revert
// to normal behaviour to ensure graceful shutdown
if (terminating.get()) {
return super.size();
}
return randomIntBetween(executorPoolSize + 1, Integer.MAX_VALUE);
}
};
ThreadPoolExecutor executor = new ThreadPoolExecutor(executorPoolSize, executorPoolSize, 0L, TimeUnit.MILLISECONDS, queue);
try {
doTestSearchConcurrency(executor, numIters, numSegmentTasks, completedTasks);
terminating.set(true);
} finally {
terminate(executor);
}
// Tasks are still created, but the internal executor is a direct one hence there is no parallelism in practice
assertEquals((long) numIters * numSegmentTasks + numIters, completedTasks.get());
assertEquals(numIters, executor.getCompletedTaskCount());
}

@SuppressForbidden(reason = "need to provide queue to ThreadPoolExecutor")
public void testSearchConcurrencyDoesNotCreateMoreTasksThanThreads() throws Exception {
// with multiple threads, but not enough queueing to disable parallelism, we will provide the executor to the searcher
int executorPoolSize = randomIntBetween(2, 5);
int numIters = randomIntBetween(10, 50);
int numSegmentTasks = randomIntBetween(50, 100);
AtomicInteger completedTasks = new AtomicInteger(0);
final AtomicBoolean terminating = new AtomicBoolean(false);
LinkedBlockingQueue<Runnable> queue = new LinkedBlockingQueue<>() {
@Override
public int size() {
int size = super.size();
// for the purpose of this test we pretend that we only ever have as many items in the queue as number of threads, but we
// need to revert to normal behaviour to ensure graceful shutdown
if (size <= executorPoolSize || terminating.get()) {
return size;
}
return randomIntBetween(0, executorPoolSize);
}
};
ThreadPoolExecutor executor = new ThreadPoolExecutor(executorPoolSize, executorPoolSize, 0L, TimeUnit.MILLISECONDS, queue);
try {
doTestSearchConcurrency(executor, numIters, numSegmentTasks, completedTasks);
terminating.set(true);
} finally {
terminate(executor);
}
// make sure that we do parallelize execution: each operation will use at minimum as many tasks as threads available
assertThat(executor.getCompletedTaskCount(), greaterThanOrEqualTo((long) numIters * executorPoolSize));
// while we parallelize we also limit the number of tasks that each searcher submits
assertThat(executor.getCompletedTaskCount(), lessThan((long) numIters * numSegmentTasks));
// *2 is just a wild guess to account for tasks that get executed while we are still submitting
assertThat(executor.getCompletedTaskCount(), lessThan((long) numIters * executorPoolSize * 2));
}

private void doTestSearchConcurrency(ThreadPoolExecutor executor, int numIters, int numSegmentTasks, AtomicInteger completedTasks)
throws IOException, ExecutionException, InterruptedException {
DefaultSearchContext[] contexts = new DefaultSearchContext[numIters];
for (int i = 0; i < numIters; i++) {
contexts[i] = createDefaultSearchContext(executor, randomFrom(SearchService.ResultsType.DFS, SearchService.ResultsType.QUERY));
}
List<Future<?>> futures = new ArrayList<>(numIters);
try {
for (int i = 0; i < numIters; i++) {
// simulate multiple concurrent search operations that parallelize each their execution across many segment level tasks
// via Lucene's TaskExecutor. Segment level tasks are never rejected (they execute on the caller upon rejection), but
// the top-level execute call is subject to rejection once the queue is filled with segment level tasks. That is why
// we want to limit the number of tasks that each search can parallelize to
// NOTE: DefaultSearchContext does not provide the executor to the searcher once it sees maxPoolSize items in the queue.
DefaultSearchContext searchContext = contexts[i];
AtomicInteger segmentTasksCompleted = new AtomicInteger(0);
RunnableFuture<Void> task = new FutureTask<>(() -> {
Collection<Callable<Void>> tasks = new ArrayList<>();
for (int j = 0; j < numSegmentTasks; j++) {
tasks.add(() -> {
segmentTasksCompleted.incrementAndGet();
completedTasks.incrementAndGet();
return null;
});
}
try {
searchContext.searcher().getTaskExecutor().invokeAll(tasks);
// TODO additional calls to invokeAll

// invokeAll is blocking, hence at this point we are done executing all the sub-tasks, but the queue may
// still be filled up with no-op leftover tasks
assertEquals(numSegmentTasks, segmentTasksCompleted.get());
} catch (IOException e) {
throw new UncheckedIOException(e);
} finally {
completedTasks.incrementAndGet();
}
return null;
});
futures.add(task);
executor.execute(task);
}
for (Future<?> future : futures) {
future.get();
}
} finally {
for (DefaultSearchContext searchContext : contexts) {
searchContext.indexShard().getThreadPool().shutdown();
searchContext.close();
}
}
}

private DefaultSearchContext createDefaultSearchContext(Executor executor, SearchService.ResultsType resultsType) throws IOException {
return createDefaultSearchContext(Settings.EMPTY, null, executor, resultsType);
}

private DefaultSearchContext createDefaultSearchContext(Settings providedIndexSettings) throws IOException {
return createDefaultSearchContext(providedIndexSettings, null);
}

private DefaultSearchContext createDefaultSearchContext(Settings providedIndexSettings, XContentBuilder mappings) throws IOException {
return createDefaultSearchContext(providedIndexSettings, mappings, null, randomFrom(SearchService.ResultsType.values()));
}

private DefaultSearchContext createDefaultSearchContext(
Settings providedIndexSettings,
XContentBuilder mappings,
Executor executor,
SearchService.ResultsType resultsType
) throws IOException {
TimeValue timeout = new TimeValue(randomIntBetween(1, 100));
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
when(shardSearchRequest.searchType()).thenReturn(SearchType.DEFAULT);
Expand Down Expand Up @@ -1047,9 +1214,9 @@ protected Engine.Searcher acquireSearcherInternal(String source) {
timeout,
null,
false,
null,
randomFrom(SearchService.ResultsType.values()),
randomBoolean(),
executor,
resultsType,
executor != null || randomBoolean(),
randomInt()
);
}
Expand Down