Skip to content

Commit 5886ec8

Browse files
benwtrentbenchaplin
authored andcommitted
Reapply "Adds unused lower level ivf knn query (elastic#127852)" (elastic#128003) (elastic#128052)
* Reapply "Adds unused lower level ivf knn query (elastic#127852)" (elastic#128003) This reverts commit 648d74b. * Fixing tests
1 parent e23cfa8 commit 5886ec8

File tree

7 files changed

+1467
-11
lines changed

7 files changed

+1467
-11
lines changed

muted-tests.yml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -447,18 +447,9 @@ tests:
447447
- class: org.elasticsearch.indices.stats.IndexStatsIT
448448
method: testThrottleStats
449449
issue: https://github.com/elastic/elasticsearch/issues/126359
450-
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
451-
method: testRandomWithFilter
452-
issue: https://github.com/elastic/elasticsearch/issues/127963
453-
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
454-
method: testSearchBoost
455-
issue: https://github.com/elastic/elasticsearch/issues/127969
456450
- class: org.elasticsearch.packaging.test.DockerTests
457451
method: test040JavaUsesTheOsProvidedKeystore
458452
issue: https://github.com/elastic/elasticsearch/issues/127437
459-
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
460-
method: testFindFewer
461-
issue: https://github.com/elastic/elasticsearch/issues/128002
462453
- class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityRestIT
463454
method: testTaskCancellation
464455
issue: https://github.com/elastic/elasticsearch/issues/128009

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.lucene.util.FixedBitSet;
3333
import org.apache.lucene.util.hnsw.NeighborQueue;
3434
import org.elasticsearch.core.IOUtils;
35+
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
3536

3637
import java.io.IOException;
3738
import java.util.function.IntPredicate;
@@ -243,8 +244,11 @@ public final void search(String field, float[] target, KnnCollector knnCollector
243244
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
244245
return;
245246
}
246-
// TODO add new ivf search strategy
247-
int nProbe = 10;
247+
if (fieldInfo.getVectorDimension() != target.length) {
248+
throw new IllegalArgumentException(
249+
"vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension()
250+
);
251+
}
248252
float percentFiltered = 1f;
249253
if (acceptDocs instanceof BitSet bitSet) {
250254
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
@@ -257,6 +261,13 @@ public final void search(String field, float[] target, KnnCollector knnCollector
257261
}
258262
return visitedDocs.getAndSet(docId) == false;
259263
};
264+
final int nProbe;
265+
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
266+
nProbe = ivfSearchStrategy.getNProbe();
267+
} else {
268+
// TODO calculate nProbe given the number of centroids vs. number of vectors for given `k`
269+
nProbe = 10;
270+
}
260271

261272
FieldEntry entry = fields.get(fieldInfo.number);
262273
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.index.IndexReader;
13+
import org.apache.lucene.index.LeafReader;
14+
import org.apache.lucene.index.LeafReaderContext;
15+
import org.apache.lucene.search.BooleanClause;
16+
import org.apache.lucene.search.BooleanQuery;
17+
import org.apache.lucene.search.DocIdSetIterator;
18+
import org.apache.lucene.search.FieldExistsQuery;
19+
import org.apache.lucene.search.FilteredDocIdSetIterator;
20+
import org.apache.lucene.search.IndexSearcher;
21+
import org.apache.lucene.search.KnnCollector;
22+
import org.apache.lucene.search.MatchNoDocsQuery;
23+
import org.apache.lucene.search.Query;
24+
import org.apache.lucene.search.QueryVisitor;
25+
import org.apache.lucene.search.ScoreDoc;
26+
import org.apache.lucene.search.ScoreMode;
27+
import org.apache.lucene.search.Scorer;
28+
import org.apache.lucene.search.TaskExecutor;
29+
import org.apache.lucene.search.TopDocs;
30+
import org.apache.lucene.search.TopDocsCollector;
31+
import org.apache.lucene.search.TopKnnCollector;
32+
import org.apache.lucene.search.Weight;
33+
import org.apache.lucene.search.knn.KnnCollectorManager;
34+
import org.apache.lucene.search.knn.KnnSearchStrategy;
35+
import org.apache.lucene.util.BitSet;
36+
import org.apache.lucene.util.BitSetIterator;
37+
import org.apache.lucene.util.Bits;
38+
import org.elasticsearch.search.profile.query.QueryProfiler;
39+
40+
import java.io.IOException;
41+
import java.util.ArrayList;
42+
import java.util.List;
43+
import java.util.Objects;
44+
import java.util.concurrent.Callable;
45+
46+
abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerProvider {
47+
48+
static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
49+
50+
protected final String field;
51+
protected final int nProbe;
52+
protected final int k;
53+
protected final Query filter;
54+
protected final KnnSearchStrategy searchStrategy;
55+
protected int vectorOpsCount;
56+
57+
protected AbstractIVFKnnVectorQuery(String field, int nProbe, int k, Query filter) {
58+
this.field = field;
59+
this.nProbe = nProbe;
60+
this.k = k;
61+
this.filter = filter;
62+
this.searchStrategy = new IVFKnnSearchStrategy(nProbe);
63+
}
64+
65+
@Override
66+
public void visit(QueryVisitor visitor) {
67+
if (visitor.acceptField(field)) {
68+
visitor.visitLeaf(this);
69+
}
70+
}
71+
72+
@Override
73+
public boolean equals(Object o) {
74+
if (this == o) return true;
75+
if (o == null || getClass() != o.getClass()) return false;
76+
AbstractIVFKnnVectorQuery that = (AbstractIVFKnnVectorQuery) o;
77+
return k == that.k
78+
&& Objects.equals(field, that.field)
79+
&& Objects.equals(filter, that.filter)
80+
&& Objects.equals(nProbe, that.nProbe);
81+
}
82+
83+
@Override
84+
public int hashCode() {
85+
return Objects.hash(field, k, filter, nProbe);
86+
}
87+
88+
@Override
89+
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
90+
vectorOpsCount = 0;
91+
IndexReader reader = indexSearcher.getIndexReader();
92+
93+
final Weight filterWeight;
94+
if (filter != null) {
95+
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(filter, BooleanClause.Occur.FILTER)
96+
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
97+
.build();
98+
Query rewritten = indexSearcher.rewrite(booleanQuery);
99+
if (rewritten.getClass() == MatchNoDocsQuery.class) {
100+
return rewritten;
101+
}
102+
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
103+
} else {
104+
filterWeight = null;
105+
}
106+
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
107+
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
108+
List<LeafReaderContext> leafReaderContexts = reader.leaves();
109+
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
110+
for (LeafReaderContext context : leafReaderContexts) {
111+
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
112+
}
113+
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
114+
115+
// Merge sort the results
116+
TopDocs topK = TopDocs.merge(k, perLeafResults);
117+
vectorOpsCount = (int) topK.totalHits.value();
118+
if (topK.scoreDocs.length == 0) {
119+
return new MatchNoDocsQuery();
120+
}
121+
return new KnnScoreDocQuery(topK.scoreDocs, reader);
122+
}
123+
124+
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
125+
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
126+
if (ctx.docBase > 0) {
127+
for (ScoreDoc scoreDoc : results.scoreDocs) {
128+
scoreDoc.doc += ctx.docBase;
129+
}
130+
}
131+
return results;
132+
}
133+
134+
TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
135+
final LeafReader reader = ctx.reader();
136+
final Bits liveDocs = reader.getLiveDocs();
137+
138+
if (filterWeight == null) {
139+
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
140+
}
141+
142+
Scorer scorer = filterWeight.scorer(ctx);
143+
if (scorer == null) {
144+
return TopDocsCollector.EMPTY_TOPDOCS;
145+
}
146+
147+
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
148+
final int cost = acceptDocs.cardinality();
149+
return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
150+
}
151+
152+
abstract TopDocs approximateSearch(
153+
LeafReaderContext context,
154+
Bits acceptDocs,
155+
int visitedLimit,
156+
KnnCollectorManager knnCollectorManager
157+
) throws IOException;
158+
159+
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
160+
return new IVFCollectorManager(k, nProbe);
161+
}
162+
163+
@Override
164+
public final void profile(QueryProfiler queryProfiler) {
165+
queryProfiler.addVectorOpsCount(vectorOpsCount);
166+
}
167+
168+
BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException {
169+
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
170+
// If we already have a BitSet and no deletions, reuse the BitSet
171+
return bitSetIterator.getBitSet();
172+
} else {
173+
// Create a new BitSet from matching and live docs
174+
FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(iterator) {
175+
@Override
176+
protected boolean match(int doc) {
177+
return liveDocs == null || liveDocs.get(doc);
178+
}
179+
};
180+
return BitSet.of(filterIterator, maxDoc);
181+
}
182+
}
183+
184+
static class IVFCollectorManager implements KnnCollectorManager {
185+
private final int k;
186+
private final int nprobe;
187+
188+
IVFCollectorManager(int k, int nprobe) {
189+
this.k = k;
190+
this.nprobe = nprobe;
191+
}
192+
193+
@Override
194+
public KnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context) throws IOException {
195+
return new TopKnnCollector(k, visitedLimit, new IVFKnnSearchStrategy(nprobe));
196+
}
197+
}
198+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.search.vectors;
10+
11+
import org.apache.lucene.index.FloatVectorValues;
12+
import org.apache.lucene.index.LeafReader;
13+
import org.apache.lucene.index.LeafReaderContext;
14+
import org.apache.lucene.search.KnnCollector;
15+
import org.apache.lucene.search.Query;
16+
import org.apache.lucene.search.TopDocs;
17+
import org.apache.lucene.search.knn.KnnCollectorManager;
18+
import org.apache.lucene.util.Bits;
19+
20+
import java.io.IOException;
21+
import java.util.Arrays;
22+
23+
/** A {@link IVFKnnFloatVectorQuery} that uses the IVF search strategy. */
24+
public class IVFKnnFloatVectorQuery extends AbstractIVFKnnVectorQuery {
25+
26+
private final float[] query;
27+
28+
/**
29+
* Creates a new {@link IVFKnnFloatVectorQuery} with the given parameters.
30+
* @param field the field to search
31+
* @param query the query vector
32+
* @param k the number of nearest neighbors to return
33+
* @param filter the filter to apply to the results
34+
* @param nProbe the number of probes to use for the IVF search strategy
35+
*/
36+
public IVFKnnFloatVectorQuery(String field, float[] query, int k, Query filter, int nProbe) {
37+
super(field, nProbe, k, filter);
38+
if (k < 1) {
39+
throw new IllegalArgumentException("k must be at least 1, got: " + k);
40+
}
41+
if (nProbe < 1) {
42+
throw new IllegalArgumentException("nProbe must be at least 1, got: " + nProbe);
43+
}
44+
this.query = query;
45+
}
46+
47+
@Override
48+
public String toString(String field) {
49+
StringBuilder buffer = new StringBuilder();
50+
buffer.append(getClass().getSimpleName())
51+
.append(":")
52+
.append(this.field)
53+
.append("[")
54+
.append(query[0])
55+
.append(",...]")
56+
.append("[")
57+
.append(k)
58+
.append("]");
59+
if (this.filter != null) {
60+
buffer.append("[").append(this.filter).append("]");
61+
}
62+
return buffer.toString();
63+
}
64+
65+
@Override
66+
public boolean equals(Object o) {
67+
if (this == o) return true;
68+
if (super.equals(o) == false) return false;
69+
IVFKnnFloatVectorQuery that = (IVFKnnFloatVectorQuery) o;
70+
return Arrays.equals(query, that.query);
71+
}
72+
73+
@Override
74+
public int hashCode() {
75+
int result = super.hashCode();
76+
result = 31 * result + Arrays.hashCode(query);
77+
return result;
78+
}
79+
80+
@Override
81+
protected TopDocs approximateSearch(
82+
LeafReaderContext context,
83+
Bits acceptDocs,
84+
int visitedLimit,
85+
KnnCollectorManager knnCollectorManager
86+
) throws IOException {
87+
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, searchStrategy, context);
88+
LeafReader reader = context.reader();
89+
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
90+
if (floatVectorValues == null) {
91+
FloatVectorValues.checkField(reader, field);
92+
return NO_RESULTS;
93+
}
94+
if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) {
95+
return NO_RESULTS;
96+
}
97+
reader.searchNearestVectors(field, query, knnCollector, acceptDocs);
98+
TopDocs results = knnCollector.topDocs();
99+
return results != null ? results : NO_RESULTS;
100+
}
101+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.search.vectors;
10+
11+
import org.apache.lucene.search.knn.KnnSearchStrategy;
12+
13+
import java.util.Objects;
14+
15+
public class IVFKnnSearchStrategy extends KnnSearchStrategy {
16+
private final int nProbe;
17+
18+
IVFKnnSearchStrategy(int nProbe) {
19+
this.nProbe = nProbe;
20+
}
21+
22+
public int getNProbe() {
23+
return nProbe;
24+
}
25+
26+
@Override
27+
public boolean equals(Object o) {
28+
if (this == o) return true;
29+
if (o == null || getClass() != o.getClass()) return false;
30+
IVFKnnSearchStrategy that = (IVFKnnSearchStrategy) o;
31+
return nProbe == that.nProbe;
32+
}
33+
34+
@Override
35+
public int hashCode() {
36+
return Objects.hashCode(nProbe);
37+
}
38+
39+
@Override
40+
public void nextVectorsBlock() {
41+
// do nothing
42+
}
43+
}

0 commit comments

Comments
 (0)