Skip to content

Commit a41dac9

Browse files
committed
MinScore implementation in Linear retriever
1 parent 473c4da commit a41dac9

File tree

10 files changed

+293
-51
lines changed

10 files changed

+293
-51
lines changed

server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
3232
private final RankDoc[] rankDocs;
3333
private final QueryBuilder[] queryBuilders;
3434
private final boolean onlyRankDocs;
35+
private final float minScore;
3536

36-
public RankDocsQueryBuilder(RankDoc[] rankDocs, QueryBuilder[] queryBuilders, boolean onlyRankDocs) {
37+
public RankDocsQueryBuilder(RankDoc[] rankDocs, QueryBuilder[] queryBuilders, boolean onlyRankDocs, float minScore) {
3738
this.rankDocs = rankDocs;
3839
this.queryBuilders = queryBuilders;
3940
this.onlyRankDocs = onlyRankDocs;
41+
this.minScore = minScore;
4042
}
4143

4244
public RankDocsQueryBuilder(StreamInput in) throws IOException {
@@ -45,9 +47,11 @@ public RankDocsQueryBuilder(StreamInput in) throws IOException {
4547
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
4648
this.queryBuilders = in.readOptionalArray(c -> c.readNamedWriteable(QueryBuilder.class), QueryBuilder[]::new);
4749
this.onlyRankDocs = in.readBoolean();
50+
this.minScore = in.readFloat();
4851
} else {
4952
this.queryBuilders = null;
5053
this.onlyRankDocs = false;
54+
this.minScore = Float.MIN_VALUE;
5155
}
5256
}
5357

@@ -70,7 +74,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
7074
changed |= newQueryBuilders[i] != queryBuilders[i];
7175
}
7276
if (changed) {
73-
RankDocsQueryBuilder clone = new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
77+
RankDocsQueryBuilder clone = new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs, minScore);
7478
clone.queryName(queryName());
7579
return clone;
7680
}
@@ -88,6 +92,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
8892
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
8993
out.writeOptionalArray(StreamOutput::writeNamedWriteable, queryBuilders);
9094
out.writeBoolean(onlyRankDocs);
95+
out.writeFloat(minScore);
9196
}
9297
}
9398

@@ -115,7 +120,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
115120
queries = new Query[0];
116121
queryNames = Strings.EMPTY_ARRAY;
117122
}
118-
return new RankDocsQuery(reader, shardRankDocs, queries, queryNames, onlyRankDocs);
123+
return new RankDocsQuery(reader, shardRankDocs, queries, queryNames, onlyRankDocs, minScore);
119124
}
120125

121126
@Override
@@ -135,12 +140,13 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
135140
protected boolean doEquals(RankDocsQueryBuilder other) {
136141
return Arrays.equals(rankDocs, other.rankDocs)
137142
&& Arrays.equals(queryBuilders, other.queryBuilders)
138-
&& onlyRankDocs == other.onlyRankDocs;
143+
&& onlyRankDocs == other.onlyRankDocs
144+
&& minScore == other.minScore;
139145
}
140146

141147
@Override
142148
protected int doHashCode() {
143-
return Objects.hash(Arrays.hashCode(rankDocs), Arrays.hashCode(queryBuilders), onlyRankDocs);
149+
return Objects.hash(Arrays.hashCode(rankDocs), Arrays.hashCode(queryBuilders), onlyRankDocs, minScore);
144150
}
145151

146152
@Override

server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
201201
public QueryBuilder topDocsQuery() {
202202
assert queryVector != null : "query vector must be materialized at this point";
203203
assert rankDocs != null : "rankDocs should have been materialized by now";
204-
var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true);
204+
var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true, Float.MIN_VALUE);
205205
if (preFilterQueryBuilders.isEmpty()) {
206206
return rankDocsQuery.queryName(retrieverName);
207207
}
@@ -217,7 +217,8 @@ public QueryBuilder explainQuery() {
217217
var rankDocsQuery = new RankDocsQueryBuilder(
218218
rankDocs,
219219
new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector.get()), field, similarity) },
220-
true
220+
false,
221+
Float.MIN_VALUE
221222
);
222223
if (preFilterQueryBuilders.isEmpty()) {
223224
return rankDocsQuery.queryName(retrieverName);

server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ public QueryBuilder explainQuery() {
9393
var explainQuery = new RankDocsQueryBuilder(
9494
rankDocs.get(),
9595
sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new),
96-
true
96+
true,
97+
Float.MIN_VALUE
9798
);
9899
explainQuery.queryName(retrieverName());
99100
return explainQuery;
@@ -113,17 +114,19 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
113114
rankQuery = new RankDocsQueryBuilder(
114115
rankDocResults,
115116
sources.stream().map(RetrieverBuilder::topDocsQuery).toArray(QueryBuilder[]::new),
116-
false
117+
false,
118+
Float.MIN_VALUE
117119
);
118120
} else {
119121
rankQuery = new RankDocsQueryBuilder(
120122
rankDocResults,
121123
sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new),
122-
false
124+
false,
125+
Float.MIN_VALUE
123126
);
124127
}
125128
} else {
126-
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
129+
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false, Float.MIN_VALUE);
127130
}
128131
rankQuery.queryName(retrieverName());
129132
// ignore prefilters of this level, they were already propagated to children

