Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.aggregations.metrics;

import org.apache.logging.log4j.util.Strings;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregatorFactory.ExecutionMode;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.ESIntegTestCase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
import static org.elasticsearch.search.aggregations.AggregationBuilders.topHits;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFailures;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse;
import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.notNullValue;

@ESIntegTestCase.SuiteScopeTestCase()
public class LargeTopHitsIT extends ESIntegTestCase {

private static final String TERMS_AGGS_FIELD_1 = "terms1";
private static final String TERMS_AGGS_FIELD_2 = "terms2";
private static final String TERMS_AGGS_FIELD_3 = "terms3";
private static final String SORT_FIELD = "sort";

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)).put("indices.breaker.request.type", "memory").build();
}

public static String randomExecutionHint() {
return randomBoolean() ? null : randomFrom(ExecutionMode.values()).toString();
}

@Override
public void setupSuiteScopeCluster() throws Exception {
initSmallIdx();
ensureSearchable();
}

private void initSmallIdx() throws IOException {
createIndex("small_idx");
ensureGreen("small_idx");
populateIndex("small_idx", 5, 40_000);
}

private void initLargeIdx() throws IOException {
createIndex("large_idx");
ensureGreen("large_idx");
populateIndex("large_idx", 70, 50_000);
}

public void testSimple() {
assertNoFailuresAndResponse(query("small_idx"), response -> {
Terms terms = response.getAggregations().get("terms");
assertThat(terms, notNullValue());
});
}

public void test500Queries() {
for (int i = 0; i < 500; i++) {
// make sure we are not leaking memory over multiple queries
assertNoFailuresAndResponse(query("small_idx"), response -> {
Terms terms = response.getAggregations().get("terms");
assertThat(terms, notNullValue());
});
}
}

// This works most of the time, but it's not consistent: it still triggers OOM sometimes.
// The test env is too small and non-deterministic to hold all these data and results.
@AwaitsFix(bugUrl = "see comment above")
public void testBreakAndRecover() throws IOException {
initLargeIdx();
assertNoFailuresAndResponse(query("small_idx"), response -> {
Terms terms = response.getAggregations().get("terms");
assertThat(terms, notNullValue());
});

assertFailures(query("large_idx"), RestStatus.TOO_MANY_REQUESTS, containsString("Data too large"));

assertNoFailuresAndResponse(query("small_idx"), response -> {
Terms terms = response.getAggregations().get("terms");
assertThat(terms, notNullValue());
});
}

private void createIndex(String idxName) {
assertAcked(
prepareCreate(idxName).setMapping(
TERMS_AGGS_FIELD_1,
"type=keyword",
TERMS_AGGS_FIELD_2,
"type=keyword",
TERMS_AGGS_FIELD_3,
"type=keyword",
"text",
"type=text,store=true",
"large_text_1",
"type=text,store=false",
"large_text_2",
"type=text,store=false",
"large_text_3",
"type=text,store=false",
"large_text_4",
"type=text,store=false",
"large_text_5",
"type=text,store=false"
)
);
}

private void populateIndex(String idxName, int nDocs, int size) throws IOException {
for (int i = 0; i < nDocs; i++) {
List<IndexRequestBuilder> builders = new ArrayList<>();
builders.add(
prepareIndex(idxName).setId(Integer.toString(i))
.setSource(
jsonBuilder().startObject()
.field(TERMS_AGGS_FIELD_1, "val" + i % 53)
.field(TERMS_AGGS_FIELD_2, "val" + i % 23)
.field(TERMS_AGGS_FIELD_3, "val" + i % 10)
.field(SORT_FIELD, i)
.field("text", "some text to entertain")
.field("large_text_1", Strings.repeat("this is a text field 1 ", size))
.field("large_text_2", Strings.repeat("this is a text field 2 ", size))
.field("large_text_3", Strings.repeat("this is a text field 3 ", size))
.field("large_text_4", Strings.repeat("this is a text field 4 ", size))
.field("large_text_5", Strings.repeat("this is a text field 5 ", size))
.field("field1", 5)
.field("field2", 2.71)
.endObject()
)
);

indexRandom(true, builders);
}
}

