Skip to content

Commit 4da03a9

Browse files
committed
Add streaming search with configurable scoring modes
Introduces streaming search infrastructure that enables progressive emission of search results with three configurable scoring modes. The implementation extends the existing streaming transport layer to support partial result computation at the coordinator level. Scoring modes: - NO_SCORING: Immediate result emission without confidence requirements - CONFIDENCE_BASED: Statistical emission using Hoeffding inequality bounds - FULL_SCORING: Complete scoring before result emission The implementation leverages OpenSearch's inter-node streaming capabilities to reduce query latency through early result emission. Partial reductions are triggered based on the selected scoring mode, with results accumulated at the coordinator before final response generation. Key changes: - Add HoeffdingBounds for statistical confidence calculation - Extend QueryPhaseResultConsumer to support streaming reduction - Add StreamingScoringCollector wrapping TopScoreDocCollector - Integrate streaming scorer selection in QueryPhase - Add REST parameter stream_scoring_mode for mode selection - Include streaming metadata in SearchResponse The current implementation operates within architectural constraints where streaming is limited to inter-node communication. Client-facing streaming will be addressed in a follow-up contribution. Addresses opensearch-project#18725 Signed-off-by: Atri Sharma <atri.jiit@gmail.com>
1 parent f967a72 commit 4da03a9

28 files changed

+3061
-9
lines changed

server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,12 @@ private MergeResult partialReduce(
235235
SearchShardTarget target = result.getSearchShardTarget();
236236
processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
237237
}
238-
progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
238+
// For streaming search with TopDocs, use the new notification method
239+
if (hasTopDocs && newTopDocs != null) {
240+
progressListener.notifyPartialReduceWithTopDocs(processedShards, topDocsStats.getTotalHits(), newTopDocs, newAggs, numReducePhases);
241+
} else {
242+
progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
243+
}
239244
// we leave the results un-serialized because serializing is slow but we compute the serialized
240245
// size as an estimate of the memory used by the newly reduced aggregations.
241246
long serializedSize = hasAggs ? newAggs.getSerializedSize() : 0;

server/src/main/java/org/opensearch/action/search/SearchProgressListener.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,22 @@ protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc
9999
* @param reducePhase The version number for this reduce.
100100
*/
101101
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}
102+
103+
/**
104+
* Executed when a partial reduce with TopDocs is created for streaming search.
105+
*
106+
* @param shards The list of shards that are part of this reduce.
107+
* @param totalHits The total number of hits in this reduce.
108+
* @param topDocs The partial TopDocs result (may be null if no docs).
109+
* @param aggs The partial result for aggregations.
110+
* @param reducePhase The version number for this reduce.
111+
*/
112+
protected void onPartialReduceWithTopDocs(List<SearchShard> shards, TotalHits totalHits,
113+
org.apache.lucene.search.TopDocs topDocs,
114+
InternalAggregations aggs, int reducePhase) {
115+
// Default implementation delegates to the original method for backward compatibility
116+
onPartialReduce(shards, totalHits, aggs, reducePhase);
117+
}
102118

