Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ public void scoreFromArray(Blackhole bh) throws IOException {
in.readFloats(corrections, 0, corrections.length);
int addition = Short.toUnsignedInt(in.readShort());
float score = scorer.score(
result,
result.lowerInterval(),
result.upperInterval(),
result.quantizedComponentSum(),
result.additionalCorrection(),
VectorSimilarityFunction.EUCLIDEAN,
centroidDp,
corrections[0],
Expand All @@ -150,7 +153,10 @@ public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException {
in.readFloats(corrections, 0, corrections.length);
int addition = Short.toUnsignedInt(in.readShort());
float score = scorer.score(
result,
result.lowerInterval(),
result.upperInterval(),
result.quantizedComponentSum(),
result.additionalCorrection(),
VectorSimilarityFunction.EUCLIDEAN,
centroidDp,
corrections[0],
Expand All @@ -175,7 +181,10 @@ public void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOExceptio
in.readFloats(corrections, 0, corrections.length);
int addition = Short.toUnsignedInt(in.readShort());
float score = scorer.score(
result,
result.lowerInterval(),
result.upperInterval(),
result.quantizedComponentSum(),
result.additionalCorrection(),
VectorSimilarityFunction.EUCLIDEAN,
centroidDp,
corrections[0],
Expand All @@ -196,7 +205,16 @@ public void scoreFromMemorySegmentAllBulk(Blackhole bh) throws IOException {
for (int j = 0; j < numQueries; j++) {
in.seek(0);
for (int i = 0; i < numVectors; i += 16) {
scorer.scoreBulk(binaryQueries[j], result, VectorSimilarityFunction.EUCLIDEAN, centroidDp, scratchScores);
scorer.scoreBulk(
binaryQueries[j],
result.lowerInterval(),
result.upperInterval(),
result.quantizedComponentSum(),
result.additionalCorrection(),
VectorSimilarityFunction.EUCLIDEAN,
centroidDp,
scratchScores
);
bh.consume(scratchScores);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOExce
* Computes the score by applying the necessary corrections to the provided quantized distance.
*/
public float score(
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp,
float lowerInterval,
Expand All @@ -107,19 +110,19 @@ public float score(
float ax = lowerInterval;
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = upperInterval - ax;
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
// For euclidean, we need to invert the score and apply the additional correction, which is
// assumed to be the squared l2norm of the centroid centered vectors.
if (similarityFunction == EUCLIDEAN) {
score = queryCorrections.additionalCorrection() + additionalCorrection - 2 * score;
score = queryAdditionalCorrection + additionalCorrection - 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score += queryCorrections.additionalCorrection() + additionalCorrection - centroidDp;
score += queryAdditionalCorrection + additionalCorrection - centroidDp;
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
Expand All @@ -140,7 +143,10 @@ public float score(
*/
public void scoreBulk(
byte[] q,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp,
float[] scores
Expand All @@ -154,7 +160,10 @@ public void scoreBulk(
in.readFloats(additionalCorrections, 0, BULK_SIZE);
for (int i = 0; i < BULK_SIZE; i++) {
scores[i] = score(
queryCorrections,
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
lowerIntervals[i],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;

import java.io.IOException;
Expand Down Expand Up @@ -298,7 +297,10 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO
@Override
public void scoreBulk(
byte[] q,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp,
float[] scores
Expand All @@ -307,19 +309,49 @@ public void scoreBulk(
// 128 / 8 == 16
if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
score256Bulk(q, queryCorrections, similarityFunction, centroidDp, scores);
score256Bulk(
q,
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
scores
);
return;
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
score128Bulk(q, queryCorrections, similarityFunction, centroidDp, scores);
score128Bulk(
q,
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
scores
);
return;
}
}
super.scoreBulk(q, queryCorrections, similarityFunction, centroidDp, scores);
super.scoreBulk(
q,
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
scores
);
}

private void score128Bulk(
byte[] q,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp,
float[] scores
Expand All @@ -328,9 +360,9 @@ private void score128Bulk(
int limit = FLOAT_SPECIES_128.loopBound(BULK_SIZE);
int i = 0;
long offset = in.getFilePointer();
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
for (; i < limit; i += FLOAT_SPECIES_128.length()) {
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var lx = FloatVector.fromMemorySegment(
Expand Down Expand Up @@ -362,13 +394,13 @@ private void score128Bulk(
// For euclidean, we need to invert the score and apply the additional correction, which is
// assumed to be the squared l2norm of the centroid centered vectors.
if (similarityFunction == EUCLIDEAN) {
res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f);
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0);
res.intoArray(scores, i);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp);
res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp);
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
res.intoArray(scores, i);
// not sure how to do it better
Expand All @@ -386,7 +418,10 @@ private void score128Bulk(

private void score256Bulk(
byte[] q,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp,
float[] scores
Expand All @@ -395,9 +430,9 @@ private void score256Bulk(
int limit = FLOAT_SPECIES_256.loopBound(BULK_SIZE);
int i = 0;
long offset = in.getFilePointer();
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
for (; i < limit; i += FLOAT_SPECIES_256.length()) {
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var lx = FloatVector.fromMemorySegment(
Expand Down Expand Up @@ -429,13 +464,13 @@ private void score256Bulk(
// For euclidean, we need to invert the score and apply the additional correction, which is
// assumed to be the squared l2norm of the centroid centered vectors.
if (similarityFunction == EUCLIDEAN) {
res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f);
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0);
res.intoArray(scores, i);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp);
res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp);
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
res.intoArray(scores, i);
// not sure how to do it better
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ESVectorUtil;
Expand All @@ -31,8 +30,8 @@
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte;
import static org.elasticsearch.index.codec.vectors.BQSpaceUtils.transposeHalfByte;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE;

/**
Expand All @@ -47,13 +46,8 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
}

@Override
CentroidQueryScorer getCentroidScorer(
FieldInfo fieldInfo,
int numCentroids,
IndexInput centroids,
float[] targetQuery,
IndexInput clusters
) throws IOException {
CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
throws IOException {
FieldEntry fieldEntry = fields.get(fieldInfo.number);
float[] globalCentroid = fieldEntry.globalCentroid();
float globalCentroidDp = fieldEntry.globalCentroidDp();
Expand Down Expand Up @@ -259,7 +253,10 @@ void scoreIndividually(int offset) throws IOException {
int doc = docIdsScratch[offset + j];
if (doc != -1) {
scores[j] = osqVectorsScorer.score(
queryCorrections,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
queryCorrections.additionalCorrection(),
fieldInfo.getVectorSimilarityFunction(),
centroidDp,
correctionsLower[j],
Expand Down Expand Up @@ -297,7 +294,10 @@ public int visit(KnnCollector knnCollector) throws IOException {
} else {
osqVectorsScorer.scoreBulk(
quantizedQueryScratch,
queryCorrections,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
queryCorrections.additionalCorrection(),
fieldInfo.getVectorSimilarityFunction(),
centroidDp,
scores
Expand All @@ -321,7 +321,10 @@ public int visit(KnnCollector knnCollector) throws IOException {
indexInput.readFloats(correctiveValues, 0, 3);
final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort());
float score = osqVectorsScorer.score(
queryCorrections,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
queryCorrections.additionalCorrection(),
fieldInfo.getVectorSimilarityFunction(),
centroidDp,
correctiveValues[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
import org.elasticsearch.logging.LogManager;
Expand All @@ -30,8 +29,8 @@
import java.nio.ByteOrder;

import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;

/**
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;

Expand Down Expand Up @@ -109,13 +108,4 @@ public String toString() {
return "IVFVectorsFormat(" + "vectorPerCluster=" + vectorPerCluster + ')';
}

static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) {
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
vectorsReader = candidateReader.getFieldReader(fieldName);
}
if (vectorsReader instanceof IVFVectorsReader reader) {
return reader;
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,8 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR
}
}

abstract CentroidQueryScorer getCentroidScorer(
FieldInfo fieldInfo,
int numCentroids,
IndexInput centroids,
float[] target,
IndexInput clusters
) throws IOException;
abstract CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
throws IOException;

private static IndexInput openDataInput(
SegmentReadState state,
Expand Down Expand Up @@ -249,8 +244,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
fieldInfo,
entry.postingListOffsets.length,
entry.centroidSlice(ivfCentroids),
target,
ivfClusters
target
);
if (nProbe == DYNAMIC_NPROBE) {
// empirically based, and a good dynamic to get decent recall while scaling a la "efSearch"
Expand Down
Loading