From d2da66dc883007e337e93e3465d542195e4a0b7c Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Wed, 25 Dec 2024 14:25:41 +0100 Subject: [PATCH 1/3] Use inheritance instead of composition to simplify search phase transitions We only need the extensibility for testing and it's a lot easier to reason about the code if we have explicit methods instead of overly complicated composition with lots of redundant references being retained all over the place. -> lets simplify to inheritance and get shorter code that performs more predictably (especially when it comes to memory) as a first step. This also opens up the possibility of further simplifications and removing more retained state/memory as we go through the search phases. --- .../action/search/DfsQueryPhase.java | 20 ++-- .../action/search/ExpandSearchPhase.java | 31 ++++-- .../action/search/FetchLookupFieldsPhase.java | 14 +-- .../action/search/FetchSearchPhase.java | 38 ++------ .../SearchDfsQueryThenFetchAsyncAction.java | 11 +-- .../action/search/DfsQueryPhaseTests.java | 42 ++++---- .../action/search/ExpandSearchPhaseTests.java | 95 ++++++++----------- .../action/search/FetchSearchPhaseTests.java | 82 +++++++--------- 8 files changed, 151 insertions(+), 182 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index cc8c4becea9a9..2226ccc191e21 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -10,6 +10,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; @@ -29,7 +30,6 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.List; -import java.util.function.Function; /** * This search phase fans out to every shards to execute a distributed search with a pre-collected distributed frequencies for all @@ -38,35 +38,39 @@ * operation. * @see CountedCollector#onFailure(int, SearchShardTarget, Exception) */ -final class DfsQueryPhase extends SearchPhase { +class DfsQueryPhase extends SearchPhase { private final SearchPhaseResults queryResult; private final List searchResults; private final AggregatedDfs dfs; private final List knnResults; - private final Function, SearchPhase> nextPhaseFactory; + private final Client client; private final AbstractSearchAsyncAction context; private final SearchTransportService searchTransportService; private final SearchProgressListener progressListener; DfsQueryPhase( List searchResults, - AggregatedDfs dfs, List knnResults, SearchPhaseResults queryResult, - Function, SearchPhase> nextPhaseFactory, + Client client, AbstractSearchAsyncAction context ) { super("dfs_query"); this.progressListener = context.getTask().getProgressListener(); this.queryResult = queryResult; this.searchResults = searchResults; - this.dfs = dfs; + this.dfs = SearchPhaseController.aggregateDfs(searchResults); this.knnResults = knnResults; - this.nextPhaseFactory = nextPhaseFactory; + this.client = client; this.context = context; this.searchTransportService = context.getSearchTransport(); } + // protected for testing + protected SearchPhase nextPhase() { + return SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResult, dfs); + } + @Override public void run() { // TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs @@ -74,7 +78,7 @@ public void run() { final CountedCollector counter = new CountedCollector<>( queryResult, searchResults.size(), - () -> context.executeNextPhase(this, () -> nextPhaseFactory.apply(queryResult)), + () -> context.executeNextPhase(this, this::nextPhase), context ); diff --git a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java index e8d94c32bdcc7..bedbfd077c742 100644 --- a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java @@ -12,34 +12,44 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.Maps; +import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.InnerHitBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; import java.util.Iterator; import java.util.List; -import java.util.function.Supplier; /** * This search phase is an optional phase that will be executed once all hits are fetched from the shards that executes * field-collapsing on the inner hits. This phase only executes if field collapsing is requested in the search request and otherwise * forwards to the next phase immediately. */ -final class ExpandSearchPhase extends SearchPhase { +class ExpandSearchPhase extends SearchPhase { private final AbstractSearchAsyncAction context; - private final SearchHits searchHits; - private final Supplier nextPhase; + private final SearchResponseSections searchResponseSections; + private final AtomicArray queryPhaseResults; - ExpandSearchPhase(AbstractSearchAsyncAction context, SearchHits searchHits, Supplier nextPhase) { + ExpandSearchPhase( + AbstractSearchAsyncAction context, + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { super("expand"); this.context = context; - this.searchHits = searchHits; - this.nextPhase = nextPhase; + this.searchResponseSections = searchResponseSections; + this.queryPhaseResults = queryPhaseResults; + } + + // protected for tests + protected SearchPhase nextPhase() { + return new FetchLookupFieldsPhase(context, searchResponseSections, queryPhaseResults); } /** @@ -52,14 +62,15 @@ private boolean isCollapseRequest() { @Override public void run() { + var searchHits = searchResponseSections.hits(); if (isCollapseRequest() == false || searchHits.getHits().length == 0) { onPhaseDone(); } else { - doRun(); + doRun(searchHits); } } - private void doRun() { + private void doRun(SearchHits searchHits) { SearchRequest searchRequest = context.getRequest(); CollapseBuilder collapseBuilder = searchRequest.source().collapse(); final List innerHitBuilders = collapseBuilder.getInnerHits(); @@ -168,6 +179,6 @@ private static SearchSourceBuilder buildExpandSearchSourceBuilder(InnerHitBuilde } private void onPhaseDone() { - context.executeNextPhase(this, nextPhase); + context.executeNextPhase(this, this::nextPhase); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java index d8671bcadf86d..c89905b2ecd07 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java @@ -48,9 +48,7 @@ final class FetchLookupFieldsPhase extends SearchPhase { this.queryResults = queryResults; } - private record Cluster(String clusterAlias, List hitsWithLookupFields, List lookupFields) { - - } + private record Cluster(String clusterAlias, List hitsWithLookupFields, List lookupFields) {} private static List groupLookupFieldsByClusterAlias(SearchHits searchHits) { final Map> perClusters = new HashMap<>(); @@ -77,7 +75,7 @@ private static List groupLookupFieldsByClusterAlias(SearchHits searchHi public void run() { final List clusters = groupLookupFieldsByClusterAlias(searchResponse.hits); if (clusters.isEmpty()) { - context.sendSearchResponse(searchResponse, queryResults); + sendResponse(); return; } doRun(clusters); @@ -129,9 +127,9 @@ public void onResponse(MultiSearchResponse items) { } } if (failure != null) { - context.onPhaseFailure(FetchLookupFieldsPhase.this, "failed to fetch lookup fields", failure); + onFailure(failure); } else { - context.sendSearchResponse(searchResponse, queryResults); + sendResponse(); } } @@ -141,4 +139,8 @@ public void onFailure(Exception e) { } }); } + + private void sendResponse() { + context.sendSearchResponse(searchResponse, queryResults); + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 0fbface3793a8..09989c60b9d2b 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -27,15 +27,13 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; /** * This search phase merges the query results from the previous phase together and calculates the topN hits for this search. * Then it reaches out to all relevant shards to fetch the topN hits. */ -final class FetchSearchPhase extends SearchPhase { +class FetchSearchPhase extends SearchPhase { private final AtomicArray searchPhaseShardResults; - private final BiFunction, SearchPhase> nextPhaseFactory; private final AbstractSearchAsyncAction context; private final Logger logger; private final SearchProgressListener progressListener; @@ -49,26 +47,6 @@ final class FetchSearchPhase extends SearchPhase { AggregatedDfs aggregatedDfs, AbstractSearchAsyncAction context, @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase - ) { - this( - resultConsumer, - aggregatedDfs, - context, - reducedQueryPhase, - (response, queryPhaseResults) -> new ExpandSearchPhase( - context, - response.hits, - () -> new FetchLookupFieldsPhase(context, response, queryPhaseResults) - ) - ); - } - - FetchSearchPhase( - SearchPhaseResults resultConsumer, - AggregatedDfs aggregatedDfs, - AbstractSearchAsyncAction context, - @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase, - BiFunction, SearchPhase> nextPhaseFactory ) { super("fetch"); if (context.getNumShards() != resultConsumer.getNumShards()) { @@ -81,7 +59,6 @@ final class FetchSearchPhase extends SearchPhase { } this.searchPhaseShardResults = resultConsumer.getAtomicArray(); this.aggregatedDfs = aggregatedDfs; - this.nextPhaseFactory = nextPhaseFactory; this.context = context; this.logger = context.getLogger(); this.progressListener = context.getTask().getProgressListener(); @@ -89,8 +66,13 @@ final class FetchSearchPhase extends SearchPhase { this.resultConsumer = reducedQueryPhase == null ? resultConsumer : null; } + // protected for tests + protected SearchPhase nextPhase(SearchResponseSections searchResponseSections, AtomicArray queryPhaseResults) { + return new ExpandSearchPhase(context, searchResponseSections, queryPhaseResults); + } + @Override - public void run() { + public final void run() { context.execute(new AbstractRunnable() { @Override @@ -112,7 +94,7 @@ private void innerRun() throws Exception { final int numShards = context.getNumShards(); // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase. - final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1 + final boolean queryAndFetchOptimization = numShards == 1 && context.getRequest().hasKnnSearch() == false && reducedQueryPhase.queryPhaseRankCoordinatorContext() == null && (context.getRequest().source() == null || context.getRequest().source().rankBuilder() == null); @@ -127,7 +109,7 @@ private void innerRun() throws Exception { // we have to release contexts here to free up resources searchPhaseShardResults.asList() .forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context)); - moveToNextPhase(new AtomicArray<>(numShards), reducedQueryPhase); + moveToNextPhase(new AtomicArray<>(0), reducedQueryPhase); } else { innerRunFetch(scoreDocs, numShards, reducedQueryPhase); } @@ -272,7 +254,7 @@ private void moveToNextPhase( context.executeNextPhase(this, () -> { var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr); context.addReleasable(resp::decRef); - return nextPhaseFactory.apply(resp, searchPhaseShardResults); + return nextPhase(resp, searchPhaseShardResults); }); } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 25d59a06664da..83e787b828c8a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -17,7 +17,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; -import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.dfs.DfsKnnResults; import org.elasticsearch.search.dfs.DfsSearchResult; import org.elasticsearch.search.internal.AliasFilter; @@ -93,16 +92,8 @@ protected void executePhaseOnShard( @Override protected SearchPhase getNextPhase() { final List dfsSearchResults = results.getAtomicArray().asList(); - final AggregatedDfs aggregatedDfs = SearchPhaseController.aggregateDfs(dfsSearchResults); final List mergedKnnResults = SearchPhaseController.mergeKnnResults(getRequest(), dfsSearchResults); - return new DfsQueryPhase( - dfsSearchResults, - aggregatedDfs, - mergedKnnResults, - queryPhaseResultConsumer, - (queryResults) -> SearchQueryThenFetchAsyncAction.nextPhase(client, this, queryResults, aggregatedDfs), - this - ); + return new DfsQueryPhase(dfsSearchResults, mergedKnnResults, queryPhaseResultConsumer, client, this); } @Override diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index 193855a4c835f..daa709db59f80 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -139,12 +139,7 @@ public void sendExecuteQuery( exc -> {} ) ) { - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - public void run() throws IOException { - responseRef.set(((QueryPhaseResultConsumer) response).results); - } - }, mockSearchPhaseContext); + DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -225,12 +220,7 @@ public void sendExecuteQuery( exc -> {} ) ) { - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - public void run() throws IOException { - responseRef.set(((QueryPhaseResultConsumer) response).results); - } - }, mockSearchPhaseContext); + DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -313,12 +303,7 @@ public void sendExecuteQuery( exc -> {} ) ) { - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - public void run() throws IOException { - responseRef.set(((QueryPhaseResultConsumer) response).results); - } - }, mockSearchPhaseContext); + DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef); assertEquals("dfs_query", phase.getName()); phase.run(); assertThat(mockSearchPhaseContext.failures, hasSize(1)); @@ -328,6 +313,25 @@ public void run() throws IOException { } } + private static DfsQueryPhase makeDfsPhase( + AtomicArray results, + SearchPhaseResults consumer, + MockSearchPhaseContext mockSearchPhaseContext, + AtomicReference> responseRef + ) { + return new DfsQueryPhase(results.asList(), null, consumer, null, mockSearchPhaseContext) { + @Override + protected SearchPhase nextPhase() { + return new SearchPhase("test") { + @Override + public void run() { + responseRef.set(((QueryPhaseResultConsumer) consumer).results); + } + }; + } + }; + } + public void testRewriteShardSearchRequestWithRank() { List dkrs = List.of( new DfsKnnResults(null, new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1), new ScoreDoc(7, 0.1f, 2) }), @@ -338,7 +342,7 @@ public void testRewriteShardSearchRequestWithRank() { ); MockSearchPhaseContext mspc = new MockSearchPhaseContext(2); mspc.searchTransport = new SearchTransportService(null, null, null); - DfsQueryPhase dqp = new DfsQueryPhase(null, null, dkrs, mock(QueryPhaseResultConsumer.class), null, mspc); + DfsQueryPhase dqp = new DfsQueryPhase(List.of(), dkrs, mock(QueryPhaseResultConsumer.class), null, mspc); QueryBuilder bm25 = new TermQueryBuilder("field", "term"); SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25) diff --git a/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java index 23184be02f9c3..f2a2884545dfa 100644 --- a/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.search.AbstractSearchTestCase; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; @@ -117,18 +118,9 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL SearchHit hit = new SearchHit(1, "ID"); hit.setDocumentField("someField", new DocumentField("someField", Collections.singletonList(collapseValue))); SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); + SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1); try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - var sections = new SearchResponseSections(hits, null, null, false, null, null, 1); - try { - mockSearchPhaseContext.sendSearchResponse(sections, null); - } finally { - sections.decRef(); - } - } - }); + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, null); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -145,6 +137,7 @@ public void run() { assertTrue(executedMultiSearch.get()); } finally { + searchResponseSections.decRef(); hits.decRef(); } } finally { @@ -211,18 +204,9 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL SearchHit hit2 = new SearchHit(2, "ID2"); hit2.setDocumentField("someField", new DocumentField("someField", Collections.singletonList(collapseValue))); SearchHits hits = new SearchHits(new SearchHit[] { hit1, hit2 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); + SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1); try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - var sections = new SearchResponseSections(hits, null, null, false, null, null, 1); - try { - mockSearchPhaseContext.sendSearchResponse(sections, null); - } finally { - sections.decRef(); - } - } - }); + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, null); phase.run(); assertThat(mockSearchPhaseContext.phaseFailure.get(), Matchers.instanceOf(RuntimeException.class)); assertEquals("boom", mockSearchPhaseContext.phaseFailure.get().getMessage()); @@ -230,6 +214,7 @@ public void run() { assertNull(mockSearchPhaseContext.searchResponse.get()); } finally { mockSearchPhaseContext.results.close(); + searchResponseSections.decRef(); hits.decRef(); collapsedHits.decRef(); } @@ -250,22 +235,14 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL SearchHit hit2 = new SearchHit(2, "ID2"); hit2.setDocumentField("someField", new DocumentField("someField", Collections.singletonList(null))); SearchHits hits = new SearchHits(new SearchHit[] { hit1, hit2 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); + SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1); try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - var sections = new SearchResponseSections(hits, null, null, false, null, null, 1); - try { - mockSearchPhaseContext.sendSearchResponse(sections, null); - } finally { - sections.decRef(); - } - } - }); + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, null); phase.run(); mockSearchPhaseContext.assertNoFailure(); assertNotNull(mockSearchPhaseContext.searchResponse.get()); } finally { + searchResponseSections.decRef(); hits.decRef(); } } finally { @@ -294,12 +271,8 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL ); SearchHits hits = SearchHits.empty(new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f); - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(new SearchResponseSections(hits, null, null, false, null, null, 1), null); - } - }); + final SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1); + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, null); phase.run(); mockSearchPhaseContext.assertNoFailure(); assertNotNull(mockSearchPhaseContext.searchResponse.get()); @@ -343,16 +316,13 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL SearchHit hit = new SearchHit(1, "ID"); hit.setDocumentField("someField", new DocumentField("someField", Collections.singletonList("foo"))); SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); + final SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1); try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(new SearchResponseSections(hits, null, null, false, null, null, 1), null); - } - }); + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, null); phase.run(); mockSearchPhaseContext.assertNoFailure(); } finally { + searchResponseSections.decRef(); hits.decRef(); } } finally { @@ -364,6 +334,29 @@ public void run() { } } + private static ExpandSearchPhase newExpandSearchPhase( + MockSearchPhaseContext mockSearchPhaseContext, + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return new ExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, queryPhaseResults) { + @Override + protected SearchPhase nextPhase() { + searchResponseSections.mustIncRef(); + return new SearchPhase("test") { + @Override + public void run() { + try { + mockSearchPhaseContext.sendSearchResponse(searchResponseSections, queryPhaseResults); + } finally { + searchResponseSections.decRef(); + } + } + }; + } + }; + } + public void testExpandSearchRespectsOriginalPIT() { MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); final PointInTimeBuilder pit = new PointInTimeBuilder(new BytesArray("foo")); @@ -392,20 +385,14 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL SearchHit hit = new SearchHit(1, "ID"); hit.setDocumentField("someField", new DocumentField("someField", Collections.singletonList("foo"))); SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); + SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1); try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse( - new SearchResponseSections(hits, null, null, false, null, null, 1), - new AtomicArray<>(0) - ); - } - }); + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, new AtomicArray<>(0)); phase.run(); mockSearchPhaseContext.assertNoFailure(); } finally { hits.decRef(); + searchResponseSections.decRef(); } } finally { mockSearchPhaseContext.results.close(); diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java index dda20dfb37e9d..a23f446c0a36e 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java @@ -134,13 +134,7 @@ public void testShortcutQueryAndFetchOptimization() throws Exception { numHits = 0; } SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -263,13 +257,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -373,13 +361,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -481,19 +463,21 @@ public void sendExecuteFetch( }; CountDownLatch latch = new CountDownLatch(1); SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - (searchResponse, scrollId) -> new SearchPhase("test") { - @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(searchResponse, null); - latch.countDown(); - } + FetchSearchPhase phase = new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return new SearchPhase("test") { + @Override + public void run() { + mockSearchPhaseContext.sendSearchResponse(searchResponseSections, null); + latch.countDown(); + } + }; } - ); + }; assertEquals("fetch", phase.getName()); phase.run(); latch.await(); @@ -621,13 +605,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); assertNotNull(mockSearchPhaseContext.searchResponse.get()); @@ -641,6 +619,22 @@ public void sendExecuteFetch( } } + private static FetchSearchPhase getFetchSearchPhase( + SearchPhaseResults results, + MockSearchPhaseContext mockSearchPhaseContext, + SearchPhaseController.ReducedQueryPhase reducedQueryPhase + ) { + return new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return searchPhaseFactory(mockSearchPhaseContext).apply(searchResponseSections, queryPhaseResults); + } + }; + } + public void testCleanupIrrelevantContexts() throws Exception { // contexts that are not fetched should be cleaned up MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); SearchPhaseController controller = new SearchPhaseController((t, s) -> InternalAggregationTestCase.emptyReduceContextBuilder()); @@ -723,13 +717,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); From 01920ba45c5ea82767f5e1d266e4fe493133fe29 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Wed, 25 Dec 2024 15:39:20 +0100 Subject: [PATCH 2/3] save more --- .../action/search/DfsQueryPhase.java | 38 ++++++------------- .../SearchDfsQueryThenFetchAsyncAction.java | 6 +-- .../action/search/DfsQueryPhaseTests.java | 13 +++++-- 3 files changed, 22 insertions(+), 35 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index 2226ccc191e21..1dfd3aad49a42 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -40,55 +40,45 @@ */ class DfsQueryPhase extends SearchPhase { private final SearchPhaseResults queryResult; - private final List searchResults; - private final AggregatedDfs dfs; - private final List knnResults; private final Client client; private final AbstractSearchAsyncAction context; - private final SearchTransportService searchTransportService; private final SearchProgressListener progressListener; - DfsQueryPhase( - List searchResults, - List knnResults, - SearchPhaseResults queryResult, - Client client, - AbstractSearchAsyncAction context - ) { + DfsQueryPhase(SearchPhaseResults queryResult, Client client, AbstractSearchAsyncAction context) { super("dfs_query"); this.progressListener = context.getTask().getProgressListener(); this.queryResult = queryResult; - this.searchResults = searchResults; - this.dfs = SearchPhaseController.aggregateDfs(searchResults); - this.knnResults = knnResults; this.client = client; this.context = context; - this.searchTransportService = context.getSearchTransport(); } // protected for testing - protected SearchPhase nextPhase() { + protected SearchPhase nextPhase(AggregatedDfs dfs) { return SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResult, dfs); } + @SuppressWarnings("unchecked") @Override public void run() { + List searchResults = (List) context.results.getAtomicArray().asList(); + AggregatedDfs dfs = SearchPhaseController.aggregateDfs(searchResults); // TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs // to free up memory early final CountedCollector counter = new CountedCollector<>( queryResult, searchResults.size(), - () -> context.executeNextPhase(this, this::nextPhase), + () -> context.executeNextPhase(this, () -> nextPhase(dfs)), context ); + List knnResults = SearchPhaseController.mergeKnnResults(context.getRequest(), searchResults); for (final DfsSearchResult dfsResult : searchResults) { final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget(); final int shardIndex = dfsResult.getShardIndex(); QuerySearchRequest querySearchRequest = new QuerySearchRequest( context.getOriginalIndices(shardIndex), dfsResult.getContextId(), - rewriteShardSearchRequest(dfsResult.getShardSearchRequest()), + rewriteShardSearchRequest(knnResults, dfsResult.getShardSearchRequest()), dfs ); final Transport.Connection connection; @@ -98,11 +88,8 @@ public void run() { shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter); continue; } - searchTransportService.sendExecuteQuery( - connection, - querySearchRequest, - context.getTask(), - new SearchActionListener<>(shardTarget, shardIndex) { + context.getSearchTransport() + .sendExecuteQuery(connection, querySearchRequest, context.getTask(), new SearchActionListener<>(shardTarget, shardIndex) { @Override protected void innerOnResponse(QuerySearchResult response) { @@ -127,8 +114,7 @@ public void onFailure(Exception exception) { } } } - } - ); + }); } } @@ -145,7 +131,7 @@ private void shardFailure( } // package private for testing - ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { + ShardSearchRequest rewriteShardSearchRequest(List knnResults, ShardSearchRequest request) { SearchSourceBuilder source = request.source(); if (source == null || source.knnSearch().isEmpty()) { return request; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 83e787b828c8a..069dbb9a87010 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -17,12 +17,10 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; -import org.elasticsearch.search.dfs.DfsKnnResults; import org.elasticsearch.search.dfs.DfsSearchResult; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.transport.Transport; -import java.util.List; import java.util.Map; import java.util.concurrent.Executor; import java.util.function.BiFunction; @@ -91,9 +89,7 @@ protected void executePhaseOnShard( @Override protected SearchPhase getNextPhase() { - final List dfsSearchResults = results.getAtomicArray().asList(); - final List mergedKnnResults = SearchPhaseController.mergeKnnResults(getRequest(), dfsSearchResults); - return new DfsQueryPhase(dfsSearchResults, mergedKnnResults, queryPhaseResultConsumer, client, this); + return new DfsQueryPhase(queryPhaseResultConsumer, client, this); } @Override diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index daa709db59f80..43292c4f65245 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.dfs.DfsKnnResults; import org.elasticsearch.search.dfs.DfsSearchResult; import org.elasticsearch.search.internal.AliasFilter; @@ -319,9 +320,13 @@ private static DfsQueryPhase makeDfsPhase( MockSearchPhaseContext mockSearchPhaseContext, AtomicReference> responseRef ) { - return new DfsQueryPhase(results.asList(), null, consumer, null, mockSearchPhaseContext) { + int shards = mockSearchPhaseContext.numShards; + for (int i = 0; i < shards; i++) { + mockSearchPhaseContext.results.getAtomicArray().set(i, results.get(i)); + } + return new DfsQueryPhase(consumer, null, mockSearchPhaseContext) { @Override - protected SearchPhase nextPhase() { + protected SearchPhase nextPhase(AggregatedDfs dfs) { return new SearchPhase("test") { @Override public void run() { @@ -342,7 +347,7 @@ public void testRewriteShardSearchRequestWithRank() { ); MockSearchPhaseContext mspc = new MockSearchPhaseContext(2); mspc.searchTransport = new SearchTransportService(null, null, null); - DfsQueryPhase dqp = new DfsQueryPhase(List.of(), dkrs, mock(QueryPhaseResultConsumer.class), null, mspc); + DfsQueryPhase dqp = new DfsQueryPhase(mock(QueryPhaseResultConsumer.class), null, mspc); QueryBuilder bm25 = new TermQueryBuilder("field", "term"); SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25) @@ -356,7 +361,7 @@ public void testRewriteShardSearchRequestWithRank() { SearchRequest sr = new SearchRequest().allowPartialSearchResults(true).source(ssb); ShardSearchRequest ssr = new ShardSearchRequest(null, sr, new ShardId("test", "testuuid", 1), 1, 1, null, 1.0f, 0, null); - dqp.rewriteShardSearchRequest(ssr); + dqp.rewriteShardSearchRequest(dkrs, ssr); KnnScoreDocQueryBuilder ksdqb0 = new KnnScoreDocQueryBuilder( new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1) }, From 361b1a92eb645198c0e6670788e8eaf528d8f9b7 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Tue, 18 Feb 2025 16:47:31 +0100 Subject: [PATCH 3/3] fix compile --- .../action/search/SearchDfsQueryThenFetchAsyncAction.java | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 7f6b06ea21cc6..dd97f02dd8f40 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.transport.Transport; +import java.util.List; import java.util.Map; import java.util.concurrent.Executor; import java.util.function.BiFunction;