103119
/**
104120
* Executed once when the final reduce is created.
@@ -164,6 +180,16 @@ final void notifyPartialReduce(List<SearchShard> shards, TotalHits totalHits, In
164180
logger.warn(() -> new ParameterizedMessage("Failed to execute progress listener on partial reduce"), e);
165181
}
166182
}
183+
184+
final void notifyPartialReduceWithTopDocs(List<SearchShard> shards, TotalHits totalHits,
185+
org.apache.lucene.search.TopDocs topDocs,
186+
InternalAggregations aggs, int reducePhase) {
187+
try {
188+
onPartialReduceWithTopDocs(shards, totalHits, topDocs, aggs, reducePhase);
189+
} catch (Exception e) {
190+
logger.warn(() -> new ParameterizedMessage("Failed to execute progress listener on partial reduce with TopDocs"), e);
191+
}
192+
}
167193

168194
protected final void notifyFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
169195
try {

server/src/main/java/org/opensearch/action/search/SearchRequest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
125125

126126
private Boolean phaseTook = null;
127127

128+
private boolean streamingScoring = false;
129+
private String streamingScoringMode = null; // Will use StreamingScoringMode.DEFAULT if null
130+
128131
public SearchRequest() {
129132
this.localClusterAlias = null;
130133
this.absoluteStartMillis = DEFAULT_ABSOLUTE_START_MILLIS;
@@ -142,6 +145,7 @@ public SearchRequest(SearchRequest searchRequest) {
142145
searchRequest.absoluteStartMillis,
143146
searchRequest.finalReduce
144147
);
148+
this.streamingScoring = searchRequest.streamingScoring;
145149
}
146150

147151
/**
@@ -656,6 +660,34 @@ public void setPhaseTook(Boolean phaseTook) {
656660
this.phaseTook = phaseTook;
657661
}
658662

663+
/**
664+
* Enable streaming scoring for this search request.
665+
*/
666+
public void setStreamingScoring(boolean streamingScoring) {
667+
this.streamingScoring = streamingScoring;
668+
}
669+
670+
/**
671+
* Check if streaming scoring is enabled for this search request.
672+
*/
673+
public boolean isStreamingScoring() {
674+
return streamingScoring;
675+
}
676+
677+
/**
678+
* Set the streaming scoring mode.
679+
*/
680+
public void setStreamingScoringMode(String mode) {
681+
this.streamingScoringMode = mode;
682+
}
683+
684+
/**
685+
* Get the streaming scoring mode.
686+
*/
687+
public String getStreamingScoringMode() {
688+
return streamingScoringMode;
689+
}
690+
659691
/**
660692
* Returns a threshold that enforces a pre-filter roundtrip to pre-filter search shards based on query rewriting if the number of shards
661693
* the search request expands to exceeds the threshold, or <code>null</code> if the threshold is unspecified.

server/src/main/java/org/opensearch/action/search/SearchResponse.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ public class SearchResponse extends ActionResponse implements StatusToXContentOb
102102
private final Clusters clusters;
103103
private final long tookInMillis;
104104
private final PhaseTook phaseTook;
105+
106+
// Fields for streaming responses
107+
private boolean isPartial = false;
108+
private int sequenceNumber = 0;
109+
private int totalPartials = 0;
105110

106111
public SearchResponse(StreamInput in) throws IOException {
107112
super(in);
@@ -301,6 +306,31 @@ public ShardSearchFailure[] getShardFailures() {
301306
public String getScrollId() {
302307
return scrollId;
303308
}
309+
310+
// Streaming response methods
311+
public boolean isPartial() {
312+
return isPartial;
313+
}
314+
315+
public void setPartial(boolean partial) {
316+
this.isPartial = partial;
317+
}
318+
319+
public int getSequenceNumber() {
320+
return sequenceNumber;
321+
}
322+
323+
public void setSequenceNumber(int sequenceNumber) {
324+
this.sequenceNumber = sequenceNumber;
325+
}
326+
327+
public int getTotalPartials() {
328+
return totalPartials;
329+
}
330+
331+
public void setTotalPartials(int totalPartials) {
332+
this.totalPartials = totalPartials;
333+
}
304334

305335
/**
306336
* Returns the encoded string of the search context that the search request is used to executed

server/src/main/java/org/opensearch/action/search/StreamQueryPhaseResultConsumer.java

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,19 @@
88

99
package org.opensearch.action.search;
1010

11+
import org.apache.logging.log4j.LogManager;
12+
import org.apache.logging.log4j.Logger;
1113
import org.opensearch.core.common.breaker.CircuitBreaker;
1214
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
1315
import org.opensearch.search.SearchPhaseResult;
1416
import org.opensearch.search.query.QuerySearchResult;
17+
import org.opensearch.search.query.HoeffdingBounds;
18+
import org.opensearch.search.query.StreamingScoringMode;
19+
import org.apache.lucene.search.TopDocs;
20+
import org.apache.lucene.search.ScoreDoc;
1521

22+
import java.util.HashMap;
23+
import java.util.Map;
1624
import java.util.concurrent.Executor;
1725
import java.util.function.Consumer;
1826

@@ -22,6 +30,13 @@
2230
* @opensearch.internal
2331
*/
2432
public class StreamQueryPhaseResultConsumer extends QueryPhaseResultConsumer {
33+
private static final Logger logger = LogManager.getLogger(StreamQueryPhaseResultConsumer.class);
34+
35+
private final Map<Integer, HoeffdingBounds> shardBounds = new HashMap<>();
36+
private final double confidence = 0.95; // Default confidence level
37+
private int streamingEmissions = 0;
38+
private int totalDocsProcessed = 0;
39+
private final StreamingScoringMode scoringMode;
2540

2641
public StreamQueryPhaseResultConsumer(
2742
SearchRequest request,
@@ -43,22 +58,130 @@ public StreamQueryPhaseResultConsumer(
4358
expectedResultSize,
4459
onPartialMergeFailure
4560
);
61+
62+
// Determine scoring mode from request
63+
this.scoringMode = StreamingScoringMode.fromString(request.getStreamingScoringMode());
4664
}
4765

4866
/**
49-
* For stream search, the minBatchReduceSize is set higher than shard number
67+
* Adjust batch reduce size based on scoring mode.
5068
*
5169
* @param minBatchReduceSize: pass as number of shard
5270
*/
5371
@Override
5472
int getBatchReduceSize(int requestBatchedReduceSize, int minBatchReduceSize) {
55-
return super.getBatchReduceSize(requestBatchedReduceSize, minBatchReduceSize * 10);
73+
switch (scoringMode) {
74+
case NO_SCORING:
75+
// Emit immediately as results arrive
76+
return 1;
77+
78+
case CONFIDENCE_BASED:
79+
// Emit based on confidence threshold
80+
if (confidence > 0.9) {
81+
return minBatchReduceSize;
82+
}
83+
return minBatchReduceSize * 2;
84+
85+
case FULL_SCORING:
86+
// Wait for all shards before reducing
87+
return Integer.MAX_VALUE;
88+
89+
default:
90+
return super.getBatchReduceSize(requestBatchedReduceSize, minBatchReduceSize * 10);
91+
}
5692
}
5793

5894
void consumeStreamResult(SearchPhaseResult result, Runnable next) {
5995
// For streaming, we skip the ArraySearchPhaseResults.consumeResult() call
6096
// since it doesn't support multiple results from the same shard.
6197
QuerySearchResult querySearchResult = result.queryResult();
98+
99+
// Track scores for Hoeffding bounds if this is a scoring query
100+
if (querySearchResult.hasConsumedTopDocs()) {
101+
updateHoeffdingBounds(result.getShardIndex(), querySearchResult);
102+
}
103+
62104
pendingMerges.consume(querySearchResult, next);
63105
}
106+
107+
/**
108+
* Update Hoeffding bounds for a shard based on its scores.
109+
*/
110+
private void updateHoeffdingBounds(int shardIndex, QuerySearchResult queryResult) {
111+
var topDocsAndMaxScore = queryResult.topDocs();
112+
if (topDocsAndMaxScore != null && topDocsAndMaxScore.topDocs != null && topDocsAndMaxScore.topDocs.scoreDocs != null) {
113+
TopDocs topDocs = topDocsAndMaxScore.topDocs;
114+
// Get or create bounds for this shard
115+
HoeffdingBounds bounds = shardBounds.computeIfAbsent(
116+
shardIndex,
117+
k -> new HoeffdingBounds(confidence, 100.0)
118+
);
119+
120+
// Add scores to bounds tracker
121+
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
122+
bounds.addScore(scoreDoc.score);
123+
}
124+
125+
// Check if we should emit based on confidence
126+
if (shouldEmitStreamingResults()) {
127+
streamingEmissions++;
128+
totalDocsProcessed += topDocs.scoreDocs.length;
129+
130+
double maxBound = getMaxHoeffdingBound();
131+
logger.info("Streaming emission #{}: {} shards, {} docs, bound={}",
132+
streamingEmissions, shardBounds.size(), totalDocsProcessed, maxBound);
133+
134+
// Adjusted batch size triggers more frequent partial reductions
135+
}
136+
}
137+
}
138+
139+
/**
140+
* Check if we should emit streaming results based on scoring mode.
141+
*/
142+
private boolean shouldEmitStreamingResults() {
143+
switch (scoringMode) {
144+
case NO_SCORING:
145+
// Always emit immediately
146+
return true;
147+
148+
case CONFIDENCE_BASED:
149+
// Check Hoeffding bounds
150+
if (shardBounds.isEmpty()) {
151+
return false;
152+
}
153+
double maxBound = getMaxHoeffdingBound();
154+
boolean shouldEmit = maxBound <= 0.1; // Threshold for confidence
155+
156+
if (logger.isDebugEnabled()) {
157+
logger.debug("Hoeffding bound check: {}, emit={}",
158+
maxBound, shouldEmit);
159+
}
160+
return shouldEmit;
161+
162+
case FULL_SCORING:
163+
// Never emit early - wait for all results
164+
return false;
165+
166+
default:
167+
return false;
168+
}
169+
}
170+
171+
/**
172+
* Get the maximum Hoeffding bound across all shards.
173+
*/
174+
private double getMaxHoeffdingBound() {
175+
return shardBounds.values().stream()
176+
.mapToDouble(HoeffdingBounds::getBound)
177+
.max()
178+
.orElse(Double.MAX_VALUE);
179+
}
180+
181+
/**
182+
* Get the number of streaming emissions for monitoring.
183+
*/
184+
public int getStreamingEmissions() {
185+
return streamingEmissions;
186+
}
64187
}

