Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ private static String formatIndexPath(CmdLineArgs args) {
static Codec createCodec(CmdLineArgs args) {
final KnnVectorsFormat format;
if (args.indexType() == IndexType.IVF) {
format = new IVFVectorsFormat(args.ivfClusterSize());
format = new IVFVectorsFormat(args.ivfClusterSize(), IVFVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER);
} else {
if (args.quantizeBits() == 1) {
if (args.indexType() == IndexType.FLAT) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@

import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
import static org.elasticsearch.index.codec.vectors.BQSpaceUtils.transposeHalfByte;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.DEFAULT_LAMBDA;
Expand All @@ -41,7 +39,9 @@
* brute force and then scores the top ones using the posting list.
*/
public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats {
private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);

// The percentage of centroids that are scored to keep recall
public static final double CENTROID_SAMPLING_PERCENTAGE = 0.2;

public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
super(state, rawVectorsReader);
Expand All @@ -54,8 +54,12 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
final float globalCentroidDp = fieldEntry.globalCentroidDp();
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
final int[] scratch = new int[targetQuery.length];
float[] targetQueryCopy = ArrayUtil.copyArray(targetQuery);
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
VectorUtil.l2normalize(targetQueryCopy);
}
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
ArrayUtil.copyArray(targetQuery),
targetQueryCopy,
scratch,
(byte) 4,
fieldEntry.globalCentroid()
Expand All @@ -65,67 +69,227 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
quantized[i] = (byte) scratch[i];
}
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
NeighborQueue queue = new NeighborQueue(fieldEntry.numCentroids(), true);
centroids.seek(0L);
final float[] centroidCorrectiveValues = new float[3];
for (int i = 0; i < numCentroids; i++) {
final float qcDist = scorer.int4DotProduct(quantized);
centroids.readFloats(centroidCorrectiveValues, 0, 3);
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
float score = int4QuantizedScore(
qcDist,
int numParents = centroids.readVInt();
if (numParents > 0) {
return getCentroidIteratorWithParents(
fieldInfo,
centroids,
numParents,
numCentroids,
scorer,
quantized,
queryParams,
fieldInfo.getVectorDimension(),
centroidCorrectiveValues,
quantizedCentroidComponentSum,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction()
globalCentroidDp
);
queue.add(i, score);
}
final long offset = centroids.getFilePointer();
return getCentroidIteratorNoParent(fieldInfo, centroids, numCentroids, scorer, quantized, queryParams, globalCentroidDp);
}

private static CentroidIterator getCentroidIteratorNoParent(
FieldInfo fieldInfo,
IndexInput centroids,
int numCentroids,
ES91Int4VectorsScorer scorer,
byte[] quantizeQuery,
OptimizedScalarQuantizer.QuantizationResult queryParams,
float globalCentroidDp
) throws IOException {
final NeighborQueue neighborQueue = new NeighborQueue(numCentroids, true);
score(
neighborQueue,
numCentroids,
0,
scorer,
quantizeQuery,
queryParams,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction(),
new float[ES91Int4VectorsScorer.BULK_SIZE]
);
long offset = centroids.getFilePointer();
return new CentroidIterator() {
@Override
public boolean hasNext() {
return queue.size() > 0;
return neighborQueue.size() > 0;
}

@Override
public long nextPostingListOffset() throws IOException {
int centroidOrdinal = queue.pop();
int centroidOrdinal = neighborQueue.pop();
centroids.seek(offset + (long) Long.BYTES * centroidOrdinal);
return centroids.readLong();
}
};
}

// TODO can we do this in off-heap blocks?
private float int4QuantizedScore(
float qcDist,
private static CentroidIterator getCentroidIteratorWithParents(
FieldInfo fieldInfo,
IndexInput centroids,
int numParents,
int numCentroids,
ES91Int4VectorsScorer scorer,
byte[] quantizeQuery,
OptimizedScalarQuantizer.QuantizationResult queryParams,
float globalCentroidDp
) throws IOException {
// build the three queues we are going to use
final NeighborQueue parentsQueue = new NeighborQueue(numParents, true);
final int maxChildrenSize = centroids.readVInt();
final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true);
final int bufferSize = (int) Math.max(numCentroids * CENTROID_SAMPLING_PERCENTAGE, 1);
final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true);
// score the parents
final float[] scores = new float[ES91Int4VectorsScorer.BULK_SIZE];
score(
parentsQueue,
numParents,
0,
scorer,
quantizeQuery,
queryParams,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction(),
scores
);
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
final long offset = centroids.getFilePointer();
final long childrenOffset = offset + (long) Long.BYTES * numParents;
// populate the children's queue by reading parents one by one
while (parentsQueue.size() > 0 && neighborQueue.size() < bufferSize) {
final int pop = parentsQueue.pop();
populateOneChildrenGroup(
currentParentQueue,
centroids,
offset + 2L * Integer.BYTES * pop,
childrenOffset,
centroidQuantizeSize,
fieldInfo,
scorer,
quantizeQuery,
queryParams,
globalCentroidDp,
scores
);
while (currentParentQueue.size() > 0 && neighborQueue.size() < bufferSize) {
final float score = currentParentQueue.topScore();
final int children = currentParentQueue.pop();
neighborQueue.add(children, score);
}
}
final long childrenFileOffsets = childrenOffset + centroidQuantizeSize * numCentroids;
return new CentroidIterator() {
@Override
public boolean hasNext() {
return neighborQueue.size() > 0;
}

@Override
public long nextPostingListOffset() throws IOException {
int centroidOrdinal = neighborQueue.pop();
updateQueue(); // add one children if available so the queue remains fully populated
centroids.seek(childrenFileOffsets + (long) Long.BYTES * centroidOrdinal);
return centroids.readLong();
}

private void updateQueue() throws IOException {
if (currentParentQueue.size() > 0) {
// add a children from the current parent queue
float score = currentParentQueue.topScore();
int children = currentParentQueue.pop();
neighborQueue.add(children, score);
} else if (parentsQueue.size() > 0) {
// add a new parent from the current parent queue
int pop = parentsQueue.pop();
populateOneChildrenGroup(
currentParentQueue,
centroids,
offset + 2L * Integer.BYTES * pop,
childrenOffset,
centroidQuantizeSize,
fieldInfo,
scorer,
quantizeQuery,
queryParams,
globalCentroidDp,
scores
);
updateQueue();
}
}
};
}

private static void populateOneChildrenGroup(
NeighborQueue neighborQueue,
IndexInput centroids,
long parentOffset,
long childrenOffset,
long centroidQuantizeSize,
FieldInfo fieldInfo,
ES91Int4VectorsScorer scorer,
byte[] quantizeQuery,
OptimizedScalarQuantizer.QuantizationResult queryParams,
float globalCentroidDp,
float[] scores
) throws IOException {
centroids.seek(parentOffset);
int childrenOrdinal = centroids.readInt();
int numChildren = centroids.readInt();
centroids.seek(childrenOffset + centroidQuantizeSize * childrenOrdinal);
score(
neighborQueue,
numChildren,
childrenOrdinal,
scorer,
quantizeQuery,
queryParams,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction(),
scores
);
}

private static void score(
NeighborQueue neighborQueue,
int size,
int scoresOffset,
ES91Int4VectorsScorer scorer,
byte[] quantizeQuery,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
int dims,
float[] targetCorrections,
int targetComponentSum,
float centroidDp,
VectorSimilarityFunction similarityFunction
) {
float ax = targetCorrections[0];
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
if (similarityFunction == EUCLIDEAN) {
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
VectorSimilarityFunction similarityFunction,
float[] scores
) throws IOException {
int limit = size - ES91Int4VectorsScorer.BULK_SIZE + 1;
int i = 0;
for (; i < limit; i += ES91Int4VectorsScorer.BULK_SIZE) {
scorer.scoreBulk(
quantizeQuery,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
queryCorrections.additionalCorrection(),
similarityFunction,
centroidDp,
scores
);
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
neighborQueue.add(scoresOffset + i + j, scores[j]);
}
return Math.max((1f + score) / 2f, 0);
}

for (; i < size; i++) {
float score = scorer.score(
quantizeQuery,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
queryCorrections.additionalCorrection(),
similarityFunction,
centroidDp
);
neighborQueue.add(scoresOffset + i, score);
}
}

Expand Down
Loading