Skip to content

Commit 4a2abab

Browse files
authored
Refactor how SearchHits are processed in ML module (elastic#120258)
Don't accumulate Rows on heap to save some heap.
1 parent 9782179 commit 4a2abab

File tree

6 files changed

+114
-103
lines changed

6 files changed

+114
-103
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPreviewDataFrameAnalyticsAction.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ public TransportPreviewDataFrameAnalyticsAction(
7777
this.clusterService = clusterService;
7878
}
7979

80-
private static Map<String, Object> mergeRow(DataFrameDataExtractor.Row row, List<String> fieldNames) {
81-
return row.getValues() == null
80+
private static Map<String, Object> mergeRow(String[] row, List<String> fieldNames) {
81+
return row == null
8282
? Collections.emptyMap()
83-
: IntStream.range(0, row.getValues().length).boxed().collect(Collectors.toMap(fieldNames::get, i -> row.getValues()[i]));
83+
: IntStream.range(0, row.length).boxed().collect(Collectors.toMap(fieldNames::get, i -> row[i]));
8484
}
8585

8686
@Override
@@ -121,7 +121,7 @@ void preview(Task task, DataFrameAnalyticsConfig config, ActionListener<Response
121121
).newExtractor(false);
122122
extractor.preview(delegate.delegateFailureAndWrap((l, rows) -> {
123123
List<String> fieldNames = extractor.getFieldNames();
124-
l.onResponse(new Response(rows.stream().map((r) -> mergeRow(r, fieldNames)).collect(Collectors.toList())));
124+
l.onResponse(new Response(rows.stream().map(r -> mergeRow(r, fieldNames)).collect(Collectors.toList())));
125125
}));
126126
}));
127127
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.elasticsearch.index.query.QueryBuilder;
2020
import org.elasticsearch.index.query.QueryBuilders;
2121
import org.elasticsearch.search.SearchHit;
22-
import org.elasticsearch.search.SearchHits;
2322
import org.elasticsearch.search.fetch.StoredFieldsContext;
2423
import org.elasticsearch.search.sort.SortOrder;
2524
import org.elasticsearch.xpack.core.ClientHelper;
@@ -107,14 +106,14 @@ public void cancel() {
107106
isCancelled = true;
108107
}
109108

