Skip to content

Commit 48144ba

Browse files
Fix SearchResponse reference count leaks in ML module (#103009)
Fixing all kinds of leaks in both ml prod and test code. Added a new utility for a very common operation in tests that I'm planning on replacing other use sites with in a follow up.
1 parent 7413e41 commit 48144ba

File tree

34 files changed

+574
-407
lines changed

34 files changed

+574
-407
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the Server Side Public License, v 1; you may not use this file except
5+
* in compliance with, at your election, the Elastic License 2.0 or the Server
6+
* Side Public License, v 1.
7+
*/
8+
package org.elasticsearch.search;
9+
10+
import org.elasticsearch.action.search.SearchRequestBuilder;
11+
12+
public enum SearchResponseUtils {
13+
;
14+
15+
public static long getTotalHitsValue(SearchRequestBuilder request) {
16+
var resp = request.get();
17+
try {
18+
return resp.getHits().getTotalHits().value;
19+
} finally {
20+
resp.decRef();
21+
}
22+
}
23+
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ClientHelperTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ private void assertExecutionWithOrigin(Map<String, String> storedHeaders, Client
340340
assertThat(headers, not(hasEntry(AuthenticationServiceField.RUN_AS_USER_HEADER, "anything")));
341341

342342
return client.search(new SearchRequest()).actionGet();
343-
});
343+
}).decRef();
344344
}
345345

