Skip to content

Commit 23bf992

Browse files
luigidellaquilaafoucret
authored andcommitted
Fix SearchContext CB memory accounting (elastic#138002)
1 parent 27fa31d commit 23bf992

File tree

8 files changed

+264
-35
lines changed

8 files changed

+264
-35
lines changed

docs/changelog/138002.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 138002
2+
summary: Fix `SearchContext` CB memory accounting
3+
area: Aggregations
4+
type: bug
5+
issues: []
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.search.aggregations.metrics;
10+
11+
import org.apache.logging.log4j.util.Strings;
12+
import org.elasticsearch.action.index.IndexRequestBuilder;
13+
import org.elasticsearch.action.search.SearchRequestBuilder;
14+
import org.elasticsearch.common.settings.Settings;
15+
import org.elasticsearch.rest.RestStatus;
16+
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
17+
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregatorFactory.ExecutionMode;
18+
import org.elasticsearch.search.sort.SortBuilders;
19+
import org.elasticsearch.search.sort.SortOrder;
20+
import org.elasticsearch.test.ESIntegTestCase;
21+
22+
import java.io.IOException;
23+
import java.util.ArrayList;
24+
import java.util.List;
25+
26+
import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
27+
import static org.elasticsearch.search.aggregations.AggregationBuilders.topHits;
28+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
29+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFailures;
30+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse;
31+
import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder;
32+
import static org.hamcrest.Matchers.containsString;
33+
import static org.hamcrest.Matchers.notNullValue;
34+
35+
@ESIntegTestCase.SuiteScopeTestCase()
36+
public class LargeTopHitsIT extends ESIntegTestCase {
37+
38+
private static final String TERMS_AGGS_FIELD_1 = "terms1";
39+
private static final String TERMS_AGGS_FIELD_2 = "terms2";
40+
private static final String TERMS_AGGS_FIELD_3 = "terms3";
41+
private static final String SORT_FIELD = "sort";
42+
43+
@Override
44+
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
45+
return Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)).put("indices.breaker.request.type", "memory").build();
46+
}
47+
48+
public static String randomExecutionHint() {
49+
return randomBoolean() ? null : randomFrom(ExecutionMode.values()).toString();
50+
}
51+
52+
@Override
53+
public void setupSuiteScopeCluster() throws Exception {
54+
initSmallIdx();
55+
ensureSearchable();
56+
}
57+
58+
private void initSmallIdx() throws IOException {
59+
createIndex("small_idx");
60+
ensureGreen("small_idx");
61+
populateIndex("small_idx", 5, 40_000);
62+
}
63+
64+
private void initLargeIdx() throws IOException {
65+
createIndex("large_idx");
66+
ensureGreen("large_idx");
67+
populateIndex("large_idx", 70, 50_000);
68+
}
69+
70+
public void testSimple() {
71+
assertNoFailuresAndResponse(query("small_idx"), response -> {
72+
Terms terms = response.getAggregations().get("terms");
73+
assertThat(terms, notNullValue());
74+
});
75+
}
76+
77+
public void test500Queries() {
78+
for (int i = 0; i < 500; i++) {
79+
// make sure we are not leaking memory over multiple queries
80+
assertNoFailuresAndResponse(query("small_idx"), response -> {
81+
Terms terms = response.getAggregations().get("terms");
82+
assertThat(terms, notNullValue());
83+
});
84+
}
85+
}
86+
87+
// This works most of the time, but it's not consistent: it still triggers OOM sometimes.
88+
// The test env is too small and non-deterministic to hold all these data and results.
89+
@AwaitsFix(bugUrl = "see comment above")
90+
public void testBreakAndRecover() throws IOException {
91+
initLargeIdx();
92+
assertNoFailuresAndResponse(query("small_idx"), response -> {
93+
Terms terms = response.getAggregations().get("terms");
94+
assertThat(terms, notNullValue());
95+
});
96+
97+
assertFailures(query("large_idx"), RestStatus.TOO_MANY_REQUESTS, containsString("Data too large"));
98+
99+
assertNoFailuresAndResponse(query("small_idx"), response -> {
100+
Terms terms = response.getAggregations().get("terms");
101+
assertThat(terms, notNullValue());
102+
});
103+
}
104+
105+
private void createIndex(String idxName) {
106+
assertAcked(
107+
prepareCreate(idxName).setMapping(
108+
TERMS_AGGS_FIELD_1,
109+
"type=keyword",
110+
TERMS_AGGS_FIELD_2,
111+
"type=keyword",
112+
TERMS_AGGS_FIELD_3,
113+
"type=keyword",
114+
"text",
115+
"type=text,store=true",
116+
"large_text_1",
117+
"type=text,store=false",
118+
"large_text_2",
119+
"type=text,store=false",
120+
"large_text_3",
121+
"type=text,store=false",
122+
"large_text_4",
123+
"type=text,store=false",
124+
"large_text_5",
125+
"type=text,store=false"
126+
)
127+
);
128+
}
129+
130+
private void populateIndex(String idxName, int nDocs, int size) throws IOException {
131+
for (int i = 0; i < nDocs; i++) {
132+
List<IndexRequestBuilder> builders = new ArrayList<>();
133+
builders.add(
134+
prepareIndex(idxName).setId(Integer.toString(i))
135+
.setSource(
136+
jsonBuilder().startObject()
137+
.field(TERMS_AGGS_FIELD_1, "val" + i % 53)
138+
.field(TERMS_AGGS_FIELD_2, "val" + i % 23)
139+
.field(TERMS_AGGS_FIELD_3, "val" + i % 10)
140+
.field(SORT_FIELD, i)
141+
.field("text", "some text to entertain")
142+
.field("large_text_1", Strings.repeat("this is a text field 1 ", size))
143+
.field("large_text_2", Strings.repeat("this is a text field 2 ", size))
144+
.field("large_text_3", Strings.repeat("this is a text field 3 ", size))
145+
.field("large_text_4", Strings.repeat("this is a text field 4 ", size))
146+
.field("large_text_5", Strings.repeat("this is a text field 5 ", size))
147+
.field("field1", 5)
148+
.field("field2", 2.71)
149+
.endObject()
150+
)
151+
);
152+
153+
indexRandom(true, builders);
154+
}
155+
}
156+
157+
private static SearchRequestBuilder query(String indexName) {
158+
return prepareSearch(indexName).addAggregation(
159+
terms("terms").executionHint(randomExecutionHint())
160+
.field(TERMS_AGGS_FIELD_1)
161+
.subAggregation(
162+
terms("terms").executionHint(randomExecutionHint())
163+
.field(TERMS_AGGS_FIELD_2)
164+
.subAggregation(
165+
terms("terms").executionHint(randomExecutionHint())
166+
.field(TERMS_AGGS_FIELD_2)
167+
.subAggregation(topHits("hits").sort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC)))
168+
)
169+
)
170+
);
171+
}
172+
}

