Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
Expand All @@ -24,6 +23,7 @@
import org.elasticsearch.simdvec.ESVectorUtil;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.function.IntPredicate;

Expand All @@ -32,7 +32,6 @@
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
import static org.elasticsearch.index.codec.vectors.BQSpaceUtils.transposeHalfByte;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.DEFAULT_LAMBDA;
import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE;

Expand All @@ -48,31 +47,118 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
}

@Override
CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
throws IOException {
CentroidIterator getCentroidIterator(
FieldInfo fieldInfo,
int numCentroids,
int numOversampled,
IndexInput centroids,
float[] targetQuery
) throws IOException {

final float[] centroidCorrectiveValues = new float[3];

// constants
final int discretizedDimensions = BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64);
final long oneBitQuantizeCentroidByteSize = (long) discretizedDimensions / 8;
final long oneBitQuantizeCentroidsLength = (long) numCentroids * (oneBitQuantizeCentroidByteSize + 3 * Float.BYTES + Short.BYTES);
final long quantizationCentroidByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
final long quantizeCentroidsLength = (long) numCentroids * quantizationCentroidByteSize;

final FieldEntry fieldEntry = fields.get(fieldInfo.number);
final float globalCentroidDp = fieldEntry.globalCentroidDp();

// quantize the query
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
final int[] scratch = new int[targetQuery.length];

// FIXME: do l2normalize here?
final float[] scratchTarget = new float[targetQuery.length];
System.arraycopy(targetQuery, 0, scratchTarget, 0, targetQuery.length);
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
Copy link
Contributor Author

@john-wagster john-wagster Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we weren't normalizing for COSINE here previously; bug maybe?

VectorUtil.l2normalize(scratchTarget);
}
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
ArrayUtil.copyArray(targetQuery),
scratchTarget,
scratch,
(byte) 4,
fieldEntry.globalCentroid()
);
final byte[] quantized = new byte[targetQuery.length];
for (int i = 0; i < quantized.length; i++) {
quantized[i] = (byte) scratch[i];

// pack the quantized value the way the one bit scorer expects it as 4 bits
final byte[] oneBitQuantized = new byte[QUERY_BITS * discretizedDimensions / 8];
transposeHalfByte(scratch, oneBitQuantized);

// pack the quantized value the way the four bit scorer expects it as single bytes instead of 4 bits
final byte[] fourBitQuantized = new byte[targetQuery.length];
for (int i = 0; i < fourBitQuantized.length; i++) {
fourBitQuantized[i] = (byte) scratch[i];
}
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
NeighborQueue queue = new NeighborQueue(fieldEntry.numCentroids(), true);
centroids.seek(0L);
final float[] centroidCorrectiveValues = new float[3];
for (int i = 0; i < numCentroids; i++) {
final float qcDist = scorer.int4DotProduct(quantized);

// setup to score the centroids
final ES91OSQVectorsScorer oneBitScorer = ESVectorUtil.getES91OSQVectorsScorer(centroids, fieldInfo.getVectorDimension());
final ES91Int4VectorsScorer fourBitScorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());

final NeighborQueue oneBitCentroidQueue = new NeighborQueue(fieldEntry.numCentroids(), true);
final NeighborQueue fourBitCentroidQueue = new NeighborQueue(numOversampled, true);

// FIXME: this adds complexity to this method consider moving it / discuss
OneBitCentroidScorer oneBitCentroidScorer = new OneBitCentroidScorer() {
@Override
public void bulkScore(NeighborQueue queue) throws IOException {
final float[] scores = new float[BULK_SIZE];

centroids.seek(0L);

// block processing
int limit = numCentroids - BULK_SIZE + 1;
int i = 0;
for (; i < limit; i += BULK_SIZE) {
oneBitScorer.scoreBulk(
oneBitQuantized,
queryParams.lowerInterval(),
queryParams.upperInterval(),
queryParams.quantizedComponentSum(),
queryParams.additionalCorrection(),
fieldInfo.getVectorSimilarityFunction(),
globalCentroidDp,
scores
);
for (int j = 0; j < BULK_SIZE; j++) {
queue.add(i + j, scores[j]);
}
}
// process tail
for (; i < numCentroids; i++) {
queue.add(i, score());
}
}

private float score() throws IOException {
final float qcDist = oneBitScorer.quantizeScore(oneBitQuantized);
centroids.readFloats(centroidCorrectiveValues, 0, 3);
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
return oneBitScorer.score(
queryParams.lowerInterval(),
queryParams.upperInterval(),
queryParams.quantizedComponentSum(),
queryParams.additionalCorrection(),
fieldInfo.getVectorSimilarityFunction(),
globalCentroidDp,
centroidCorrectiveValues[0],
centroidCorrectiveValues[1],
quantizedCentroidComponentSum,
centroidCorrectiveValues[2],
qcDist
);
}
};

FourBitCentroidScorer fourBitCentroidScorer = centroidOrdinal -> {
centroids.seek(oneBitQuantizeCentroidsLength + quantizationCentroidByteSize * centroidOrdinal);
final float qcDist = fourBitScorer.int4DotProduct(fourBitQuantized);
centroids.readFloats(centroidCorrectiveValues, 0, 3);
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
float score = int4QuantizedScore(
return int4QuantizedScore(
qcDist,
queryParams,
fieldInfo.getVectorDimension(),
Expand All @@ -81,24 +167,65 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction()
);
queue.add(i, score);
}
final long offset = centroids.getFilePointer();
};

// populate the first set of centroids
oneBitCentroidScorer.bulkScore(oneBitCentroidQueue);
populateCentroidQueue(numOversampled, oneBitCentroidQueue, fourBitCentroidQueue, fourBitCentroidScorer);

return new CentroidIterator() {
@Override
public boolean hasNext() {
return queue.size() > 0;
return oneBitCentroidQueue.size() + fourBitCentroidQueue.size() > 0;
}

@Override
public long nextPostingListOffset() throws IOException {
int centroidOrdinal = queue.pop();
centroids.seek(offset + (long) Long.BYTES * centroidOrdinal);
int centroidOrdinal = fourBitCentroidQueue.pop();
if (oneBitCentroidQueue.size() > 0) {
// TODO: it may be more efficient as far as disk reads to pop a set of ordinals,
// sort them, and do a batch read of for instance the next max(0.1f * rescoreSize, 1)
int centroidOrd = oneBitCentroidQueue.pop();
fourBitCentroidQueue.add(centroidOrd, fourBitCentroidScorer.score(centroidOrd));
}

centroids.seek(oneBitQuantizeCentroidsLength + quantizeCentroidsLength + (long) Long.BYTES * centroidOrdinal);
return centroids.readLong();
}
};
}