346346
/**
@@ -356,7 +356,7 @@ public void assertRunAsExecution(Map<String, String> storedHeaders, Consumer<Map
356356

357357
consumer.accept(client.threadPool().getThreadContext().getHeaders());
358358
return client.search(new SearchRequest()).actionGet();
359-
});
359+
}).decRef();
360360
}
361361

362362
public void testFilterSecurityHeaders() {

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationHousePricingIT.java

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,36 +1567,38 @@ public void testFeatureImportanceValues() throws Exception {
15671567

15681568
client().admin().indices().refresh(new RefreshRequest(destIndex));
15691569
SearchResponse sourceData = prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
1570-
1571-
// obtain addition information for investigation of #90599
1572-
String modelId = getModelId(jobId);
1573-
TrainedModelMetadata modelMetadata = getModelMetadata(modelId);
1574-
assertThat(modelMetadata.getHyperparameters().size(), greaterThan(0));
1575-
StringBuilder hyperparameters = new StringBuilder(); // used to investigate #90019
1576-
for (Hyperparameters hyperparameter : modelMetadata.getHyperparameters()) {
1577-
hyperparameters.append(hyperparameter.hyperparameterName).append(": ").append(hyperparameter.value).append("\n");
1578-
}
1579-
TrainedModelDefinition modelDefinition = getModelDefinition(modelId);
1580-
Ensemble ensemble = (Ensemble) modelDefinition.getTrainedModel();
1581-
int numberTrees = ensemble.getModels().size();
1582-
String str = "Failure: failed for modelId %s numberTrees %d\n";
1583-
for (SearchHit hit : sourceData.getHits()) {
1584-
Map<String, Object> destDoc = getDestDoc(config, hit);
1585-
assertNotNull(destDoc);
1586-
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
1587-
assertThat(resultsObject.containsKey(predictionField), is(true));
1588-
String predictionValue = (String) resultsObject.get(predictionField);
1589-
assertNotNull(predictionValue);
1590-
assertThat(resultsObject.containsKey("feature_importance"), is(true));
1591-
@SuppressWarnings("unchecked")
1592-
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>) resultsObject.get("feature_importance");
1593-
assertThat(
1594-
Strings.format(str, modelId, numberTrees) + predictionValue + hyperparameters + modelDefinition,
1595-
importanceArray,
1596-
hasSize(greaterThan(0))
1597-
);
1570+
try {
1571+
// obtain addition information for investigation of #90599
1572+
String modelId = getModelId(jobId);
1573+
TrainedModelMetadata modelMetadata = getModelMetadata(modelId);
1574+
assertThat(modelMetadata.getHyperparameters().size(), greaterThan(0));
1575+
StringBuilder hyperparameters = new StringBuilder(); // used to investigate #90019
1576+
for (Hyperparameters hyperparameter : modelMetadata.getHyperparameters()) {
1577+
hyperparameters.append(hyperparameter.hyperparameterName).append(": ").append(hyperparameter.value).append("\n");
1578+
}
1579+
TrainedModelDefinition modelDefinition = getModelDefinition(modelId);
1580+
Ensemble ensemble = (Ensemble) modelDefinition.getTrainedModel();
1581+
int numberTrees = ensemble.getModels().size();
1582+
String str = "Failure: failed for modelId %s numberTrees %d\n";
1583+
for (SearchHit hit : sourceData.getHits()) {
1584+
Map<String, Object> destDoc = getDestDoc(config, hit);
1585+
assertNotNull(destDoc);
1586+
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
1587+
assertThat(resultsObject.containsKey(predictionField), is(true));
1588+
String predictionValue = (String) resultsObject.get(predictionField);
1589+
assertNotNull(predictionValue);
1590+
assertThat(resultsObject.containsKey("feature_importance"), is(true));
1591+
@SuppressWarnings("unchecked")
1592+
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>) resultsObject.get("feature_importance");
1593+
assertThat(
1594+
Strings.format(str, modelId, numberTrees) + predictionValue + hyperparameters + modelDefinition,
1595+
importanceArray,
1596+
hasSize(greaterThan(0))
1597+
);
1598+
}
1599+
} finally {
1600+
sourceData.decRef();
15981601
}
1599-
16001602
}
16011603

16021604
static void indexData(String sourceIndex) {

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DatafeedWithAggsIT.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import java.util.List;
3636
import java.util.concurrent.TimeUnit;
3737

38+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
3839
import static org.hamcrest.Matchers.equalTo;
3940
import static org.hamcrest.Matchers.greaterThan;
4041
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -163,16 +164,16 @@ private void testDfWithAggs(AggregatorFactories.Builder aggs, Detector.Builder d
163164
bucket.getEventCount()
164165
);
165166
// Confirm that it's possible to search for the same buckets by @timestamp - proves that @timestamp works as a field alias
166-
assertThat(
167+
assertHitCount(
167168
prepareSearch(AnomalyDetectorsIndex.jobResultsAliasedName(jobId)).setQuery(
168169
QueryBuilders.boolQuery()
169170
.filter(QueryBuilders.termQuery("job_id", jobId))
170171
.filter(QueryBuilders.termQuery("result_type", "bucket"))
171172
.filter(
172173
QueryBuilders.rangeQuery("@timestamp").gte(bucket.getTimestamp().getTime()).lte(bucket.getTimestamp().getTime())
173174
)
174-
).setTrackTotalHits(true).get().getHits().getTotalHits().value,
175-
equalTo(1L)
175+
).setTrackTotalHits(true),
176+
1
176177
);
177178
}
178179
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DeleteExpiredDataIT.java

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.core.TimeValue;
2121
import org.elasticsearch.index.query.QueryBuilders;
2222
import org.elasticsearch.search.SearchHit;
23+
import org.elasticsearch.search.SearchResponseUtils;
2324
import org.elasticsearch.xcontent.ToXContent;
2425
import org.elasticsearch.xcontent.XContentBuilder;
2526
import org.elasticsearch.xcontent.XContentFactory;
@@ -268,14 +269,13 @@ private void testExpiredDeletion(Float customThrottle, int numUnusedState) throw
268269

269270
retainAllSnapshots("snapshots-retention-with-retain");
270271

271-
long totalModelSizeStatsBeforeDelete = prepareSearch("*").setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN)
272-
.setQuery(QueryBuilders.termQuery("result_type", "model_size_stats"))
273-
.get()
274-
.getHits()
275-
.getTotalHits().value;
276-
long totalNotificationsCountBeforeDelete = prepareSearch(NotificationsIndex.NOTIFICATIONS_INDEX).get()
277-
.getHits()
278-
.getTotalHits().value;
272+
long totalModelSizeStatsBeforeDelete = SearchResponseUtils.getTotalHitsValue(
273+
prepareSearch("*").setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN)
274+
.setQuery(QueryBuilders.termQuery("result_type", "model_size_stats"))
275+
);
276+
long totalNotificationsCountBeforeDelete = SearchResponseUtils.getTotalHitsValue(
277+
prepareSearch(NotificationsIndex.NOTIFICATIONS_INDEX)
278+
);
279279
assertThat(totalModelSizeStatsBeforeDelete, greaterThan(0L));
280280
assertThat(totalNotificationsCountBeforeDelete, greaterThan(0L));
281281

@@ -319,14 +319,13 @@ private void testExpiredDeletion(Float customThrottle, int numUnusedState) throw
319319
assertThat(getRecords("results-and-snapshots-retention").size(), equalTo(0));
320320
assertThat(getModelSnapshots("results-and-snapshots-retention").size(), equalTo(1));
321321

322-
long totalModelSizeStatsAfterDelete = prepareSearch("*").setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN)
323-
.setQuery(QueryBuilders.termQuery("result_type", "model_size_stats"))
324-
.get()
325-
.getHits()
326-
.getTotalHits().value;
327-
long totalNotificationsCountAfterDelete = prepareSearch(NotificationsIndex.NOTIFICATIONS_INDEX).get()
328-
.getHits()
329-
.getTotalHits().value;
322+
long totalModelSizeStatsAfterDelete = SearchResponseUtils.getTotalHitsValue(
323+
prepareSearch("*").setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN)
324+
.setQuery(QueryBuilders.termQuery("result_type", "model_size_stats"))
325+
);
326+
long totalNotificationsCountAfterDelete = SearchResponseUtils.getTotalHitsValue(
327+
prepareSearch(NotificationsIndex.NOTIFICATIONS_INDEX)
328+
);
330329
assertThat(totalModelSizeStatsAfterDelete, equalTo(totalModelSizeStatsBeforeDelete));
331330
assertThat(totalNotificationsCountAfterDelete, greaterThanOrEqualTo(totalNotificationsCountBeforeDelete));
332331

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.index.query.QueryBuilders;
2121
import org.elasticsearch.rest.RestStatus;
2222
import org.elasticsearch.search.SearchHit;
23+
import org.elasticsearch.search.SearchResponseUtils;
2324
import org.elasticsearch.xcontent.XContentType;
2425
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
2526
import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
@@ -396,11 +397,12 @@ public void testStopOutlierDetectionWithEnoughDocumentsToScroll() throws Excepti
396397

397398
assertResponse(prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true), searchResponse -> {
398399
if (searchResponse.getHits().getTotalHits().value == docCount) {
399-
searchResponse = prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true)
400-
.setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score"))
401-
.get();
402-
logger.debug("We stopped during analysis: [{}] < [{}]", searchResponse.getHits().getTotalHits().value, docCount);
403-
assertThat(searchResponse.getHits().getTotalHits().value, lessThan((long) docCount));
400+
long seenCount = SearchResponseUtils.getTotalHitsValue(
401+
prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true)
402+
.setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score"))
403+
);
404+
logger.debug("We stopped during analysis: [{}] < [{}]", seenCount, docCount);
405+
assertThat(seenCount, lessThan((long) docCount));
404406
} else {
405407
logger.debug("We stopped during reindexing: [{}] < [{}]", searchResponse.getHits().getTotalHits().value, docCount);
406408
}

x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ public void testInferenceAggRestricted() {
756756

757757
SearchRequest search = new SearchRequest(index);
758758
search.source().aggregation(termsAgg);
759-
client().search(search).actionGet();
759+
client().search(search).actionGet().decRef();
760760

761761
// Pick a license that does not allow machine learning
762762
License.OperationMode mode = randomInvalidLicenseType();

x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/BucketCorrelationAggregationIT.java

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.elasticsearch.action.bulk.BulkRequestBuilder;
1313
import org.elasticsearch.action.bulk.BulkResponse;
1414
import org.elasticsearch.action.index.IndexRequest;
15-
import org.elasticsearch.action.search.SearchResponse;
1615
import org.elasticsearch.action.support.WriteRequest;
1716
import org.elasticsearch.core.Tuple;
1817
import org.elasticsearch.search.aggregations.AggregationBuilders;
@@ -31,6 +30,7 @@
3130
import java.util.concurrent.atomic.AtomicLong;
3231
import java.util.stream.Stream;
3332

33+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
3434
import static org.hamcrest.Matchers.closeTo;
3535

3636
public class BucketCorrelationAggregationIT extends MlSingleNodeTestCase {
@@ -71,34 +71,42 @@ public void testCountCorrelation() {
7171

7272
AtomicLong counter = new AtomicLong();
7373
double[] steps = Stream.generate(() -> counter.getAndAdd(2L)).limit(50).mapToDouble(l -> (double) l).toArray();
74-
SearchResponse percentilesSearch = client().prepareSearch("data")
75-
.addAggregation(AggregationBuilders.percentiles("percentiles").field("metric").percentiles(steps))
76-
.setSize(0)
77-
.setTrackTotalHits(true)
78-
.get();
79-
long totalHits = percentilesSearch.getHits().getTotalHits().value;
80-
Percentiles percentiles = percentilesSearch.getAggregations().get("percentiles");
81-
Tuple<RangeAggregationBuilder, BucketCorrelationAggregationBuilder> aggs = buildRangeAggAndSetExpectations(
82-
percentiles,
83-
steps,
84-
totalHits,
85-
"metric"
74+
assertResponse(
75+
client().prepareSearch("data")
76+
.addAggregation(AggregationBuilders.percentiles("percentiles").field("metric").percentiles(steps))
77+
.setSize(0)
78+
.setTrackTotalHits(true),
79+
percentilesSearch -> {
80+
long totalHits = percentilesSearch.getHits().getTotalHits().value;
81+
Percentiles percentiles = percentilesSearch.getAggregations().get("percentiles");
82+
Tuple<RangeAggregationBuilder, BucketCorrelationAggregationBuilder> aggs = buildRangeAggAndSetExpectations(
83+
percentiles,
84+
steps,
85+
totalHits,
86+
"metric"
87+
);
88+
89+
assertResponse(
90+
client().prepareSearch("data")
91+
.setSize(0)
92+
.setTrackTotalHits(false)
93+
.addAggregation(
94+
AggregationBuilders.terms("buckets").field("term").subAggregation(aggs.v1()).subAggregation(aggs.v2())
95+
),
96+
countCorrelations -> {
97+
98+
Terms terms = countCorrelations.getAggregations().get("buckets");
99+
Terms.Bucket catBucket = terms.getBucketByKey("cat");
100+
Terms.Bucket dogBucket = terms.getBucketByKey("dog");
101+
NumericMetricsAggregation.SingleValue approxCatCorrelation = catBucket.getAggregations().get("correlates");
102+
NumericMetricsAggregation.SingleValue approxDogCorrelation = dogBucket.getAggregations().get("correlates");
103+
104+
assertThat(approxCatCorrelation.value(), closeTo(catCorrelation, 0.1));
105+
assertThat(approxDogCorrelation.value(), closeTo(dogCorrelation, 0.1));
106+
}
107+
);
108+
}
86109
);
87-
88-
SearchResponse countCorrelations = client().prepareSearch("data")
89-
.setSize(0)
90-
.setTrackTotalHits(false)
91-
.addAggregation(AggregationBuilders.terms("buckets").field("term").subAggregation(aggs.v1()).subAggregation(aggs.v2()))
92-
.get();
93-
94-
Terms terms = countCorrelations.getAggregations().get("buckets");
95-
Terms.Bucket catBucket = terms.getBucketByKey("cat");
96-
Terms.Bucket dogBucket = terms.getBucketByKey("dog");
97-
NumericMetricsAggregation.SingleValue approxCatCorrelation = catBucket.getAggregations().get("correlates");
98-
NumericMetricsAggregation.SingleValue approxDogCorrelation = dogBucket.getAggregations().get("correlates");
99-
100-
assertThat(approxCatCorrelation.value(), closeTo(catCorrelation, 0.1));
101-
assertThat(approxDogCorrelation.value(), closeTo(dogCorrelation, 0.1));
102110
}
103111

104112
private static Tuple<RangeAggregationBuilder, BucketCorrelationAggregationBuilder> buildRangeAggAndSetExpectations(

0 commit comments

Comments
 (0)