server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import java.util.List;
5555
import java.util.Map;
5656
import java.util.function.BiConsumer;
57+
import java.util.function.IntConsumer;
5758

5859
class TopHitsAggregator extends MetricsAggregator {
5960

@@ -198,7 +199,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE
198199
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
199200
docIdsToLoad[i] = topDocs.scoreDocs[i].doc;
200201
}
201-
FetchSearchResult fetchResult = runFetchPhase(subSearchContext, docIdsToLoad);
202+
FetchSearchResult fetchResult = runFetchPhase(subSearchContext, docIdsToLoad, this::addRequestCircuitBreakerBytes);
202203
if (fetchProfiles != null) {
203204
fetchProfiles.add(fetchResult.profileResult());
204205
}
@@ -222,7 +223,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE
222223
);
223224
}
224225

225-
private static FetchSearchResult runFetchPhase(SubSearchContext subSearchContext, int[] docIdsToLoad) {
226+
private static FetchSearchResult runFetchPhase(SubSearchContext subSearchContext, int[] docIdsToLoad, IntConsumer memoryChecker) {
226227
// Fork the search execution context for each slice, because the fetch phase does not support concurrent execution yet.
227228
SearchExecutionContext searchExecutionContext = new SearchExecutionContext(subSearchContext.getSearchExecutionContext());
228229
// InnerHitSubContext is not thread-safe, so we fork it as well to support concurrent execution
@@ -242,7 +243,7 @@ public InnerHitsContext innerHits() {
242243
}
243244
};
244245

245-
fetchSubSearchContext.fetchPhase().execute(fetchSubSearchContext, docIdsToLoad, null);
246+
fetchSubSearchContext.fetchPhase().execute(fetchSubSearchContext, docIdsToLoad, null, memoryChecker);
246247
return fetchSubSearchContext.fetchResult();
247248
}
248249

