Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
}

@Override
CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
throws IOException {
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
final float globalCentroidDp = fieldEntry.globalCentroidDp();
Expand All @@ -65,90 +65,69 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
quantized[i] = (byte) scratch[i];
}
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
return new CentroidQueryScorer() {
int currentCentroid = -1;
long postingListOffset;
private final float[] centroidCorrectiveValues = new float[3];
private final long quantizeCentroidsLength = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES
+ Short.BYTES);

NeighborQueue queue = new NeighborQueue(fieldEntry.numCentroids(), true);
centroids.seek(0L);
final float[] centroidCorrectiveValues = new float[3];
for (int i = 0; i < numCentroids; i++) {
final float qcDist = scorer.int4DotProduct(quantized);
centroids.readFloats(centroidCorrectiveValues, 0, 3);
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
float score = int4QuantizedScore(
qcDist,
queryParams,
fieldInfo.getVectorDimension(),
centroidCorrectiveValues,
quantizedCentroidComponentSum,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction()
);
queue.add(i, score);
}
final long offset = centroids.getFilePointer();
return new CentroidIterator() {
@Override
public int size() {
return numCentroids;
public boolean hasNext() {
return queue.size() > 0;
}

@Override
public long postingListOffset(int centroidOrdinal) throws IOException {
if (centroidOrdinal != currentCentroid) {
centroids.seek(quantizeCentroidsLength + (long) Long.BYTES * centroidOrdinal);
postingListOffset = centroids.readLong();
currentCentroid = centroidOrdinal;
}
return postingListOffset;
}

public void bulkScore(NeighborQueue queue) throws IOException {
// TODO: bulk score centroids like we do with posting lists
centroids.seek(0L);
for (int i = 0; i < numCentroids; i++) {
queue.add(i, score());
}
}

private float score() throws IOException {
final float qcDist = scorer.int4DotProduct(quantized);
centroids.readFloats(centroidCorrectiveValues, 0, 3);
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
return int4QuantizedScore(
qcDist,
queryParams,
fieldInfo.getVectorDimension(),
centroidCorrectiveValues,
quantizedCentroidComponentSum,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction()
);
}

// TODO can we do this in off-heap blocks?
private float int4QuantizedScore(
float qcDist,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
int dims,
float[] targetCorrections,
int targetComponentSum,
float centroidDp,
VectorSimilarityFunction similarityFunction
) {
float ax = targetCorrections[0];
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
if (similarityFunction == EUCLIDEAN) {
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
return Math.max((1f + score) / 2f, 0);
}
public long nextPostingListOffset() throws IOException {
int centroidOrdinal = queue.pop();
centroids.seek(offset + (long) Long.BYTES * centroidOrdinal);
return centroids.readLong();
}
};
}

@Override
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
throws IOException {
NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
centroidQueryScorer.bulkScore(neighborQueue);
return neighborQueue;
// TODO can we do this in off-heap blocks?
private float int4QuantizedScore(
float qcDist,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
int dims,
float[] targetCorrections,
int targetComponentSum,
float centroidDp,
VectorSimilarityFunction similarityFunction
) {
float ax = targetCorrections[0];
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've noticed this while going through the code with @john-wagster , the comment says scaling isn't necessary but then we do scaling, either the comment is outdated or we have a bug (seen it in other parts of the code).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy paster error in the comment

float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
if (similarityFunction == EUCLIDEAN) {
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
return Math.max((1f + score) / 2f, 0);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;

Expand Down Expand Up @@ -89,7 +88,7 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR
}
}

abstract CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
abstract CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
throws IOException;

private static IndexInput openDataInput(
Expand Down Expand Up @@ -236,22 +235,16 @@ public final void search(String field, float[] target, KnnCollector knnCollector
}

FieldEntry entry = fields.get(fieldInfo.number);
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
fieldInfo,
entry.numCentroids,
entry.centroidSlice(ivfCentroids),
target
);
if (nProbe == DYNAMIC_NPROBE) {
// empirically based, and a good dynamic to get decent recall while scaling a la "efSearch"
// scaling by the number of centroids vs. the nearest neighbors requested
// not perfect, but a comparative heuristic.
// we might want to utilize the total vector count as well, but this is a good start
nProbe = (int) Math.round(Math.log10(centroidQueryScorer.size()) * Math.sqrt(knnCollector.k()));
nProbe = (int) Math.round(Math.log10(entry.numCentroids) * Math.sqrt(knnCollector.k()));
// clip to be between 1 and the number of centroids
nProbe = Math.max(Math.min(nProbe, centroidQueryScorer.size()), 1);
nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1);
}
final NeighborQueue centroidQueue = scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe);
CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target);
PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring);
int centroidsVisited = 0;
long expectedDocs = 0;
Expand All @@ -260,22 +253,22 @@ public final void search(String field, float[] target, KnnCollector knnCollector
// Note, numCollected is doing the bare minimum here.
// TODO do we need to handle nested doc counts similarly to how we handle
// filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors?
while (centroidQueue.size() > 0 && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) {
while (centroidIterator.hasNext() && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) {
++centroidsVisited;
// todo do we actually need to know the score???
int centroidOrdinal = centroidQueue.pop();
long offset = centroidIterator.nextPostingListOffset();
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing
// is enough?
expectedDocs += scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal));
expectedDocs += scorer.resetPostingsScorer(offset);
actualDocs += scorer.visit(knnCollector);
}
if (acceptDocs != null) {
float unfilteredRatioVisited = (float) expectedDocs / numVectors;
int filteredVectors = (int) Math.ceil(numVectors * percentFiltered);
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
int centroidOrdinal = centroidQueue.pop();
scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal));
while (centroidIterator.hasNext() && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
long offset = centroidIterator.nextPostingListOffset();
scorer.resetPostingsScorer(offset);
actualDocs += scorer.visit(knnCollector);
}
}
Expand All @@ -294,13 +287,6 @@ public final void search(String field, byte[] target, KnnCollector knnCollector,
}
}

abstract NeighborQueue scorePostingLists(
FieldInfo fieldInfo,
KnnCollector knnCollector,
CentroidQueryScorer centroidQueryScorer,
int nProbe
) throws IOException;

@Override
public void close() throws IOException {
IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters);
Expand All @@ -323,12 +309,10 @@ IndexInput centroidSlice(IndexInput centroidFile) throws IOException {
abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring)
throws IOException;

interface CentroidQueryScorer {
int size();

long postingListOffset(int centroidOrdinal) throws IOException;
interface CentroidIterator {
boolean hasNext();

void bulkScore(NeighborQueue queue) throws IOException;
long nextPostingListOffset() throws IOException;
}

interface PostingVisitor {
Expand Down