server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,7 @@ public float getMaxScore(int docId) {
164164
}
165165

166166
@Override
167-
public float score() {
168-
// We could still end up with a valid 0 score for a RankDoc
169-
// so here we want to differentiate between this and all the tailQuery matches
170-
// that would also produce a 0 score due to filtering, by setting the score to `Float.MIN_VALUE` instead for
171-
// RankDoc matches.
167+
public float score() throws IOException {
172168
return Math.max(docs[upTo].score, Float.MIN_VALUE);
173169
}
174170

@@ -234,6 +230,7 @@ public int hashCode() {
234230
// RankDocs provided. This query does not contribute to scoring, as it is set as filter when creating the weight
235231
private final Query tailQuery;
236232
private final boolean onlyRankDocs;
233+
private final float minScore;
237234

238235
/**
239236
* Creates a {@code RankDocsQuery} based on the provided docs.
@@ -242,8 +239,9 @@ public int hashCode() {
242239
* @param sources The original queries that were used to compute the top documents
243240
* @param queryNames The names (if present) of the original retrievers
244241
* @param onlyRankDocs Whether the query should only match the provided rank docs
242+
* @param minScore The minimum score threshold for documents to be included in total hits
245243
*/
246-
public RankDocsQuery(IndexReader reader, RankDoc[] rankDocs, Query[] sources, String[] queryNames, boolean onlyRankDocs) {
244+
public RankDocsQuery(IndexReader reader, RankDoc[] rankDocs, Query[] sources, String[] queryNames, boolean onlyRankDocs, float minScore) {
247245
assert sources.length == queryNames.length;
248246
// clone to avoid side-effect after sorting
249247
this.docs = rankDocs.clone();
@@ -260,13 +258,15 @@ public RankDocsQuery(IndexReader reader, RankDoc[] rankDocs, Query[] sources, St
260258
this.tailQuery = null;
261259
}
262260
this.onlyRankDocs = onlyRankDocs;
261+
this.minScore = minScore;
263262
}
264263

265264
private RankDocsQuery(RankDoc[] docs, Query topQuery, Query tailQuery, boolean onlyRankDocs) {
266265
this.docs = docs;
267266
this.topQuery = topQuery;
268267
this.tailQuery = tailQuery;
269268
this.onlyRankDocs = onlyRankDocs;
269+
this.minScore = Float.MIN_VALUE;
270270
}
271271

272272
private static int binarySearch(RankDoc[] docs, int fromIndex, int toIndex, int key) {
@@ -346,7 +346,41 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException {
346346

347347
@Override
348348
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
349-
return combinedWeight.scorerSupplier(context);
349+
return new ScorerSupplier() {
350+
private final ScorerSupplier supplier = combinedWeight.scorerSupplier(context);
351+
352+
@Override
353+
public Scorer get(long leadCost) throws IOException {
354+
Scorer scorer = supplier.get(leadCost);
355+
return new Scorer() {
356+
@Override
357+
public DocIdSetIterator iterator() {
358+
return scorer.iterator();
359+
}
360+
361+
@Override
362+
public float getMaxScore(int docId) throws IOException {
363+
return scorer.getMaxScore(docId);
364+
}
365+
366+
@Override
367+
public float score() throws IOException {
368+
float score = scorer.score();
369+
return score >= minScore ? score : 0f;
370+
}
371+
372+
@Override
373+
public int docID() {
374+
return scorer.docID();
375+
}
376+
};
377+
}
378+
379+
@Override
380+
public long cost() {
381+
return supplier.cost();
382+
}
383+
};
350384
}
351385
};
352386
}

server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.apache.lucene.search.Query;
2121
import org.apache.lucene.search.ScoreDoc;
2222
import org.apache.lucene.search.TopScoreDocCollectorManager;
23+
import org.apache.lucene.search.MatchAllDocsQuery;
2324
import org.apache.lucene.store.Directory;
2425
import org.apache.lucene.tests.index.RandomIndexWriter;
2526
import org.elasticsearch.search.rank.RankDoc;
@@ -31,6 +32,7 @@
3132
import java.util.Random;
3233

3334
import static org.hamcrest.Matchers.equalTo;
35+
import static org.hamcrest.Matchers.instanceOf;
3436
import static org.hamcrest.Matchers.lessThanOrEqualTo;
3537

3638
public class RankDocsQueryBuilderTests extends AbstractQueryTestCase<RankDocsQueryBuilder> {
@@ -50,14 +52,30 @@ private RankDoc[] generateRandomRankDocs() {
5052
@Override
5153
protected RankDocsQueryBuilder doCreateTestQueryBuilder() {
5254
RankDoc[] rankDocs = generateRandomRankDocs();
53-
return new RankDocsQueryBuilder(rankDocs, null, false);
55+
return new RankDocsQueryBuilder(rankDocs, null, false, Float.MIN_VALUE);
5456
}
5557

5658
@Override
5759
protected void doAssertLuceneQuery(RankDocsQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException {
58-
assertTrue(query instanceof RankDocsQuery);
60+
assertThat(query, instanceOf(RankDocsQuery.class));
5961
RankDocsQuery rankDocsQuery = (RankDocsQuery) query;
60-
assertArrayEquals(queryBuilder.rankDocs(), rankDocsQuery.rankDocs());
62+
assertThat(rankDocsQuery.rankDocs(), equalTo(queryBuilder.rankDocs()));
63+
}
64+
65+
protected Query createTestQuery() throws IOException {
66+
return createRandomQuery().toQuery(createSearchExecutionContext());
67+
}
68+
69+
private RankDocsQueryBuilder createQueryBuilder() {
70+
return createRandomQuery();
71+
}
72+
73+
private RankDocsQueryBuilder createRandomQuery() {
74+
RankDoc[] rankDocs = new RankDoc[randomIntBetween(1, 5)];
75+
for (int i = 0; i < rankDocs.length; i++) {
76+
rankDocs[i] = new RankDoc(randomInt(), randomFloat(), randomIntBetween(0, 2));
77+
}
78+
return new RankDocsQueryBuilder(rankDocs, null, randomBoolean(), Float.MIN_VALUE);
6179
}
6280

6381
/**
@@ -151,7 +169,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException {
151169
rankDocs,
152170
new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) },
153171
new String[1],
154-
false
172+
false,
173+
Float.MIN_VALUE
155174
);
156175
var topDocsManager = new TopScoreDocCollectorManager(topSize, null, totalHitsThreshold);
157176
var col = searcher.search(q, topDocsManager);
@@ -172,7 +191,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException {
172191
rankDocs,
173192
new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) },
174193
new String[1],
175-
false
194+
false,
195+
Float.MIN_VALUE
176196
);
177197
var topDocsManager = new TopScoreDocCollectorManager(topSize, null, Integer.MAX_VALUE);
178198
var col = searcher.search(q, topDocsManager);
@@ -187,7 +207,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException {
187207
rankDocs,
188208
new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) },
189209
new String[1],
190-
true
210+
true,
211+
Float.MIN_VALUE
191212
);
192213
var topDocsManager = new TopScoreDocCollectorManager(topSize, null, Integer.MAX_VALUE);
193214
var col = searcher.search(q, topDocsManager);
@@ -204,7 +225,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException {
204225
singleRankDoc,
205226
new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) },
206227
new String[1],
207-
false
228+
false,
229+
Float.MIN_VALUE
208230
);
209231
var topDocsManager = new TopScoreDocCollectorManager(1, null, 0);
210232
var col = searcher.search(q, topDocsManager);
@@ -257,10 +279,29 @@ public void shouldThrowForNegativeScores() throws IOException {
257279
iw.addDocument(new Document());
258280
try (IndexReader reader = iw.getReader()) {
259281
SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader));
260-
RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder(new RankDoc[] { new RankDoc(0, -1.0f, 0) }, null, false);
282+
RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder(new RankDoc[] { new RankDoc(0, -1.0f, 0) }, null, false, Float.MIN_VALUE);
261283
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> queryBuilder.doToQuery(context));
262284
assertEquals("RankDoc scores must be positive values. Missing a normalization step?", ex.getMessage());
263285
}
264286
}
265287
}
288+
289+
public void testCreateQuery() throws IOException {
290+
try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) {
291+
iw.addDocument(new Document());
292+
try (IndexReader reader = iw.getReader()) {
293+
RankDoc[] rankDocs = new RankDoc[] { new RankDoc(0, randomFloat(), 0) };
294+
RankDocsQuery q = new RankDocsQuery(
295+
reader,
296+
rankDocs,
297+
new Query[] { new MatchAllDocsQuery() },
298+
new String[] { "test" },
299+
false,
300+
Float.MIN_VALUE
301+
);
302+
assertNotNull(q);
303+
assertArrayEquals(rankDocs, q.rankDocs());
304+
}
305+
}
306+
}
266307
}

server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public void testRankDocSerialization() throws IOException {
5050
for (int i = 0; i < totalDocs; i++) {
5151
docs.add(createTestRankDoc());
5252
}
53-
RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(docs.toArray((T[]) new RankDoc[0]), null, randomBoolean());
53+
RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(docs.toArray((T[]) new RankDoc[0]), null, randomBoolean(), Float.MIN_VALUE);
5454
RankDocsQueryBuilder copy = (RankDocsQueryBuilder) copyNamedWriteable(rankDocsQueryBuilder, writableRegistry(), QueryBuilder.class);
5555
assertThat(rankDocsQueryBuilder, equalTo(copy));
5656
}

0 commit comments

Comments
 (0)