server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.apache.lucene.index.LeafReaderContext;
1515
import org.apache.lucene.search.TotalHits;
1616
import org.elasticsearch.common.bytes.BytesReference;
17+
import org.elasticsearch.core.Nullable;
1718
import org.elasticsearch.index.fieldvisitor.LeafStoredFieldLoader;
1819
import org.elasticsearch.index.fieldvisitor.StoredFieldLoader;
1920
import org.elasticsearch.index.mapper.IdLoader;
@@ -47,6 +48,7 @@
4748
import java.util.Collections;
4849
import java.util.List;
4950
import java.util.Map;
51+
import java.util.function.IntConsumer;
5052
import java.util.function.Supplier;
5153

5254
import static org.elasticsearch.index.get.ShardGetService.maybeExcludeVectorFields;
@@ -67,6 +69,17 @@ public FetchPhase(List<FetchSubPhase> fetchSubPhases) {
6769
}
6870

6971
public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo rankDocs) {
72+
execute(context, docIdsToLoad, rankDocs, null);
73+
}
74+
75+
/**
76+
*
77+
* @param context
78+
* @param docIdsToLoad
79+
* @param rankDocs
80+
* @param memoryChecker if not provided, the fetch phase will use the circuit breaker to check memory usage
81+
*/
82+
public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo rankDocs, @Nullable IntConsumer memoryChecker) {
7083
if (LOGGER.isTraceEnabled()) {
7184
LOGGER.trace("{}", new SearchContextSourcePrinter(context));
7285
}
@@ -88,7 +101,7 @@ public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo
88101
: Profilers.startProfilingFetchPhase();
89102
SearchHits hits = null;
90103
try {
91-
hits = buildSearchHits(context, docIdsToLoad, profiler, rankDocs);
104+
hits = buildSearchHits(context, docIdsToLoad, profiler, rankDocs, memoryChecker);
92105
} finally {
93106
try {
94107
// Always finish profiling
@@ -116,7 +129,13 @@ public Source getSource(LeafReaderContext ctx, int doc) {
116129
}
117130
}
118131

119-
private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Profiler profiler, RankDocShardInfo rankDocs) {
132+
private SearchHits buildSearchHits(
133+
SearchContext context,
134+
int[] docIdsToLoad,
135+
Profiler profiler,
136+
RankDocShardInfo rankDocs,
137+
IntConsumer memoryChecker
138+
) {
120139
var lookup = context.getSearchExecutionContext().getMappingLookup();
121140

122141
// Optionally remove sparse and dense vector fields early to:
@@ -180,6 +199,14 @@ private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Pr
180199
SourceLoader.Leaf leafSourceLoader;
181200
IdLoader.Leaf leafIdLoader;
182201

202+
IntConsumer memChecker = memoryChecker != null ? memoryChecker : bytes -> {
203+
locallyAccumulatedBytes[0] += bytes;
204+
if (context.checkCircuitBreaker(locallyAccumulatedBytes[0], "fetch source")) {
205+
addRequestBreakerBytes(locallyAccumulatedBytes[0]);
206+
locallyAccumulatedBytes[0] = 0;
207+
}
208+
};
209+
183210
@Override
184211
protected void setNextReader(LeafReaderContext ctx, int[] docsInLeaf) throws IOException {
185212
Timer timer = profiler.startNextReader();
@@ -206,10 +233,6 @@ protected SearchHit nextDoc(int doc) throws IOException {
206233
if (context.isCancelled()) {
207234
throw new TaskCancelledException("cancelled");
208235
}
209-
if (context.checkRealMemoryCB(locallyAccumulatedBytes[0], "fetch source")) {
210-
// if we checked the real memory breaker, we restart our local accounting
211-
locallyAccumulatedBytes[0] = 0;
212-
}
213236

214237
HitContext hit = prepareHitContext(
215238
context,
@@ -233,7 +256,9 @@ protected SearchHit nextDoc(int doc) throws IOException {
233256

234257
BytesReference sourceRef = hit.hit().getSourceRef();
235258
if (sourceRef != null) {
236-
locallyAccumulatedBytes[0] += sourceRef.length();
259+
// This is an empirical value that seems to work well.
260+
// Deserializing a large source would also mean serializing it to HTTP response later on, so x2 seems reasonable
261+
memChecker.accept(sourceRef.length() * 2);
237262
}
238263
success = true;
239264
return hit.hit();
@@ -245,24 +270,31 @@ protected SearchHit nextDoc(int doc) throws IOException {
245270
}
246271
};
247272

248-
SearchHit[] hits = docsIterator.iterate(
249-
context.shardTarget(),
250-
context.searcher().getIndexReader(),
251-
docIdsToLoad,
252-
context.request().allowPartialSearchResults(),
253-
context.queryResult()
254-
);
273+
try {
274+
SearchHit[] hits = docsIterator.iterate(
275+
context.shardTarget(),
276+
context.searcher().getIndexReader(),
277+
docIdsToLoad,
278+
context.request().allowPartialSearchResults(),
279+
context.queryResult()
280+
);
255281

256-
if (context.isCancelled()) {
257-
for (SearchHit hit : hits) {
258-
// release all hits that would otherwise become owned and eventually released by SearchHits below
259-
hit.decRef();
282+
if (context.isCancelled()) {
283+
for (SearchHit hit : hits) {
284+
// release all hits that would otherwise become owned and eventually released by SearchHits below
285+
hit.decRef();
286+
}
287+
throw new TaskCancelledException("cancelled");
260288
}
261-
throw new TaskCancelledException("cancelled");
262-
}
263289

264-
TotalHits totalHits = context.getTotalHits();
265-
return new SearchHits(hits, totalHits, context.getMaxScore());
290+
TotalHits totalHits = context.getTotalHits();
291+
return new SearchHits(hits, totalHits, context.getMaxScore());
292+
} finally {
293+
long bytes = docsIterator.getRequestBreakerBytes();
294+
if (bytes > 0L) {
295+
context.circuitBreaker().addWithoutBreaking(-bytes);
296+
}
297+
}
266298
}
267299

268300
List<FetchSubPhaseProcessor> getProcessors(SearchShardTarget target, FetchContext context, Profiler profiler) {

server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@
3232
*/
3333
abstract class FetchPhaseDocsIterator {
3434

35+
/**
36+
* Accounts for FetchPhase memory usage.
37+
* It gets cleaned up after each fetch phase and should not be accessed/modified by subclasses.
38+
*/
39+
private long requestBreakerBytes;
40+
41+
public void addRequestBreakerBytes(long delta) {
42+
requestBreakerBytes += delta;
43+
}
44+
45+
public long getRequestBreakerBytes() {
46+
return requestBreakerBytes;
47+
}
48+
3549
/**
3650
* Called when a new leaf reader is reached
3751
* @param ctx the leaf reader for this set of doc ids

server/src/main/java/org/elasticsearch/search/internal/SearchContext.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -386,15 +386,15 @@ public Query rewrittenQuery() {
386386
public abstract long memAccountingBufferSize();
387387

388388
/**
389-
* Checks if the accumulated bytes are greater than the buffer size and if so, checks the available memory in the parent breaker
390-
* (the real memory breaker).
389+
* Checks if the accumulated bytes are greater than the buffer size and if so, checks the circuit breaker.
390+
* IMPORTANT: the caller is responsible for cleaning up the circuit breaker.
391391
* @param locallyAccumulatedBytes the number of bytes accumulated locally
392392
* @param label the label to use in the breaker
393-
* @return true if the real memory breaker is called and false otherwise
393+
* @return true if the circuit breaker is called and false otherwise
394394
*/
395-
public final boolean checkRealMemoryCB(int locallyAccumulatedBytes, String label) {
395+
public final boolean checkCircuitBreaker(int locallyAccumulatedBytes, String label) {
396396
if (locallyAccumulatedBytes >= memAccountingBufferSize()) {
397-
circuitBreaker().addEstimateBytesAndMaybeBreak(0, label);
397+
circuitBreaker().addEstimateBytesAndMaybeBreak(locallyAccumulatedBytes, label);
398398
return true;
399399
}
400400
return false;

0 commit comments

Comments
 (0)