server/src/main/java/org/opensearch/action/search/StreamSearchQueryThenFetchAsyncAction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.opensearch.search.SearchPhaseResult;
1616
import org.opensearch.search.SearchShardTarget;
1717
import org.opensearch.search.internal.AliasFilter;
18+
import org.opensearch.search.internal.ShardSearchRequest;
1819
import org.opensearch.telemetry.tracing.Tracer;
1920
import org.opensearch.transport.Transport;
2021

@@ -188,4 +189,5 @@ private void successfulStreamExecution() {
188189
onPhaseFailure(this, "The phase has failed", ex);
189190
}
190191
}
192+
191193
}

server/src/main/java/org/opensearch/action/search/StreamSearchTransportService.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,24 @@ public static void registerStreamRequestHandler(StreamTransportService transport
123123
ThreadPool.Names.STREAM_SEARCH
124124
)
125125
);
126+
127+
// Override QUERY_ACTION_NAME to enable streaming for query phase
128+
transportService.registerRequestHandler(
129+
QUERY_ACTION_NAME,
130+
ThreadPool.Names.SAME,
131+
false,
132+
true,
133+
AdmissionControlActionType.SEARCH,
134+
ShardSearchRequest::new,
135+
(request, channel, task) -> searchService.executeQueryPhase(
136+
request,
137+
false,
138+
(SearchShardTask) task,
139+
new StreamSearchChannelListener<>(channel, QUERY_ACTION_NAME, request),
140+
ThreadPool.Names.STREAM_SEARCH,
141+
true // isStreamSearch = true for streaming
142+
)
143+
);
126144
}
127145

128146
@Override

0 commit comments

Comments
 (0)