Skip to content

Commit fc31076

Browse files
benwtrentcbuescher
authored andcommitted
Fix bbq for Lucene 10
1 parent 70ce7a1 commit fc31076

9 files changed

+237
-270
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/BinarizedByteVectorValues.java

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,40 +19,71 @@
1919
*/
2020
package org.elasticsearch.index.codec.vectors;
2121

22-
import org.apache.lucene.search.DocIdSetIterator;
22+
import org.apache.lucene.index.ByteVectorValues;
2323
import org.apache.lucene.search.VectorScorer;
24+
import org.apache.lucene.util.VectorUtil;
2425

2526
import java.io.IOException;
2627

2728
/**
2829
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
2930
*/
30-
public abstract class BinarizedByteVectorValues extends DocIdSetIterator {
31+
public abstract class BinarizedByteVectorValues extends ByteVectorValues {
3132

32-
public abstract float[] getCorrectiveTerms();
33-
34-
public abstract byte[] vectorValue() throws IOException;
33+
public abstract float[] getCorrectiveTerms(int vectorOrd) throws IOException;
3534

3635
/** Return the dimension of the vectors */
3736
public abstract int dimension();
3837

38+
/** Returns the centroid distance for the vector */
39+
public abstract float getCentroidDistance(int vectorOrd) throws IOException;
40+
41+
/** Returns the vector magnitude for the vector */
42+
public abstract float getVectorMagnitude(int vectorOrd) throws IOException;
43+
44+
/** Returns OOQ corrective factor for the given vector ordinal */
45+
public abstract float getOOQ(int targetOrd) throws IOException;
46+
47+
/**
48+
* Returns the norm of the target vector w the centroid corrective factor for the given vector
49+
* ordinal
50+
*/
51+
public abstract float getNormOC(int targetOrd) throws IOException;
52+
53+
/**
54+
* Returns the target vector dot product the centroid corrective factor for the given vector
55+
* ordinal
56+
*/
57+
public abstract float getODotC(int targetOrd) throws IOException;
58+
59+
/**
60+
* @return the quantizer used to quantize the vectors
61+
*/
62+
public abstract BinaryQuantizer getQuantizer();
63+
64+
public abstract float[] getCentroid() throws IOException;
65+
3966
/**
4067
* Return the number of vectors for this field.
4168
*
4269
* @return the number of vectors returned by this iterator
4370
*/
4471
public abstract int size();
4572

46-
@Override
47-
public final long cost() {
48-
return size();
49-
}
50-
5173
/**
5274
* Return a {@link VectorScorer} for the given query vector.
5375
*
5476
* @param query the query vector
5577
* @return a {@link VectorScorer} instance or null
5678
*/
5779
public abstract VectorScorer scorer(float[] query) throws IOException;
80+
81+
@Override
82+
public abstract BinarizedByteVectorValues copy() throws IOException;
83+
84+
float getCentroidDP() throws IOException {
85+
// this only gets executed on-merge
86+
float[] centroid = getCentroid();
87+
return VectorUtil.dotProduct(centroid, centroid);
88+
}
5889
}

server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorer.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
package org.elasticsearch.index.codec.vectors;
2121

2222
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
23+
import org.apache.lucene.index.KnnVectorValues;
2324
import org.apache.lucene.index.VectorSimilarityFunction;
2425
import org.apache.lucene.util.ArrayUtil;
2526
import org.apache.lucene.util.VectorUtil;
26-
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
2727
import org.apache.lucene.util.hnsw.RandomVectorScorer;
2828
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
2929
import org.elasticsearch.simdvec.ESVectorUtil;
@@ -45,9 +45,9 @@ public ES816BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) {
4545
@Override
4646
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
4747
VectorSimilarityFunction similarityFunction,
48-
RandomAccessVectorValues vectorValues
48+
KnnVectorValues vectorValues
4949
) throws IOException {
50-
if (vectorValues instanceof RandomAccessBinarizedByteVectorValues) {
50+
if (vectorValues instanceof BinarizedByteVectorValues) {
5151
throw new UnsupportedOperationException(
5252
"getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format"
5353
);
@@ -58,10 +58,10 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
5858
@Override
5959
public RandomVectorScorer getRandomVectorScorer(
6060
VectorSimilarityFunction similarityFunction,
61-
RandomAccessVectorValues vectorValues,
61+
KnnVectorValues vectorValues,
6262
float[] target
6363
) throws IOException {
64-
if (vectorValues instanceof RandomAccessBinarizedByteVectorValues binarizedVectors) {
64+
if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) {
6565
BinaryQuantizer quantizer = binarizedVectors.getQuantizer();
6666
float[] centroid = binarizedVectors.getCentroid();
6767
// FIXME: precompute this once?
@@ -82,7 +82,7 @@ public RandomVectorScorer getRandomVectorScorer(
8282
@Override
8383
public RandomVectorScorer getRandomVectorScorer(
8484
VectorSimilarityFunction similarityFunction,
85-
RandomAccessVectorValues vectorValues,
85+
KnnVectorValues vectorValues,
8686
byte[] target
8787
) throws IOException {
8888
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
@@ -91,7 +91,7 @@ public RandomVectorScorer getRandomVectorScorer(
9191
RandomVectorScorerSupplier getRandomVectorScorerSupplier(
9292
VectorSimilarityFunction similarityFunction,
9393
ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors,
94-
RandomAccessBinarizedByteVectorValues targetVectors
94+
BinarizedByteVectorValues targetVectors
9595
) {
9696
return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction);
9797
}
@@ -104,12 +104,12 @@ public String toString() {
104104
/** Vector scorer supplier over binarized vector values */
105105
static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
106106
private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors;
107-
private final RandomAccessBinarizedByteVectorValues targetVectors;
107+
private final BinarizedByteVectorValues targetVectors;
108108
private final VectorSimilarityFunction similarityFunction;
109109

110110
BinarizedRandomVectorScorerSupplier(
111111
ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors,
112-
RandomAccessBinarizedByteVectorValues targetVectors,
112+
BinarizedByteVectorValues targetVectors,
113113
VectorSimilarityFunction similarityFunction
114114
) {
115115
this.queryVectors = queryVectors;
@@ -149,14 +149,14 @@ public record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors fact
149149
/** Vector scorer over binarized vector values */
150150
public static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
151151
private final BinaryQueryVector queryVector;
152-
private final RandomAccessBinarizedByteVectorValues targetVectors;
152+
private final BinarizedByteVectorValues targetVectors;
153153
private final VectorSimilarityFunction similarityFunction;
154154

155155
private final float sqrtDimensions;
156156

157157
public BinarizedRandomVectorScorer(
158158
BinaryQueryVector queryVectors,
159-
RandomAccessBinarizedByteVectorValues targetVectors,
159+
BinarizedByteVectorValues targetVectors,
160160
VectorSimilarityFunction similarityFunction
161161
) {
162162
super(targetVectors);

server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsReader.java

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.apache.lucene.store.ChecksumIndexInput;
3737
import org.apache.lucene.store.IOContext;
3838
import org.apache.lucene.store.IndexInput;
39+
import org.apache.lucene.store.ReadAdvice;
3940
import org.apache.lucene.util.Bits;
4041
import org.apache.lucene.util.IOUtils;
4142
import org.apache.lucene.util.RamUsageEstimator;
@@ -78,7 +79,7 @@ public ES816BinaryQuantizedVectorsReader(
7879
ES816BinaryQuantizedVectorsFormat.META_EXTENSION
7980
);
8081
boolean success = false;
81-
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName, state.context)) {
82+
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
8283
Throwable priorE = null;
8384
try {
8485
versionMeta = CodecUtil.checkIndexHeader(
@@ -102,7 +103,7 @@ public ES816BinaryQuantizedVectorsReader(
102103
ES816BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
103104
// Quantized vectors are accessed randomly from their node ID stored in the HNSW
104105
// graph.
105-
state.context.withRandomAccess()
106+
state.context.withReadAdvice(ReadAdvice.RANDOM)
106107
);
107108
success = true;
108109
} finally {
@@ -357,9 +358,9 @@ static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, Vector
357358
/** Binarized vector values holding row and quantized vector values */
358359
protected static final class BinarizedVectorValues extends FloatVectorValues {
359360
private final FloatVectorValues rawVectorValues;
360-
private final OffHeapBinarizedVectorValues quantizedVectorValues;
361+
private final BinarizedByteVectorValues quantizedVectorValues;
361362

362-
BinarizedVectorValues(FloatVectorValues rawVectorValues, OffHeapBinarizedVectorValues quantizedVectorValues) {
363+
BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) {
363364
this.rawVectorValues = rawVectorValues;
364365
this.quantizedVectorValues = quantizedVectorValues;
365366
}
@@ -375,37 +376,36 @@ public int size() {
375376
}
376377

377378
@Override
378-
public float[] vectorValue() throws IOException {
379-
return rawVectorValues.vectorValue();
379+
public float[] vectorValue(int ord) throws IOException {
380+
return rawVectorValues.vectorValue(ord);
380381
}
381382

382383
@Override
383-
public int docID() {
384-
return rawVectorValues.docID();
384+
public BinarizedVectorValues copy() throws IOException {
385+
return new BinarizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy());
385386
}
386387

387388
@Override
388-
public int nextDoc() throws IOException {
389-
int rawDocId = rawVectorValues.nextDoc();
390-
int quantizedDocId = quantizedVectorValues.nextDoc();
391-
assert rawDocId == quantizedDocId;
392-
return quantizedDocId;
389+
public Bits getAcceptOrds(Bits acceptDocs) {
390+
return rawVectorValues.getAcceptOrds(acceptDocs);
393391
}
394392

395393
@Override
396-
public int advance(int target) throws IOException {
397-
int rawDocId = rawVectorValues.advance(target);
398-
int quantizedDocId = quantizedVectorValues.advance(target);
399-
assert rawDocId == quantizedDocId;
400-
return quantizedDocId;
394+
public int ordToDoc(int ord) {
395+
return rawVectorValues.ordToDoc(ord);
396+
}
397+
398+
@Override
399+
public DocIndexIterator iterator() {
400+
return rawVectorValues.iterator();
401401
}
402402

403403
@Override
404404
public VectorScorer scorer(float[] query) throws IOException {
405405
return quantizedVectorValues.scorer(query);
406406
}
407407

408-
protected OffHeapBinarizedVectorValues getQuantizedVectorValues() throws IOException {
408+
protected BinarizedByteVectorValues getQuantizedVectorValues() throws IOException {
409409
return quantizedVectorValues;
410410
}
411411
}

0 commit comments

Comments
 (0)