From 090e7b8cd4c281de3c8cf88e72b7ec7d91e37de4 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 8 Oct 2024 16:38:36 -0400 Subject: [PATCH] Fix bbq for Lucene 10 --- .../vectors/BinarizedByteVectorValues.java | 51 +++++-- .../vectors/ES816BinaryFlatVectorsScorer.java | 22 +-- .../ES816BinaryQuantizedVectorsReader.java | 38 ++--- .../ES816BinaryQuantizedVectorsWriter.java | 143 ++++++++++-------- .../vectors/OffHeapBinarizedVectorValues.java | 100 +++--------- ...RandomAccessBinarizedByteVectorValues.java | 70 --------- .../ES816BinaryFlatVectorsScorerTests.java | 43 +++++- ...S816BinaryQuantizedVectorsFormatTests.java | 28 ++-- ...HnswBinaryQuantizedVectorsFormatTests.java | 12 +- 9 files changed, 237 insertions(+), 270 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/RandomAccessBinarizedByteVectorValues.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BinarizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BinarizedByteVectorValues.java index 73dd4273a794e..1929c36def141 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BinarizedByteVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BinarizedByteVectorValues.java @@ -19,23 +19,50 @@ */ package org.elasticsearch.index.codec.vectors; -import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.VectorUtil; import java.io.IOException; /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ -public abstract class BinarizedByteVectorValues extends DocIdSetIterator { +public abstract class BinarizedByteVectorValues extends ByteVectorValues { - public abstract float[] getCorrectiveTerms(); - - public abstract byte[] vectorValue() throws IOException; + public abstract float[] getCorrectiveTerms(int vectorOrd) throws IOException; /** Return the dimension of the vectors */ public abstract int dimension(); + /** Returns the centroid distance for the vector */ + public abstract float getCentroidDistance(int vectorOrd) throws IOException; + + /** Returns the vector magnitude for the vector */ + public abstract float getVectorMagnitude(int vectorOrd) throws IOException; + + /** Returns OOQ corrective factor for the given vector ordinal */ + public abstract float getOOQ(int targetOrd) throws IOException; + + /** + * Returns the norm of the target vector w the centroid corrective factor for the given vector + * ordinal + */ + public abstract float getNormOC(int targetOrd) throws IOException; + + /** + * Returns the target vector dot product the centroid corrective factor for the given vector + * ordinal + */ + public abstract float getODotC(int targetOrd) throws IOException; + + /** + * @return the quantizer used to quantize the vectors + */ + public abstract BinaryQuantizer getQuantizer(); + + public abstract float[] getCentroid() throws IOException; + /** * Return the number of vectors for this field. * @@ -43,11 +70,6 @@ public abstract class BinarizedByteVectorValues extends DocIdSetIterator { */ public abstract int size(); - @Override - public final long cost() { - return size(); - } - /** * Return a {@link VectorScorer} for the given query vector. * @@ -55,4 +77,13 @@ public final long cost() { * @return a {@link VectorScorer} instance or null */ public abstract VectorScorer scorer(float[] query) throws IOException; + + @Override + public abstract BinarizedByteVectorValues copy() throws IOException; + + float getCentroidDP() throws IOException { + // this only gets executed on-merge + float[] centroid = getCentroid(); + return VectorUtil.dotProduct(centroid, centroid); + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorer.java index 78fa282709098..cd360c6d6ff20 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorer.java @@ -20,10 +20,10 @@ package org.elasticsearch.index.codec.vectors; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.elasticsearch.simdvec.ESVectorUtil; @@ -45,9 +45,9 @@ public ES816BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues + KnnVectorValues vectorValues ) throws IOException { - if (vectorValues instanceof RandomAccessBinarizedByteVectorValues) { + if (vectorValues instanceof BinarizedByteVectorValues) { throw new UnsupportedOperationException( "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format" ); @@ -58,10 +58,10 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( @Override public RandomVectorScorer getRandomVectorScorer( VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, + KnnVectorValues vectorValues, float[] target ) throws IOException { - if (vectorValues instanceof RandomAccessBinarizedByteVectorValues binarizedVectors) { + if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) { BinaryQuantizer quantizer = binarizedVectors.getQuantizer(); float[] centroid = binarizedVectors.getCentroid(); // FIXME: precompute this once? @@ -82,7 +82,7 @@ public RandomVectorScorer getRandomVectorScorer( @Override public RandomVectorScorer getRandomVectorScorer( VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, + KnnVectorValues vectorValues, byte[] target ) throws IOException { return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); @@ -91,7 +91,7 @@ public RandomVectorScorer getRandomVectorScorer( RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, - RandomAccessBinarizedByteVectorValues targetVectors + BinarizedByteVectorValues targetVectors ) { return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction); } @@ -104,12 +104,12 @@ public String toString() { /** Vector scorer supplier over binarized vector values */ static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; - private final RandomAccessBinarizedByteVectorValues targetVectors; + private final BinarizedByteVectorValues targetVectors; private final VectorSimilarityFunction similarityFunction; BinarizedRandomVectorScorerSupplier( ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, - RandomAccessBinarizedByteVectorValues targetVectors, + BinarizedByteVectorValues targetVectors, VectorSimilarityFunction similarityFunction ) { this.queryVectors = queryVectors; @@ -149,14 +149,14 @@ public record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors fact /** Vector scorer over binarized vector values */ public static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { private final BinaryQueryVector queryVector; - private final RandomAccessBinarizedByteVectorValues targetVectors; + private final BinarizedByteVectorValues targetVectors; private final VectorSimilarityFunction similarityFunction; private final float sqrtDimensions; public BinarizedRandomVectorScorer( BinaryQueryVector queryVectors, - RandomAccessBinarizedByteVectorValues targetVectors, + BinarizedByteVectorValues targetVectors, VectorSimilarityFunction similarityFunction ) { super(targetVectors); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsReader.java index b0378fee6793d..21c4a5c449387 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsReader.java @@ -36,6 +36,7 @@ import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.ReadAdvice; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; @@ -78,7 +79,7 @@ public ES816BinaryQuantizedVectorsReader( ES816BinaryQuantizedVectorsFormat.META_EXTENSION ); boolean success = false; - try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName, state.context)) { + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { Throwable priorE = null; try { versionMeta = CodecUtil.checkIndexHeader( @@ -102,7 +103,7 @@ public ES816BinaryQuantizedVectorsReader( ES816BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, // Quantized vectors are accessed randomly from their node ID stored in the HNSW // graph. - state.context.withRandomAccess() + state.context.withReadAdvice(ReadAdvice.RANDOM) ); success = true; } finally { @@ -357,9 +358,9 @@ static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, Vector /** Binarized vector values holding row and quantized vector values */ protected static final class BinarizedVectorValues extends FloatVectorValues { private final FloatVectorValues rawVectorValues; - private final OffHeapBinarizedVectorValues quantizedVectorValues; + private final BinarizedByteVectorValues quantizedVectorValues; - BinarizedVectorValues(FloatVectorValues rawVectorValues, OffHeapBinarizedVectorValues quantizedVectorValues) { + BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) { this.rawVectorValues = rawVectorValues; this.quantizedVectorValues = quantizedVectorValues; } @@ -375,29 +376,28 @@ public int size() { } @Override - public float[] vectorValue() throws IOException { - return rawVectorValues.vectorValue(); + public float[] vectorValue(int ord) throws IOException { + return rawVectorValues.vectorValue(ord); } @Override - public int docID() { - return rawVectorValues.docID(); + public BinarizedVectorValues copy() throws IOException { + return new BinarizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy()); } @Override - public int nextDoc() throws IOException { - int rawDocId = rawVectorValues.nextDoc(); - int quantizedDocId = quantizedVectorValues.nextDoc(); - assert rawDocId == quantizedDocId; - return quantizedDocId; + public Bits getAcceptOrds(Bits acceptDocs) { + return rawVectorValues.getAcceptOrds(acceptDocs); } @Override - public int advance(int target) throws IOException { - int rawDocId = rawVectorValues.advance(target); - int quantizedDocId = quantizedVectorValues.advance(target); - assert rawDocId == quantizedDocId; - return quantizedDocId; + public int ordToDoc(int ord) { + return rawVectorValues.ordToDoc(ord); + } + + @Override + public DocIndexIterator iterator() { + return rawVectorValues.iterator(); } @Override @@ -405,7 +405,7 @@ public VectorScorer scorer(float[] query) throws IOException { return quantizedVectorValues.scorer(query); } - protected OffHeapBinarizedVectorValues getQuantizedVectorValues() throws IOException { + protected BinarizedByteVectorValues getQuantizedVectorValues() throws IOException { return quantizedVectorValues; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsWriter.java index 92837a8ffce45..a7774b850b64c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsWriter.java @@ -30,6 +30,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; @@ -44,7 +45,6 @@ import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.elasticsearch.core.SuppressForbidden; @@ -354,10 +354,11 @@ static DocsWithFieldSet writeBinarizedVectorAndQueryData( int queryCorrectionCount = binaryQuantizer.getSimilarity() != EUCLIDEAN ? 5 : 3; final ByteBuffer queryCorrectionsBuffer = ByteBuffer.allocate(Float.BYTES * queryCorrectionCount + Short.BYTES) .order(ByteOrder.LITTLE_ENDIAN); - for (int docV = floatVectorValues.nextDoc(); docV != NO_MORE_DOCS; docV = floatVectorValues.nextDoc()) { + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { // write index vector BinaryQuantizer.QueryAndIndexResults r = binaryQuantizer.quantizeQueryAndIndex( - floatVectorValues.vectorValue(), + floatVectorValues.vectorValue(iterator.index()), toIndex, toQuery, centroid @@ -393,11 +394,12 @@ static DocsWithFieldSet writeBinarizedVectorAndQueryData( static DocsWithFieldSet writeBinarizedVectorData(IndexOutput output, BinarizedByteVectorValues binarizedByteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - for (int docV = binarizedByteVectorValues.nextDoc(); docV != NO_MORE_DOCS; docV = binarizedByteVectorValues.nextDoc()) { + KnnVectorValues.DocIndexIterator iterator = binarizedByteVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { // write vector - byte[] binaryValue = binarizedByteVectorValues.vectorValue(); + byte[] binaryValue = binarizedByteVectorValues.vectorValue(iterator.index()); output.writeBytes(binaryValue, binaryValue.length); - float[] corrections = binarizedByteVectorValues.getCorrectiveTerms(); + float[] corrections = binarizedByteVectorValues.getCorrectiveTerms(iterator.index()); for (int i = 0; i < corrections.length; i++) { output.writeInt(Float.floatToIntBits(corrections[i])); } @@ -598,8 +600,9 @@ static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] if (vectorValues == null) { continue; } - for (int doc = vectorValues.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = vectorValues.nextDoc()) { - float[] vector = vectorValues.vectorValue(); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { + float[] vector = vectorValues.vectorValue(iterator.index()); // TODO Panama sum for (int j = 0; j < vector.length; j++) { centroid[j] += vector[j]; @@ -827,23 +830,31 @@ static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { private final float[] centroid; private final FloatVectorValues values; private final BinaryQuantizer quantizer; - private int lastDoc; + private int lastOrd = -1; BinarizedFloatVectorValues(FloatVectorValues delegate, BinaryQuantizer quantizer, float[] centroid) { this.values = delegate; this.quantizer = quantizer; this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8]; this.centroid = centroid; - lastDoc = -1; } @Override - public float[] getCorrectiveTerms() { + public float[] getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd + ); + } return corrections; } @Override - public byte[] vectorValue() throws IOException { + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + binarize(ord); + lastOrd = ord; + } return binarized; } @@ -853,33 +864,43 @@ public int dimension() { } @Override - public int size() { - return values.size(); + public float getCentroidDistance(int vectorOrd) throws IOException { + throw new UnsupportedOperationException(); } @Override - public int docID() { - return values.docID(); + public float getVectorMagnitude(int vectorOrd) throws IOException { + throw new UnsupportedOperationException(); } @Override - public int nextDoc() throws IOException { - int doc = values.nextDoc(); - if (doc != NO_MORE_DOCS) { - binarize(); - } - lastDoc = doc; - return doc; + public float getOOQ(int targetOrd) throws IOException { + throw new UnsupportedOperationException(); } @Override - public int advance(int target) throws IOException { - int doc = values.advance(target); - if (doc != NO_MORE_DOCS) { - binarize(); - } - lastDoc = doc; - return doc; + public float getNormOC(int targetOrd) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public float getODotC(int targetOrd) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public BinaryQuantizer getQuantizer() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] getCentroid() throws IOException { + return centroid; + } + + @Override + public int size() { + return values.size(); } @Override @@ -887,22 +908,32 @@ public VectorScorer scorer(float[] target) throws IOException { throw new UnsupportedOperationException(); } - private void binarize() throws IOException { - if (lastDoc == docID()) return; - corrections = quantizer.quantizeForIndex(values.vectorValue(), binarized, centroid); + @Override + public BinarizedByteVectorValues copy() throws IOException { + return new BinarizedFloatVectorValues(values.copy(), quantizer, centroid); + } + + private void binarize(int ord) throws IOException { + corrections = quantizer.quantizeForIndex(values.vectorValue(ord), binarized, centroid); + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); } } static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { private final RandomVectorScorerSupplier supplier; - private final RandomAccessVectorValues vectorValues; + private final KnnVectorValues vectorValues; private final Closeable onClose; - BinarizedCloseableRandomVectorScorerSupplier( - RandomVectorScorerSupplier supplier, - RandomAccessVectorValues vectorValues, - Closeable onClose - ) { + BinarizedCloseableRandomVectorScorerSupplier(RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) { this.supplier = supplier; this.onClose = onClose; this.vectorValues = vectorValues; @@ -932,7 +963,6 @@ public int totalVectorCount() { static final class NormalizedFloatVectorValues extends FloatVectorValues { private final FloatVectorValues values; private final float[] normalizedVector; - int curDoc = -1; NormalizedFloatVectorValues(FloatVectorValues values) { this.values = values; @@ -950,38 +980,25 @@ public int size() { } @Override - public float[] vectorValue() { - return normalizedVector; + public int ordToDoc(int ord) { + return values.ordToDoc(ord); } @Override - public VectorScorer scorer(float[] query) { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return values.docID(); + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + return normalizedVector; } @Override - public int nextDoc() throws IOException { - curDoc = values.nextDoc(); - if (curDoc != NO_MORE_DOCS) { - System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); - VectorUtil.l2normalize(normalizedVector); - } - return curDoc; + public DocIndexIterator iterator() { + return values.iterator(); } @Override - public int advance(int target) throws IOException { - curDoc = values.advance(target); - if (curDoc != NO_MORE_DOCS) { - System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); - VectorUtil.l2normalize(normalizedVector); - } - return curDoc; + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); } } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/OffHeapBinarizedVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/OffHeapBinarizedVectorValues.java index 2a3c3aca60e54..72934750305ce 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/OffHeapBinarizedVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/OffHeapBinarizedVectorValues.java @@ -36,7 +36,7 @@ import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; /** Binarized vector values loaded from off-heap */ -public abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues implements RandomAccessBinarizedByteVectorValues { +public abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues { protected final int dimension; protected final int size; @@ -109,7 +109,12 @@ public float getCentroidDP() { } @Override - public float[] getCorrectiveTerms() { + public float[] getCorrectiveTerms(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return correctiveValues; + } + slice.seek(((long) targetOrd * byteSize) + numBytes); + slice.readFloats(correctiveValues, 0, correctionsCount); return correctiveValues; } @@ -173,11 +178,6 @@ public float[] getCentroid() { return centroid; } - @Override - public IndexInput getSlice() { - return slice; - } - @Override public int getVectorByteLength() { return numBytes; @@ -230,8 +230,6 @@ public static OffHeapBinarizedVectorValues load( /** Dense off-heap binarized vector values */ public static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -245,30 +243,6 @@ public DenseOffHeapVectorValues( super(dimension, size, centroid, centroidDp, binaryQuantizer, similarityFunction, vectorsScorer, slice); } - @Override - public byte[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() { - return advance(doc + 1); - } - - @Override - public int advance(int target) { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; - } - @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( @@ -291,19 +265,25 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public VectorScorer scorer(float[] target) throws IOException { DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer scorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); return new VectorScorer() { @Override public float score() throws IOException { - return scorer.score(copy.doc); + return scorer.score(iterator.index()); } @Override public DocIdSetIterator iterator() { - return copy; + return iterator; } }; } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } } /** Sparse off-heap binarized vector values */ @@ -333,27 +313,6 @@ private static class SparseOffHeapVectorValues extends OffHeapBinarizedVectorVal this.disi = configuration.getIndexedDISI(dataIn); } - @Override - public byte[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); - } - @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( @@ -393,19 +352,25 @@ public int length() { }; } + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + @Override public VectorScorer scorer(float[] target) throws IOException { SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer scorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); return new VectorScorer() { @Override public float score() throws IOException { - return scorer.score(copy.disi.index()); + return scorer.score(iterator.index()); } @Override public DocIdSetIterator iterator() { - return copy; + return iterator; } }; } @@ -419,23 +384,8 @@ private static class EmptyOffHeapVectorValues extends OffHeapBinarizedVectorValu } @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() { - return advance(doc + 1); - } - - @Override - public int advance(int target) { - return doc = NO_MORE_DOCS; - } - - @Override - public byte[] vectorValue() { - throw new UnsupportedOperationException(); + public DocIndexIterator iterator() { + return createDenseIterator(); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/RandomAccessBinarizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/RandomAccessBinarizedByteVectorValues.java deleted file mode 100644 index 2417353373ba5..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/RandomAccessBinarizedByteVectorValues.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * @notice - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * Modifications copyright (C) 2024 Elasticsearch B.V. - */ -package org.elasticsearch.index.codec.vectors; - -import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; - -import java.io.IOException; - -/** - * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 - */ -public interface RandomAccessBinarizedByteVectorValues extends RandomAccessVectorValues.Bytes { - /** Returns the centroid distance for the vector */ - float getCentroidDistance(int vectorOrd) throws IOException; - - /** Returns the vector magnitude for the vector */ - float getVectorMagnitude(int vectorOrd) throws IOException; - - /** Returns OOQ corrective factor for the given vector ordinal */ - float getOOQ(int targetOrd) throws IOException; - - /** - * Returns the norm of the target vector w the centroid corrective factor for the given vector - * ordinal - */ - float getNormOC(int targetOrd) throws IOException; - - /** - * Returns the target vector dot product the centroid corrective factor for the given vector - * ordinal - */ - float getODotC(int targetOrd) throws IOException; - - /** - * @return the quantizer used to quantize the vectors - */ - BinaryQuantizer getQuantizer(); - - /** - * @return coarse grained centroids for the vectors - */ - float[] getCentroid() throws IOException; - - @Override - RandomAccessBinarizedByteVectorValues copy() throws IOException; - - default float getCentroidDP() throws IOException { - // this only gets executed on-merge - float[] centroid = getCentroid(); - return VectorUtil.dotProduct(centroid, centroid); - } -} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorerTests.java index 4ac66a9f63a3f..dc41a58fa2b8a 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorerTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryFlatVectorsScorerTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.index.codec.vectors; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.logging.LogConfigurator; @@ -61,7 +62,7 @@ public void testScore() throws IOException { new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, normVmC, vDotC) ); - RandomAccessBinarizedByteVectorValues targetVectors = new RandomAccessBinarizedByteVectorValues() { + BinarizedByteVectorValues targetVectors = new BinarizedByteVectorValues() { @Override public float getCentroidDistance(int vectorOrd) throws IOException { return random().nextFloat(0f, 1000f); @@ -99,7 +100,7 @@ public float[] getCentroid() throws IOException { } @Override - public RandomAccessBinarizedByteVectorValues copy() throws IOException { + public BinarizedByteVectorValues copy() throws IOException { return null; } @@ -115,6 +116,16 @@ public int size() { return 1; } + @Override + public VectorScorer scorer(float[] query) throws IOException { + return null; + } + + @Override + public float[] getCorrectiveTerms(int vectorOrd) throws IOException { + return new float[0]; + } + @Override public int dimension() { return dimensions; @@ -209,7 +220,7 @@ public void testScoreEuclidean() throws IOException { new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, 0f, 0f) ); - RandomAccessBinarizedByteVectorValues targetVectors = new RandomAccessBinarizedByteVectorValues() { + BinarizedByteVectorValues targetVectors = new BinarizedByteVectorValues() { @Override public float getCentroidDistance(int vectorOrd) { return 355.78073f; @@ -375,7 +386,7 @@ public float[] getCentroid() { } @Override - public RandomAccessBinarizedByteVectorValues copy() { + public BinarizedByteVectorValues copy() { return null; } @@ -389,6 +400,16 @@ public int size() { return 1; } + @Override + public VectorScorer scorer(float[] query) throws IOException { + return null; + } + + @Override + public float[] getCorrectiveTerms(int vectorOrd) throws IOException { + return new float[0]; + } + @Override public int dimension() { return dimensions; @@ -806,7 +827,7 @@ public void testScoreMIP() throws IOException { new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, normVmC, vDotC) ); - RandomAccessBinarizedByteVectorValues targetVectors = new RandomAccessBinarizedByteVectorValues() { + BinarizedByteVectorValues targetVectors = new BinarizedByteVectorValues() { @Override public float getCentroidDistance(int vectorOrd) { return 0f; @@ -1617,7 +1638,7 @@ public float[] getCentroid() { } @Override - public RandomAccessBinarizedByteVectorValues copy() { + public BinarizedByteVectorValues copy() { return null; } @@ -1727,6 +1748,16 @@ public int size() { return 1; } + @Override + public VectorScorer scorer(float[] query) throws IOException { + return null; + } + + @Override + public float[] getCorrectiveTerms(int vectorOrd) throws IOException { + return new float[0]; + } + @Override public int dimension() { return dimensions; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormatTests.java index 0892436891ff1..42f2fbb383ac9 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormatTests.java @@ -22,7 +22,7 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.lucene912.Lucene912Codec; +import org.apache.lucene.codecs.lucene100.Lucene100Codec; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.DirectoryReader; @@ -30,6 +30,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; @@ -58,7 +59,7 @@ public class ES816BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormat @Override protected Codec getCodec() { - return new Lucene912Codec() { + return new Lucene100Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { return new ES816BinaryQuantizedVectorsFormat(); @@ -90,8 +91,8 @@ public void testSearch() throws Exception { float[] queryVector = randomVector(dims); Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); TopDocs collectedDocs = searcher.search(q, k); - assertEquals(k, collectedDocs.totalHits.value); - assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); } } } @@ -148,7 +149,7 @@ public void testQuantizedVectorsWriteAndRead() throws IOException { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); assertEquals(vectorValues.size(), numVectors); - OffHeapBinarizedVectorValues qvectorValues = ((ES816BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues) + BinarizedByteVectorValues qvectorValues = ((ES816BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues) .getQuantizedVectorValues(); float[] centroid = qvectorValues.getCentroid(); assertEquals(centroid.length, dims); @@ -159,13 +160,18 @@ public void testQuantizedVectorsWriteAndRead() throws IOException { if (similarityFunction == VectorSimilarityFunction.COSINE) { vectorValues = new ES816BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); } - - while (vectorValues.nextDoc() != NO_MORE_DOCS) { - float[] corrections = quantizer.quantizeForIndex(vectorValues.vectorValue(), expectedVector, centroid); - assertArrayEquals(expectedVector, qvectorValues.vectorValue()); - assertEquals(corrections.length, qvectorValues.getCorrectiveTerms().length); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + float[] corrections = quantizer.quantizeForIndex( + vectorValues.vectorValue(docIndexIterator.index()), + expectedVector, + centroid + ); + assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); + assertEquals(corrections.length, qvectorValues.getCorrectiveTerms(docIndexIterator.index()).length); for (int i = 0; i < corrections.length; i++) { - assertEquals(corrections[i], qvectorValues.getCorrectiveTerms()[i], 0.00001f); + assertEquals(corrections[i], qvectorValues.getCorrectiveTerms(docIndexIterator.index())[i], 0.00001f); } } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormatTests.java index f607de57e1fd5..ca96e093b7b28 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormatTests.java @@ -22,7 +22,7 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.lucene912.Lucene912Codec; +import org.apache.lucene.codecs.lucene100.Lucene100Codec; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; @@ -30,6 +30,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.TopDocs; @@ -55,7 +56,7 @@ public class ES816HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFo @Override protected Codec getCodec() { - return new Lucene912Codec() { + return new Lucene100Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { return new ES816HnswBinaryQuantizedVectorsFormat(); @@ -91,12 +92,13 @@ public void testSingleVectorCase() throws Exception { try (IndexReader reader = DirectoryReader.open(w)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); assert (vectorValues.size() == 1); - while (vectorValues.nextDoc() != NO_MORE_DOCS) { - assertArrayEquals(vector, vectorValues.vectorValue(), 0.00001f); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); } TopDocs td = r.searchNearestVectors("f", randomVector(vector.length), 1, null, Integer.MAX_VALUE); - assertEquals(1, td.totalHits.value); + assertEquals(1, td.totalHits.value()); assertTrue(td.scoreDocs[0].score >= 0); } }