diff --git a/docs/changelog/138002.yaml b/docs/changelog/138002.yaml new file mode 100644 index 0000000000000..6fad9db993d9f --- /dev/null +++ b/docs/changelog/138002.yaml @@ -0,0 +1,5 @@ +pr: 138002 +summary: Fix `SearchContext` CB memory accounting +area: Aggregations +type: bug +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/LargeTopHitsIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/LargeTopHitsIT.java new file mode 100644 index 0000000000000..9d446104fea6b --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/LargeTopHitsIT.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.search.aggregations.metrics; + +import org.apache.logging.log4j.util.Strings; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregatorFactory.ExecutionMode; +import org.elasticsearch.search.sort.SortBuilders; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.test.ESIntegTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.search.aggregations.AggregationBuilders.terms; +import static org.elasticsearch.search.aggregations.AggregationBuilders.topHits; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFailures; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; +import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.notNullValue; + +@ESIntegTestCase.SuiteScopeTestCase() +public class LargeTopHitsIT extends ESIntegTestCase { + + private static final String TERMS_AGGS_FIELD_1 = "terms1"; + private static final String TERMS_AGGS_FIELD_2 = "terms2"; + private static final String TERMS_AGGS_FIELD_3 = "terms3"; + private static final String SORT_FIELD = "sort"; + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)).put("indices.breaker.request.type", "memory").build(); + } + + public static String randomExecutionHint() { + return randomBoolean() ? null : randomFrom(ExecutionMode.values()).toString(); + } + + @Override + public void setupSuiteScopeCluster() throws Exception { + initSmallIdx(); + ensureSearchable(); + } + + private void initSmallIdx() throws IOException { + createIndex("small_idx"); + ensureGreen("small_idx"); + populateIndex("small_idx", 5, 40_000); + } + + private void initLargeIdx() throws IOException { + createIndex("large_idx"); + ensureGreen("large_idx"); + populateIndex("large_idx", 70, 50_000); + } + + public void testSimple() { + assertNoFailuresAndResponse(query("small_idx"), response -> { + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + }); + } + + public void test500Queries() { + for (int i = 0; i < 500; i++) { + // make sure we are not leaking memory over multiple queries + assertNoFailuresAndResponse(query("small_idx"), response -> { + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + }); + } + } + + // This works most of the time, but it's not consistent: it still triggers OOM sometimes. + // The test env is too small and non-deterministic to hold all these data and results. + @AwaitsFix(bugUrl = "see comment above") + public void testBreakAndRecover() throws IOException { + initLargeIdx(); + assertNoFailuresAndResponse(query("small_idx"), response -> { + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + }); + + assertFailures(query("large_idx"), RestStatus.TOO_MANY_REQUESTS, containsString("Data too large")); + + assertNoFailuresAndResponse(query("small_idx"), response -> { + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + }); + } + + private void createIndex(String idxName) { + assertAcked( + prepareCreate(idxName).setMapping( + TERMS_AGGS_FIELD_1, + "type=keyword", + TERMS_AGGS_FIELD_2, + "type=keyword", + TERMS_AGGS_FIELD_3, + "type=keyword", + "text", + "type=text,store=true", + "large_text_1", + "type=text,store=false", + "large_text_2", + "type=text,store=false", + "large_text_3", + "type=text,store=false", + "large_text_4", + "type=text,store=false", + "large_text_5", + "type=text,store=false" + ) + ); + } + + private void populateIndex(String idxName, int nDocs, int size) throws IOException { + for (int i = 0; i < nDocs; i++) { + List builders = new ArrayList<>(); + builders.add( + prepareIndex(idxName).setId(Integer.toString(i)) + .setSource( + jsonBuilder().startObject() + .field(TERMS_AGGS_FIELD_1, "val" + i % 53) + .field(TERMS_AGGS_FIELD_2, "val" + i % 23) + .field(TERMS_AGGS_FIELD_3, "val" + i % 10) + .field(SORT_FIELD, i) + .field("text", "some text to entertain") + .field("large_text_1", Strings.repeat("this is a text field 1 ", size)) + .field("large_text_2", Strings.repeat("this is a text field 2 ", size)) + .field("large_text_3", Strings.repeat("this is a text field 3 ", size)) + .field("large_text_4", Strings.repeat("this is a text field 4 ", size)) + .field("large_text_5", Strings.repeat("this is a text field 5 ", size)) + .field("field1", 5) + .field("field2", 2.71) + .endObject() + ) + ); + + indexRandom(true, builders); + } + } + + private static SearchRequestBuilder query(String indexName) { + return prepareSearch(indexName).addAggregation( + terms("terms").executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD_1) + .subAggregation( + terms("terms").executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD_2) + .subAggregation( + terms("terms").executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD_2) + .subAggregation(topHits("hits").sort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC))) + ) + ) + ); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 1d8550747fa85..c832190813b08 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -971,7 +971,7 @@ public void executeRankFeaturePhase(RankFeatureShardRequest request, SearchShard return searchContext.rankFeatureResult(); } RankFeatureShardPhase.prepareForFetch(searchContext, request); - fetchPhase.execute(searchContext, docIds, null); + fetchPhase.execute(searchContext, docIds, null, i -> {}); RankFeatureShardPhase.processFetch(searchContext); var rankFeatureResult = searchContext.rankFeatureResult(); rankFeatureResult.incRef(); @@ -988,7 +988,7 @@ private QueryFetchSearchResult executeFetchPhase(ReaderContext reader, SearchCon var opsListener = context.indexShard().getSearchOperationListener(); try (Releasable scope = tracer.withScope(context.getTask());) { opsListener.onPreFetchPhase(context); - fetchPhase.execute(context, shortcutDocIdsToLoad(context), null); + fetchPhase.execute(context, shortcutDocIdsToLoad(context), null, i -> {}); if (reader.singleSession()) { freeReaderContext(reader.id()); } @@ -1204,7 +1204,7 @@ public void executeFetchPhase(ShardFetchRequest request, CancellableTask task, A var opsListener = searchContext.indexShard().getSearchOperationListener(); opsListener.onPreFetchPhase(searchContext); try { - fetchPhase.execute(searchContext, request.docIds(), request.getRankDocks()); + fetchPhase.execute(searchContext, request.docIds(), request.getRankDocks(), i -> {}); if (readerContext.singleSession()) { freeReaderContext(request.contextId()); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java index 0d21f09e699b5..8ae4555479317 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java @@ -54,6 +54,7 @@ import java.util.List; import java.util.Map; import java.util.function.BiConsumer; +import java.util.function.IntConsumer; class TopHitsAggregator extends MetricsAggregator { @@ -198,7 +199,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE for (int i = 0; i < topDocs.scoreDocs.length; i++) { docIdsToLoad[i] = topDocs.scoreDocs[i].doc; } - FetchSearchResult fetchResult = runFetchPhase(subSearchContext, docIdsToLoad); + FetchSearchResult fetchResult = runFetchPhase(subSearchContext, docIdsToLoad, this::addRequestCircuitBreakerBytes); if (fetchProfiles != null) { fetchProfiles.add(fetchResult.profileResult()); } @@ -222,7 +223,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE ); } - private static FetchSearchResult runFetchPhase(SubSearchContext subSearchContext, int[] docIdsToLoad) { + private static FetchSearchResult runFetchPhase(SubSearchContext subSearchContext, int[] docIdsToLoad, IntConsumer memoryChecker) { // Fork the search execution context for each slice, because the fetch phase does not support concurrent execution yet. SearchExecutionContext searchExecutionContext = new SearchExecutionContext(subSearchContext.getSearchExecutionContext()); // InnerHitSubContext is not thread-safe, so we fork it as well to support concurrent execution @@ -242,7 +243,7 @@ public InnerHitsContext innerHits() { } }; - fetchSubSearchContext.fetchPhase().execute(fetchSubSearchContext, docIdsToLoad, null); + fetchSubSearchContext.fetchPhase().execute(fetchSubSearchContext, docIdsToLoad, null, memoryChecker); return fetchSubSearchContext.fetchResult(); } diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java index aa27e7d2f0c82..f21149bce8522 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java @@ -47,6 +47,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.IntConsumer; import java.util.function.Supplier; import static org.elasticsearch.index.get.ShardGetService.maybeExcludeVectorFields; @@ -66,7 +67,7 @@ public FetchPhase(List fetchSubPhases) { this.fetchSubPhases[fetchSubPhases.size()] = new InnerHitsPhase(this); } - public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo rankDocs) { + public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo rankDocs, IntConsumer memoryChecker) { if (LOGGER.isTraceEnabled()) { LOGGER.trace("{}", new SearchContextSourcePrinter(context)); } @@ -88,7 +89,7 @@ public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo : Profilers.startProfilingFetchPhase(); SearchHits hits = null; try { - hits = buildSearchHits(context, docIdsToLoad, profiler, rankDocs); + hits = buildSearchHits(context, docIdsToLoad, profiler, rankDocs, memoryChecker); } finally { try { // Always finish profiling @@ -116,7 +117,13 @@ public Source getSource(LeafReaderContext ctx, int doc) { } } - private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Profiler profiler, RankDocShardInfo rankDocs) { + private SearchHits buildSearchHits( + SearchContext context, + int[] docIdsToLoad, + Profiler profiler, + RankDocShardInfo rankDocs, + IntConsumer memoryChecker + ) { var lookup = context.getSearchExecutionContext().getMappingLookup(); // Optionally remove sparse and dense vector fields early to: @@ -169,7 +176,6 @@ private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Pr StoredFieldLoader storedFieldLoader = profiler.storedFields(StoredFieldLoader.fromSpec(storedFieldsSpec)); IdLoader idLoader = context.newIdLoader(); boolean requiresSource = storedFieldsSpec.requiresSource(); - final int[] locallyAccumulatedBytes = new int[1]; NestedDocuments nestedDocuments = context.getSearchExecutionContext().getNestedDocuments(); FetchPhaseDocsIterator docsIterator = new FetchPhaseDocsIterator() { @@ -206,10 +212,6 @@ protected SearchHit nextDoc(int doc) throws IOException { if (context.isCancelled()) { throw new TaskCancelledException("cancelled"); } - if (context.checkRealMemoryCB(locallyAccumulatedBytes[0], "fetch source")) { - // if we checked the real memory breaker, we restart our local accounting - locallyAccumulatedBytes[0] = 0; - } HitContext hit = prepareHitContext( context, @@ -233,7 +235,9 @@ protected SearchHit nextDoc(int doc) throws IOException { BytesReference sourceRef = hit.hit().getSourceRef(); if (sourceRef != null) { - locallyAccumulatedBytes[0] += sourceRef.length(); + // This is an empirical value that seems to work well. + // Deserializing a large source would also mean serializing it to HTTP response later on, so x2 seems reasonable + memoryChecker.accept(sourceRef.length() * 2); } success = true; return hit.hit(); diff --git a/server/src/main/java/org/elasticsearch/search/fetch/subphase/InnerHitsPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/subphase/InnerHitsPhase.java index 374c96fdefe86..50a54a60b4c37 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/subphase/InnerHitsPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/subphase/InnerHitsPhase.java @@ -93,7 +93,7 @@ private void hitExecute(Map innerHi innerHitsContext.setRootId(hit.getId()); innerHitsContext.setRootLookup(rootSource); - fetchPhase.execute(innerHitsContext, docIdsToLoad, null); + fetchPhase.execute(innerHitsContext, docIdsToLoad, null, i -> {}); FetchSearchResult fetchResult = innerHitsContext.fetchResult(); SearchHit[] internalHits = fetchResult.fetchResult().hits().getHits(); for (int j = 0; j < internalHits.length; j++) { diff --git a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java index 7d018a7ef4ba9..21e9d8594cdec 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java @@ -394,7 +394,7 @@ public Query rewrittenQuery() { */ public final boolean checkRealMemoryCB(int locallyAccumulatedBytes, String label) { if (locallyAccumulatedBytes >= memAccountingBufferSize()) { - circuitBreaker().addEstimateBytesAndMaybeBreak(0, label); + circuitBreaker().addEstimateBytesAndMaybeBreak(locallyAccumulatedBytes, label); return true; } return false; diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java index b6ca12368f762..71839d5408a1c 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java @@ -790,7 +790,7 @@ public void testFetchTimeoutWithPartialResults() throws IOException { ContextIndexSearcher contextIndexSearcher = createSearcher(r); try (SearchContext searchContext = createSearchContext(contextIndexSearcher, true)) { FetchPhase fetchPhase = createFetchPhase(contextIndexSearcher); - fetchPhase.execute(searchContext, new int[] { 0, 1, 2 }, null); + fetchPhase.execute(searchContext, new int[] { 0, 1, 2 }, null, i -> {}); assertTrue(searchContext.queryResult().searchTimedOut()); assertEquals(1, searchContext.fetchResult().hits().getHits().length); } finally { @@ -811,7 +811,7 @@ public void testFetchTimeoutNoPartialResults() throws IOException { try (SearchContext searchContext = createSearchContext(contextIndexSearcher, false)) { FetchPhase fetchPhase = createFetchPhase(contextIndexSearcher); - expectThrows(SearchTimeoutException.class, () -> fetchPhase.execute(searchContext, new int[] { 0, 1, 2 }, null)); + expectThrows(SearchTimeoutException.class, () -> fetchPhase.execute(searchContext, new int[] { 0, 1, 2 }, null, i -> {})); assertNull(searchContext.fetchResult().hits()); } finally { r.close(); @@ -867,8 +867,13 @@ public StoredFieldsSpec storedFieldsSpec() { return StoredFieldsSpec.NEEDS_SOURCE; } })); - fetchPhase.execute(searchContext, IntStream.range(0, 100).toArray(), null); - assertThat(breakerCalledCount.get(), is(4)); + fetchPhase.execute( + searchContext, + IntStream.range(0, 100).toArray(), + null, + i -> breakingCircuitBreaker.addEstimateBytesAndMaybeBreak(i, "test") + ); + assertThat(breakerCalledCount.get(), is(100)); } finally { r.close(); dir.close(); @@ -923,7 +928,7 @@ public StoredFieldsSpec storedFieldsSpec() { })); FetchPhaseExecutionException fetchPhaseExecutionException = assertThrows( FetchPhaseExecutionException.class, - () -> fetchPhase.execute(searchContext, IntStream.range(0, 100).toArray(), null) + () -> fetchPhase.execute(searchContext, IntStream.range(0, 100).toArray(), null, i -> {}) ); assertThat(fetchPhaseExecutionException.getCause().getMessage(), is("bad things")); } finally {