Skip to content

Commit 2371ec3

Browse files
committed
always apply diversification when nested and flat
1 parent 9a1b722 commit 2371ec3

File tree

5 files changed

+217
-118
lines changed

5 files changed

+217
-118
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException
309309
}
310310
if (overSamplingFactor > 1f) {
311311
// oversample the topK results to get more candidates for the final result
312-
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery, null);
312+
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery);
313313
}
314314
QueryProfiler profiler = new QueryProfiler();
315315
TopDocs docs = searcher.search(knnQuery, this.topK);

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
import org.elasticsearch.search.lookup.Source;
8080
import org.elasticsearch.search.vectors.DenseVectorQuery;
8181
import org.elasticsearch.search.vectors.DiversifyingChildrenIVFKnnFloatVectorQuery;
82+
import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery;
8283
import org.elasticsearch.search.vectors.ESDiversifyingChildrenByteKnnVectorQuery;
8384
import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery;
8485
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
@@ -2546,9 +2547,12 @@ private Query createKnnBitQuery(
25462547
elementType.checkDimensions(dims, queryVector.length);
25472548
Query knnQuery;
25482549
if (indexOptions != null && indexOptions.isFlat()) {
2550+
var exactKnnQuery = parentFilter != null
2551+
? new DiversifyingParentBlockQuery(parentFilter, createExactKnnBitQuery(queryVector))
2552+
: createExactKnnBitQuery(queryVector);
25492553
knnQuery = filter == null
25502554
? createExactKnnBitQuery(queryVector)
2551-
: new BooleanQuery.Builder().add(createExactKnnBitQuery(queryVector), BooleanClause.Occur.SHOULD)
2555+
: new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
25522556
.add(filter, BooleanClause.Occur.FILTER)
25532557
.build();
25542558
} else {
@@ -2584,9 +2588,12 @@ private Query createKnnByteQuery(
25842588

25852589
Query knnQuery;
25862590
if (indexOptions != null && indexOptions.isFlat()) {
2591+
var exactKnnQuery = parentFilter != null
2592+
? new DiversifyingParentBlockQuery(parentFilter, createExactKnnByteQuery(queryVector))
2593+
: createExactKnnByteQuery(queryVector);
25872594
knnQuery = filter == null
25882595
? createExactKnnByteQuery(queryVector)
2589-
: new BooleanQuery.Builder().add(createExactKnnByteQuery(queryVector), BooleanClause.Occur.SHOULD)
2596+
: new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
25902597
.add(filter, BooleanClause.Occur.FILTER)
25912598
.build();
25922599
} else {
@@ -2647,9 +2654,12 @@ && isNotUnitVector(squaredMagnitude)) {
26472654
}
26482655
Query knnQuery;
26492656
if (indexOptions != null && indexOptions.isFlat()) {
2657+
var exactKnnQuery = parentFilter != null
2658+
? new DiversifyingParentBlockQuery(parentFilter, createExactKnnFloatQuery(queryVector))
2659+
: createExactKnnFloatQuery(queryVector);
26502660
knnQuery = filter == null
26512661
? createExactKnnFloatQuery(queryVector)
2652-
: new BooleanQuery.Builder().add(createExactKnnFloatQuery(queryVector), BooleanClause.Occur.SHOULD)
2662+
: new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
26532663
.add(filter, BooleanClause.Occur.FILTER)
26542664
.build();
26552665
} else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) {
@@ -2684,8 +2694,7 @@ && isNotUnitVector(squaredMagnitude)) {
26842694
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT),
26852695
k,
26862696
adjustedK,
2687-
knnQuery,
2688-
parentFilter
2697+
knnQuery
26892698
);
26902699
}
26912700
if (similarityThreshold != null) {
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.index.LeafReaderContext;
13+
import org.apache.lucene.search.BooleanClause;
14+
import org.apache.lucene.search.DocIdSetIterator;
15+
import org.apache.lucene.search.Explanation;
16+
import org.apache.lucene.search.IndexSearcher;
17+
import org.apache.lucene.search.Query;
18+
import org.apache.lucene.search.QueryVisitor;
19+
import org.apache.lucene.search.ScoreMode;
20+
import org.apache.lucene.search.Scorer;
21+
import org.apache.lucene.search.ScorerSupplier;
22+
import org.apache.lucene.search.Weight;
23+
import org.apache.lucene.search.join.BitSetProducer;
24+
25+
import java.io.IOException;
26+
import java.util.Objects;
27+
28+
/**
29+
* A Lucene query that selects the highest-scoring child document for each parent block.
30+
* <p>
31+
* Children are scored using the {@code innerQuery}, and for each parent (as defined by the
32+
* {@code parentFilter}), the single best-scoring child is returned.
33+
*/
34+
public class DiversifyingParentBlockQuery extends Query {
35+
private final BitSetProducer parentFilter;
36+
private final Query innerQuery;
37+
38+
public DiversifyingParentBlockQuery(BitSetProducer parentFilter, Query innerQuery) {
39+
this.parentFilter = Objects.requireNonNull(parentFilter);
40+
this.innerQuery = Objects.requireNonNull(innerQuery);
41+
}
42+
43+
@Override
44+
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
45+
Query rewritten = innerQuery.rewrite(indexSearcher);
46+
if (rewritten != innerQuery) {
47+
return new DiversifyingParentBlockQuery(parentFilter, rewritten);
48+
}
49+
return this;
50+
}
51+
52+
@Override
53+
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
54+
Weight innerWeight = innerQuery.createWeight(searcher, scoreMode, boost);
55+
return new DiversifyingParentBlockWeight(this, innerWeight, parentFilter);
56+
}
57+
58+
@Override
59+
public String toString(String field) {
60+
return "DiversifyingBlockQuery(inner=" + innerQuery.toString(field) + ")";
61+
}
62+
63+
@Override
64+
public void visit(QueryVisitor visitor) {
65+
innerQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this));
66+
}
67+
68+
@Override
69+
public boolean equals(Object o) {
70+
if (this == o) return true;
71+
if (o == null || getClass() != o.getClass()) return false;
72+
DiversifyingParentBlockQuery that = (DiversifyingParentBlockQuery) o;
73+
return Objects.equals(innerQuery, that.innerQuery) && parentFilter == that.parentFilter;
74+
}
75+
76+
@Override
77+
public int hashCode() {
78+
return Objects.hash(innerQuery, parentFilter);
79+
}
80+
81+
private static class DiversifyingParentBlockWeight extends Weight {
82+
private final Weight innerWeight;
83+
private final BitSetProducer parentFilter;
84+
85+
DiversifyingParentBlockWeight(Query query, Weight innerWeight, BitSetProducer parentFilter) {
86+
super(query);
87+
this.innerWeight = innerWeight;
88+
this.parentFilter = parentFilter;
89+
}
90+
91+
@Override
92+
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
93+
return innerWeight.explain(context, doc);
94+
}
95+
96+
@Override
97+
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
98+
var innerSupplier = innerWeight.scorerSupplier(context);
99+
var parentBits = parentFilter.getBitSet(context);
100+
if (parentBits == null || innerSupplier == null) {
101+
return null;
102+
}
103+
104+
return new ScorerSupplier() {
105+
@Override
106+
public Scorer get(long leadCost) throws IOException {
107+
var innerScorer = innerSupplier.get(leadCost);
108+
var innerIterator = innerScorer.iterator();
109+
return new Scorer() {
110+
int currentDoc = -1;
111+
float currentScore = Float.NaN;
112+
113+
@Override
114+
public int docID() {
115+
return currentDoc;
116+
}
117+
118+
@Override
119+
public DocIdSetIterator iterator() {
120+
return new DocIdSetIterator() {
121+
boolean exhausted = false;
122+
123+
@Override
124+
public int docID() {
125+
return currentDoc;
126+
}
127+
128+
@Override
129+
public int nextDoc() throws IOException {
130+
return advance(currentDoc + 1);
131+
}
132+
133+
@Override
134+
public int advance(int target) throws IOException {
135+
if (exhausted) {
136+
return NO_MORE_DOCS;
137+
}
138+
if (currentDoc == -1 || innerIterator.docID() < target) {
139+
if (innerIterator.advance(target) == NO_MORE_DOCS) {
140+
exhausted = true;
141+
return currentDoc = NO_MORE_DOCS;
142+
}
143+
}
144+
145+
int bestChild = innerIterator.docID();
146+
float bestScore = innerScorer.score();
147+
int parent = parentBits.nextSetBit(bestChild);
148+
149+
int innerDoc;
150+
while ((innerDoc = innerIterator.nextDoc()) < parent) {
151+
float score = innerScorer.score();
152+
if (score > bestScore) {
153+
bestChild = innerIterator.docID();
154+
bestScore = score;
155+
}
156+
}
157+
if (innerDoc == NO_MORE_DOCS) {
158+
exhausted = true;
159+
}
160+
currentScore = bestScore;
161+
return currentDoc = bestChild;
162+
}
163+
164+
@Override
165+
public long cost() {
166+
return innerIterator.cost();
167+
}
168+
};
169+
}
170+
171+
@Override
172+
public float score() throws IOException {
173+
return currentScore;
174+
}
175+
176+
@Override
177+
public float getMaxScore(int upTo) throws IOException {
178+
return innerScorer.getMaxScore(upTo);
179+
}
180+
};
181+
}
182+
183+
@Override
184+
public long cost() {
185+
return innerSupplier.cost();
186+
}
187+
};
188+
}
189+
190+
@Override
191+
public boolean isCacheable(LeafReaderContext ctx) {
192+
return false;
193+
}
194+
}
195+
}

0 commit comments

Comments
 (0)