110-
public Optional<List<Row>> next() throws IOException {
109+
public Optional<SearchHit[]> next() throws IOException {
111110
if (hasNext() == false) {
112111
throw new NoSuchElementException();
113112
}
114113

115-
Optional<List<Row>> hits = Optional.ofNullable(nextSearch());
116-
if (hits.isPresent() && hits.get().isEmpty() == false) {
117-
lastSortKey = hits.get().get(hits.get().size() - 1).getSortKey();
114+
Optional<SearchHit[]> hits = Optional.ofNullable(nextSearch());
115+
if (hits.isPresent() && hits.get().length > 0) {
116+
lastSortKey = (long) hits.get()[hits.get().length - 1].getSortValues()[0];
118117
} else {
119118
hasNext = false;
120119
}
@@ -126,7 +125,7 @@ public Optional<List<Row>> next() throws IOException {
126125
* Does no sorting of the results.
127126
* @param listener To alert with the extracted rows
128127
*/
129-
public void preview(ActionListener<List<Row>> listener) {
128+
public void preview(ActionListener<List<String[]>> listener) {
130129

131130
SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client)
132131
// This ensures the search throws if there are failures and the scroll context gets cleared automatically
@@ -155,22 +154,24 @@ public void preview(ActionListener<List<Row>> listener) {
155154
return;
156155
}
157156

158-
List<Row> rows = new ArrayList<>(searchResponse.getHits().getHits().length);
157+
List<String[]> rows = new ArrayList<>(searchResponse.getHits().getHits().length);
159158
for (SearchHit hit : searchResponse.getHits().getHits()) {
160-
var unpooled = hit.asUnpooled();
161-
String[] extractedValues = extractValues(unpooled);
162-
rows.add(extractedValues == null ? new Row(null, unpooled, true) : new Row(extractedValues, unpooled, false));
159+
String[] extractedValues = extractValues(hit);
160+
rows.add(extractedValues);
163161
}
164162
delegate.onResponse(rows);
165163
})
166164
);
167165
}
168166

169-
protected List<Row> nextSearch() throws IOException {
167+
protected SearchHit[] nextSearch() throws IOException {
168+
if (isCancelled) {
169+
return null;
170+
}
170171
return tryRequestWithSearchResponse(() -> executeSearchRequest(buildSearchRequest()));
171172
}
172173

173-
private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> request) throws IOException {
174+
private SearchHit[] tryRequestWithSearchResponse(Supplier<SearchResponse> request) throws IOException {
174175
try {
175176

176177
// We've set allow_partial_search_results to false which means if something
@@ -179,7 +180,7 @@ private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> request)
179180
try {
180181
LOGGER.trace(() -> "[" + context.jobId + "] Search response was obtained");
181182

182-
List<Row> rows = processSearchResponse(searchResponse);
183+
SearchHit[] rows = processSearchResponse(searchResponse);
183184

184185
// Request was successfully executed and processed so we can restore the flag to retry if a future failure occurs
185186
hasPreviousSearchFailed = false;
@@ -246,22 +247,12 @@ private void setFetchSource(SearchRequestBuilder searchRequestBuilder) {
246247
}
247248
}
248249

249-
private List<Row> processSearchResponse(SearchResponse searchResponse) {
250-
if (searchResponse.getHits().getHits().length == 0) {
250+
private SearchHit[] processSearchResponse(SearchResponse searchResponse) {
251+
if (isCancelled || searchResponse.getHits().getHits().length == 0) {
251252
hasNext = false;
252253
return null;
253254
}
254-
255-
SearchHits hits = searchResponse.getHits();
256-
List<Row> rows = new ArrayList<>(hits.getHits().length);
257-
for (SearchHit hit : hits) {
258-
if (isCancelled) {
259-
hasNext = false;
260-
break;
261-
}
262-
rows.add(createRow(hit));
263-
}
264-
return rows;
255+
return searchResponse.getHits().asUnpooled().getHits();
265256
}
266257

267258
private String extractNonProcessedValues(SearchHit hit, String organicFeature) {
@@ -317,14 +308,13 @@ private String[] extractProcessedValue(ProcessedField processedField, SearchHit
317308
return extractedValue;
318309
}
319310

320-
private Row createRow(SearchHit hit) {
321-
var unpooled = hit.asUnpooled();
322-
String[] extractedValues = extractValues(unpooled);
311+
public Row createRow(SearchHit hit) {
312+
String[] extractedValues = extractValues(hit);
323313
if (extractedValues == null) {
324-
return new Row(null, unpooled, true);
314+
return new Row(null, hit, true);
325315
}
326316
boolean isTraining = trainTestSplitter.get().isTraining(extractedValues);
327-
Row row = new Row(extractedValues, unpooled, isTraining);
317+
Row row = new Row(extractedValues, hit, isTraining);
328318
LOGGER.trace(
329319
() -> format(
330320
"[%s] Extracted row: sort key = [%s], is_training = [%s], values = %s",

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.common.settings.Settings;
1717
import org.elasticsearch.common.util.concurrent.ThreadContext;
1818
import org.elasticsearch.index.query.QueryBuilders;
19+
import org.elasticsearch.search.SearchHit;
1920
import org.elasticsearch.threadpool.ThreadPool;
2021
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
2122
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
@@ -256,9 +257,14 @@ private static void writeDataRows(
256257
long rowsProcessed = 0;
257258

258259
while (dataExtractor.hasNext()) {
259-
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
260+
Optional<SearchHit[]> rows = dataExtractor.next();
260261
if (rows.isPresent()) {
261-
for (DataFrameDataExtractor.Row row : rows.get()) {
262+
for (SearchHit searchHit : rows.get()) {
263+
if (dataExtractor.isCancelled()) {
264+
break;
265+
}
266+
rowsProcessed++;
267+
DataFrameDataExtractor.Row row = dataExtractor.createRow(searchHit);
262268
if (row.shouldSkip()) {
263269
dataCountsTracker.incrementSkippedDocsCount();
264270
} else {
@@ -271,7 +277,6 @@ private static void writeDataRows(
271277
}
272278
}
273279
}
274-
rowsProcessed += rows.get().size();
275280
progressTracker.updateLoadingDataProgress(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows));
276281
}
277282
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.settings.Settings;
1515
import org.elasticsearch.core.Nullable;
1616
import org.elasticsearch.search.SearchHit;
17+
import org.elasticsearch.search.SearchHits;
1718
import org.elasticsearch.tasks.TaskId;
1819
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1920
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
@@ -22,11 +23,9 @@
2223
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
2324

2425
import java.io.IOException;
25-
import java.util.Collections;
2626
import java.util.Iterator;
2727
import java.util.LinkedHashMap;
2828
import java.util.LinkedList;
29-
import java.util.List;
3029
import java.util.Map;
3130
import java.util.Objects;
3231
import java.util.Optional;
@@ -97,6 +96,9 @@ private void addResultAndJoinIfEndOfBatch(RowResults rowResults) {
9796
private void joinCurrentResults() {
9897
try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(settings, this::executeBulkRequest)) {
9998
while (currentResults.isEmpty() == false) {
99+
if (dataExtractor.isCancelled()) {
100+
break;
101+
}
100102
RowResults result = currentResults.pop();
101103
DataFrameDataExtractor.Row row = dataFrameRowsIterator.next();
102104
checkChecksumsMatch(row, result);
@@ -164,20 +166,20 @@ private void consumeDataExtractor() throws IOException {
164166

165167
private class ResultMatchingDataFrameRows implements Iterator<DataFrameDataExtractor.Row> {
166168

167-
private List<DataFrameDataExtractor.Row> currentDataFrameRows = Collections.emptyList();
169+
private SearchHit[] currentDataFrameRows = SearchHits.EMPTY;
168170
private int currentDataFrameRowsIndex;
169171

170172
@Override
171173
public boolean hasNext() {
172-
return dataExtractor.hasNext() || currentDataFrameRowsIndex < currentDataFrameRows.size();
174+
return dataExtractor.hasNext() || currentDataFrameRowsIndex < currentDataFrameRows.length;
173175
}
174176

175177
@Override
176178
public DataFrameDataExtractor.Row next() {
177179
DataFrameDataExtractor.Row row = null;
178180
while (hasNoMatch(row) && hasNext()) {
179181
advanceToNextBatchIfNecessary();
180-
row = currentDataFrameRows.get(currentDataFrameRowsIndex++);
182+
row = dataExtractor.createRow(currentDataFrameRows[currentDataFrameRowsIndex++]);
181183
}
182184

183185
if (hasNoMatch(row)) {
@@ -191,13 +193,13 @@ private static boolean hasNoMatch(DataFrameDataExtractor.Row row) {
191193
}
192194

193195
private void advanceToNextBatchIfNecessary() {
194-
if (currentDataFrameRowsIndex >= currentDataFrameRows.size()) {
195-
currentDataFrameRows = getNextDataRowsBatch().orElse(Collections.emptyList());
196+
if (currentDataFrameRowsIndex >= currentDataFrameRows.length) {
197+
currentDataFrameRows = getNextDataRowsBatch().orElse(SearchHits.EMPTY);
196198
currentDataFrameRowsIndex = 0;
197199
}
198200
}
199201

200-
private Optional<List<DataFrameDataExtractor.Row>> getNextDataRowsBatch() {
202+
private Optional<SearchHit[]> getNextDataRowsBatch() {
201203
try {
202204
return dataExtractor.next();
203205
} catch (IOException e) {

0 commit comments

Comments
 (0)