private static SearchRequestBuilder query(String indexName) {
return prepareSearch(indexName).addAggregation(
terms("terms").executionHint(randomExecutionHint())
.field(TERMS_AGGS_FIELD_1)
.subAggregation(
terms("terms").executionHint(randomExecutionHint())
.field(TERMS_AGGS_FIELD_2)
.subAggregation(
terms("terms").executionHint(randomExecutionHint())
.field(TERMS_AGGS_FIELD_2)
.subAggregation(topHits("hits").sort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC)))
)
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ public void executeRankFeaturePhase(RankFeatureShardRequest request, SearchShard
return searchContext.rankFeatureResult();
}
RankFeatureShardPhase.prepareForFetch(searchContext, request);
fetchPhase.execute(searchContext, docIds, null);
fetchPhase.execute(searchContext, docIds, null, i -> {});
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These did not account for memory before anyway, but now they could.
The hard part is releasing the CB; I don't see close+relase logic around here, and I'm not very familiar with this code. Maybe this could be a follow-up

RankFeatureShardPhase.processFetch(searchContext);
var rankFeatureResult = searchContext.rankFeatureResult();
rankFeatureResult.incRef();
Expand All @@ -988,7 +988,7 @@ private QueryFetchSearchResult executeFetchPhase(ReaderContext reader, SearchCon
var opsListener = context.indexShard().getSearchOperationListener();
try (Releasable scope = tracer.withScope(context.getTask());) {
opsListener.onPreFetchPhase(context);
fetchPhase.execute(context, shortcutDocIdsToLoad(context), null);
fetchPhase.execute(context, shortcutDocIdsToLoad(context), null, i -> {});
if (reader.singleSession()) {
freeReaderContext(reader.id());
}
Expand Down Expand Up @@ -1204,7 +1204,7 @@ public void executeFetchPhase(ShardFetchRequest request, CancellableTask task, A
var opsListener = searchContext.indexShard().getSearchOperationListener();
opsListener.onPreFetchPhase(searchContext);
try {
fetchPhase.execute(searchContext, request.docIds(), request.getRankDocks());
fetchPhase.execute(searchContext, request.docIds(), request.getRankDocks(), i -> {});
if (readerContext.singleSession()) {
freeReaderContext(request.contextId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.IntConsumer;

class TopHitsAggregator extends MetricsAggregator {

Expand Down Expand Up @@ -198,7 +199,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
docIdsToLoad[i] = topDocs.scoreDocs[i].doc;
}
FetchSearchResult fetchResult = runFetchPhase(subSearchContext, docIdsToLoad);
FetchSearchResult fetchResult = runFetchPhase(subSearchContext, docIdsToLoad, this::addRequestCircuitBreakerBytes);
Copy link
Contributor Author

@luigidellaquila luigidellaquila Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we batch here and avoid invoking the CB for every document?
Maybe addRequestCircuitBreakerBytes should take care of this?

I suspect that fetching source is way more expensive than invoking the CB, so I'm not sure we want more complication here.

if (fetchProfiles != null) {
fetchProfiles.add(fetchResult.profileResult());
}
Expand All @@ -222,7 +223,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE
);
}

private static FetchSearchResult runFetchPhase(SubSearchContext subSearchContext, int[] docIdsToLoad) {
private static FetchSearchResult runFetchPhase(SubSearchContext subSearchContext, int[] docIdsToLoad, IntConsumer memoryChecker) {
// Fork the search execution context for each slice, because the fetch phase does not support concurrent execution yet.
SearchExecutionContext searchExecutionContext = new SearchExecutionContext(subSearchContext.getSearchExecutionContext());
// InnerHitSubContext is not thread-safe, so we fork it as well to support concurrent execution
Expand All @@ -242,7 +243,7 @@ public InnerHitsContext innerHits() {
}
};

fetchSubSearchContext.fetchPhase().execute(fetchSubSearchContext, docIdsToLoad, null);
fetchSubSearchContext.fetchPhase().execute(fetchSubSearchContext, docIdsToLoad, null, memoryChecker);
return fetchSubSearchContext.fetchResult();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.IntConsumer;
import java.util.function.Supplier;

import static org.elasticsearch.index.get.ShardGetService.maybeExcludeVectorFields;
Expand All @@ -66,7 +67,7 @@ public FetchPhase(List<FetchSubPhase> fetchSubPhases) {
this.fetchSubPhases[fetchSubPhases.size()] = new InnerHitsPhase(this);
}

public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo rankDocs) {
public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo rankDocs, IntConsumer memoryChecker) {
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("{}", new SearchContextSourcePrinter(context));
}
Expand All @@ -88,7 +89,7 @@ public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo
: Profilers.startProfilingFetchPhase();
SearchHits hits = null;
try {
hits = buildSearchHits(context, docIdsToLoad, profiler, rankDocs);
hits = buildSearchHits(context, docIdsToLoad, profiler, rankDocs, memoryChecker);
} finally {
try {
// Always finish profiling
Expand Down Expand Up @@ -116,7 +117,13 @@ public Source getSource(LeafReaderContext ctx, int doc) {
}
}

private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Profiler profiler, RankDocShardInfo rankDocs) {
private SearchHits buildSearchHits(
SearchContext context,
int[] docIdsToLoad,
Profiler profiler,
RankDocShardInfo rankDocs,
IntConsumer memoryChecker
) {
var lookup = context.getSearchExecutionContext().getMappingLookup();

// Optionally remove sparse and dense vector fields early to:
Expand Down Expand Up @@ -169,7 +176,6 @@ private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Pr
StoredFieldLoader storedFieldLoader = profiler.storedFields(StoredFieldLoader.fromSpec(storedFieldsSpec));
IdLoader idLoader = context.newIdLoader();
boolean requiresSource = storedFieldsSpec.requiresSource();
final int[] locallyAccumulatedBytes = new int[1];
NestedDocuments nestedDocuments = context.getSearchExecutionContext().getNestedDocuments();

FetchPhaseDocsIterator docsIterator = new FetchPhaseDocsIterator() {
Expand Down Expand Up @@ -206,10 +212,6 @@ protected SearchHit nextDoc(int doc) throws IOException {
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
if (context.checkRealMemoryCB(locallyAccumulatedBytes[0], "fetch source")) {
// if we checked the real memory breaker, we restart our local accounting
locallyAccumulatedBytes[0] = 0;
Copy link
Contributor Author

@luigidellaquila luigidellaquila Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the time these batches were too small, so this didn't trigger.

}

HitContext hit = prepareHitContext(
context,
Expand All @@ -233,7 +235,9 @@ protected SearchHit nextDoc(int doc) throws IOException {

BytesReference sourceRef = hit.hit().getSourceRef();
if (sourceRef != null) {
locallyAccumulatedBytes[0] += sourceRef.length();
// This is an empirical value that seems to work well.
// Deserializing a large source would also mean serializing it to HTTP response later on, so x2 seems reasonable
memoryChecker.accept(sourceRef.length() * 2);
}
success = true;
return hit.hit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ private void hitExecute(Map<String, InnerHitsContext.InnerHitSubContext> innerHi
innerHitsContext.setRootId(hit.getId());
innerHitsContext.setRootLookup(rootSource);

fetchPhase.execute(innerHitsContext, docIdsToLoad, null);
fetchPhase.execute(innerHitsContext, docIdsToLoad, null, i -> {});
FetchSearchResult fetchResult = innerHitsContext.fetchResult();
SearchHit[] internalHits = fetchResult.fetchResult().hits().getHits();
for (int j = 0; j < internalHits.length; j++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ public Query rewrittenQuery() {
*/
public final boolean checkRealMemoryCB(int locallyAccumulatedBytes, String label) {
if (locallyAccumulatedBytes >= memAccountingBufferSize()) {
circuitBreaker().addEstimateBytesAndMaybeBreak(0, label);
circuitBreaker().addEstimateBytesAndMaybeBreak(locallyAccumulatedBytes, label);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the crux

return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ public void testFetchTimeoutWithPartialResults() throws IOException {
ContextIndexSearcher contextIndexSearcher = createSearcher(r);
try (SearchContext searchContext = createSearchContext(contextIndexSearcher, true)) {
FetchPhase fetchPhase = createFetchPhase(contextIndexSearcher);
fetchPhase.execute(searchContext, new int[] { 0, 1, 2 }, null);
fetchPhase.execute(searchContext, new int[] { 0, 1, 2 }, null, i -> {});
assertTrue(searchContext.queryResult().searchTimedOut());
assertEquals(1, searchContext.fetchResult().hits().getHits().length);
} finally {
Expand All @@ -811,7 +811,7 @@ public void testFetchTimeoutNoPartialResults() throws IOException {

try (SearchContext searchContext = createSearchContext(contextIndexSearcher, false)) {
FetchPhase fetchPhase = createFetchPhase(contextIndexSearcher);
expectThrows(SearchTimeoutException.class, () -> fetchPhase.execute(searchContext, new int[] { 0, 1, 2 }, null));
expectThrows(SearchTimeoutException.class, () -> fetchPhase.execute(searchContext, new int[] { 0, 1, 2 }, null, i -> {}));
assertNull(searchContext.fetchResult().hits());
} finally {
r.close();
Expand Down Expand Up @@ -867,8 +867,13 @@ public StoredFieldsSpec storedFieldsSpec() {
return StoredFieldsSpec.NEEDS_SOURCE;
}
}));
fetchPhase.execute(searchContext, IntStream.range(0, 100).toArray(), null);
assertThat(breakerCalledCount.get(), is(4));
fetchPhase.execute(
searchContext,
IntStream.range(0, 100).toArray(),
null,
i -> breakingCircuitBreaker.addEstimateBytesAndMaybeBreak(i, "test")
);
assertThat(breakerCalledCount.get(), is(100));
} finally {
r.close();
dir.close();
Expand Down Expand Up @@ -923,7 +928,7 @@ public StoredFieldsSpec storedFieldsSpec() {
}));
FetchPhaseExecutionException fetchPhaseExecutionException = assertThrows(
FetchPhaseExecutionException.class,
() -> fetchPhase.execute(searchContext, IntStream.range(0, 100).toArray(), null)
() -> fetchPhase.execute(searchContext, IntStream.range(0, 100).toArray(), null, i -> {})
);
assertThat(fetchPhaseExecutionException.getCause().getMessage(), is("bad things"));
} finally {
Expand Down