interface OneBitCentroidScorer {
void bulkScore(NeighborQueue queue) throws IOException;
}

interface FourBitCentroidScorer {
float score(int centroidOrdinal) throws IOException;
}

private static void populateCentroidQueue(
int rescoreSize,
NeighborQueue oneBitCentroidQueue,
NeighborQueue centroidQueue,
FourBitCentroidScorer centroidQueryScorer
) throws IOException {

if (oneBitCentroidQueue.size() == 0) {
return;
}

int[] centroidOrdinalsToRescore = new int[Math.min(rescoreSize, oneBitCentroidQueue.size())];
for (int i = 0; i < centroidOrdinalsToRescore.length; i++) {
centroidOrdinalsToRescore[i] = oneBitCentroidQueue.pop();
}
// do this sort so we are seeking on disk in order
Arrays.sort(centroidOrdinalsToRescore);

// TODO: bulk read the in chunks where possible, group up sets of contiguous ordinals
for (int i = 0; i < centroidOrdinalsToRescore.length; i++) {
centroidQueue.add(centroidOrdinalsToRescore[i], centroidQueryScorer.score(centroidOrdinalsToRescore[i]));
}
}

// TODO can we do this in off-heap blocks?
private float int4QuantizedScore(
float qcDist,
Expand Down Expand Up @@ -186,7 +313,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
centroid = new float[fieldInfo.getVectorDimension()];
scratch = new float[target.length];
quantizationScratch = new int[target.length];
final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
final int discretizedDimensions = BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64);
quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8];
quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES;
quantizedVectorByteSize = (discretizedDimensions / 8);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,44 @@ void writeCentroids(
LongValues offsets,
IndexOutput centroidOutput
) throws IOException {
// TODO do we want to store these distances as well for future use?
// TODO: sort centroids by global centroid (was doing so previously here)

final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
int[] quantizedScratch = new int[fieldInfo.getVectorDimension()];
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
final byte[] quantized = new byte[fieldInfo.getVectorDimension()];
// TODO do we want to store these distances as well for future use?
// TODO: sort centroids by global centroid (was doing so previously here)
// TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned
final byte[] oneBitQuantized = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8];
final byte[] fourBitquantized = new byte[fieldInfo.getVectorDimension()];

DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, centroidOutput);
bulkWriter.writeVectors(new QuantizedVectorValues() {
int currOrd = -1;
OptimizedScalarQuantizer.QuantizationResult corrections;

@Override
public int count() {
return centroidSupplier.size();
}

@Override
public byte[] next() throws IOException {
currOrd++;
float[] centroid = centroidSupplier.centroid(currOrd);
System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length);
this.corrections = osq.scalarQuantize(centroidScratch, quantizedScratch, (byte) 1, globalCentroid);
BQVectorUtils.packAsBinary(quantizedScratch, oneBitQuantized);
return oneBitQuantized;
}

@Override
public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
if (currOrd == -1) {
throw new IllegalStateException("No centroid read yet, call next first");
}
return corrections;
}
});

for (int i = 0; i < centroidSupplier.size(); i++) {
float[] centroid = centroidSupplier.centroid(i);
System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length);
Expand All @@ -306,9 +336,9 @@ void writeCentroids(
globalCentroid
);
for (int j = 0; j < quantizedScratch.length; j++) {
quantized[j] = (byte) quantizedScratch[j];
fourBitquantized[j] = (byte) quantizedScratch[j];
}
writeQuantizedValue(centroidOutput, quantized, result);
writeQuantizedValue(centroidOutput, fourBitquantized, result);
}
// write the centroid offsets at the end of the file
for (int i = 0; i < centroidSupplier.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ public class IVFVectorsFormat extends KnnVectorsFormat {
public static final int DEFAULT_VECTORS_PER_CLUSTER = 384;
public static final int MIN_VECTORS_PER_CLUSTER = 64;
public static final int MAX_VECTORS_PER_CLUSTER = 1 << 16; // 65536
// TODO: expose oversampling as a param?
public static final float NPROBE_OVERSAMPLE = 1f;

private final int vectorPerCluster;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_NPROBE;
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.NPROBE_OVERSAMPLE;

/**
* Reader for IVF vectors. This reader is used to read the IVF vectors from the index.
Expand Down Expand Up @@ -88,8 +89,13 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR
}
}

abstract CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
throws IOException;
abstract CentroidIterator getCentroidIterator(
FieldInfo fieldInfo,
int numCentroids,
int numOversampled,
IndexInput centroids,
float[] target
) throws IOException;

private static IndexInput openDataInput(
SegmentReadState state,
Expand Down Expand Up @@ -244,7 +250,15 @@ public final void search(String field, float[] target, KnnCollector knnCollector
// clip to be between 1 and the number of centroids
nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1);
}
CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target);

final int numOversampled = Math.min((int) (nProbe * NPROBE_OVERSAMPLE), entry.numCentroids());
Copy link
Contributor

@iverase iverase Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be add this computation in DefaultIVFVectors reader and pass nProbe to the method getCentroidIterator?. This is an implementation detail of the strategy you are implementing and should not leak here.

CentroidIterator centroidIterator = getCentroidIterator(
fieldInfo,
entry.numCentroids,
numOversampled,
entry.centroidSlice(ivfCentroids),
target
);
PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring);
int centroidsVisited = 0;
long expectedDocs = 0;
Expand Down