Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,9 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
final float[] correctionsUpper = new float[BULK_SIZE];
final int[] correctionsSum = new int[BULK_SIZE];
final float[] correctionsAdd = new float[BULK_SIZE];
final int[] docIdsScratch;
final int[] docIdsScratch = new int[BULK_SIZE];
byte docEncoding;
int docBase = 0;

int vectors;
boolean quantized = false;
Expand Down Expand Up @@ -415,7 +417,6 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
quantizedVectorByteSize = (discretizedDimensions / 8);
quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction(), DEFAULT_LAMBDA, 1);
osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension());
this.docIdsScratch = new int[maxPostingListSize];
}

@Override
Expand All @@ -425,24 +426,17 @@ public int resetPostingsScorer(long offset) throws IOException {
indexInput.readFloats(centroid, 0, centroid.length);
centroidDp = Float.intBitsToFloat(indexInput.readInt());
vectors = indexInput.readVInt();
// read the doc ids
assert vectors <= docIdsScratch.length;
idsWriter.readInts(indexInput, vectors, docIdsScratch);
// reconstitute from the deltas
int sum = 0;
for (int i = 0; i < vectors; i++) {
sum += docIdsScratch[i];
docIdsScratch[i] = sum;
}
docEncoding = indexInput.readByte();
docBase = 0;
slicePos = indexInput.getFilePointer();
return vectors;
}

private float scoreIndividually(int offset) throws IOException {
private float scoreIndividually() throws IOException {
float maxScore = Float.NEGATIVE_INFINITY;
// score individually, first the quantized byte chunk
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[j + offset];
int doc = docIdsScratch[j];
if (doc != -1) {
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
scores[j] = qcDist;
Expand All @@ -459,7 +453,7 @@ private float scoreIndividually(int offset) throws IOException {
indexInput.readFloats(correctionsAdd, 0, BULK_SIZE);
// Now apply corrections
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[offset + j];
int doc = docIdsScratch[j];
if (doc != -1) {
scores[j] = osqVectorsScorer.score(
queryCorrections.lowerInterval(),
Expand All @@ -482,45 +476,56 @@ private float scoreIndividually(int offset) throws IOException {
return maxScore;
}

private static int docToBulkScore(int[] docIds, int offset, Bits acceptDocs) {
private static int docToBulkScore(int[] docIds, Bits acceptDocs) {
assert acceptDocs != null : "acceptDocs must not be null";
int docToScore = ES91OSQVectorsScorer.BULK_SIZE;
for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
final int idx = offset + i;
if (acceptDocs.get(docIds[idx]) == false) {
docIds[idx] = -1;
if (acceptDocs.get(docIds[i]) == false) {
docIds[i] = -1;
docToScore--;
}
}
return docToScore;
}

private static void collectBulk(int[] docIds, int offset, KnnCollector knnCollector, float[] scores) {
private void collectBulk(KnnCollector knnCollector, float[] scores) {
for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
final int doc = docIds[offset + i];
final int doc = docIdsScratch[i];
if (doc != -1) {
knnCollector.collect(doc, scores[i]);
}
}
}

private void readDocIds(int count) throws IOException {
idsWriter.readInts(indexInput, count, docEncoding, docIdsScratch);
// reconstitute from the deltas
for (int j = 0; j < count; j++) {
docBase += docIdsScratch[j];
docIdsScratch[j] = docBase;
}
}

@Override
public int visit(KnnCollector knnCollector) throws IOException {
indexInput.seek(slicePos);
// block processing
int scoredDocs = 0;
int limit = vectors - BULK_SIZE + 1;
int i = 0;
// read Docs
for (; i < limit; i += BULK_SIZE) {
final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore(docIdsScratch, i, acceptDocs);
// read the doc ids
readDocIds(BULK_SIZE);
final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore(docIdsScratch, acceptDocs);
if (docsToBulkScore == 0) {
indexInput.skipBytes(quantizedByteLength * BULK_SIZE);
continue;
}
quantizeQueryIfNecessary();
final float maxScore;
if (docsToBulkScore < BULK_SIZE / 2) {
maxScore = scoreIndividually(i);
maxScore = scoreIndividually();
} else {
maxScore = osqVectorsScorer.scoreBulk(
quantizedQueryScratch,
Expand All @@ -534,13 +539,18 @@ public int visit(KnnCollector knnCollector) throws IOException {
);
}
if (knnCollector.minCompetitiveSimilarity() < maxScore) {
collectBulk(docIdsScratch, i, knnCollector, scores);
collectBulk(knnCollector, scores);
}
scoredDocs += docsToBulkScore;
}
// process tail
// read the doc ids
if (i < vectors) {
readDocIds(vectors - i);
}
int count = 0;
for (; i < vectors; i++) {
int doc = docIdsScratch[i];
int doc = docIdsScratch[count++];
if (acceptDocs == null || acceptDocs.get(doc)) {
quantizeQueryIfNecessary();
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ CentroidOffsetAndLength buildAndWritePostingsLists(
docDeltas[j] = j == 0 ? docIds[clusterOrds[j]] : docIds[clusterOrds[j]] - docIds[clusterOrds[j - 1]];
}
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[clusterOrds[ord]]);
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
idsWriter.writeDocIds(i -> docDeltas[i], size, postingsOutput);
// write vectors
bulkWriter.writeVectors(onHeapQuantizedVectors);
byte encoding = idsWriter.calculateBlockEncoding(i -> docDeltas[i], size, ES91OSQVectorsScorer.BULK_SIZE);
postingsOutput.writeByte(encoding);
bulkWriter.writeVectors(onHeapQuantizedVectors, i -> {
// for vector i we write `bulk` size docs or the remaining docs
idsWriter.writeDocIds(d -> docDeltas[i + d], Math.min(ES91OSQVectorsScorer.BULK_SIZE, size - i), encoding, postingsOutput);
});
lengths.add(postingsOutput.getFilePointer() - fileOffset - offset);
}

Expand Down Expand Up @@ -287,15 +287,20 @@ CentroidOffsetAndLength buildAndWritePostingsLists(
for (int j = 0; j < size; j++) {
docDeltas[j] = j == 0 ? docIds[clusterOrds[j]] : docIds[clusterOrds[j]] - docIds[clusterOrds[j - 1]];
}
byte encoding = idsWriter.calculateBlockEncoding(i -> docDeltas[i], size, ES91OSQVectorsScorer.BULK_SIZE);
postingsOutput.writeByte(encoding);
offHeapQuantizedVectors.reset(size, ord -> isOverspill[clusterOrds[ord]], ord -> cluster[clusterOrds[ord]]);
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
idsWriter.writeDocIds(i -> docDeltas[i], size, postingsOutput);
// write vectors
bulkWriter.writeVectors(offHeapQuantizedVectors);
bulkWriter.writeVectors(offHeapQuantizedVectors, i -> {
// for vector i we write `bulk` size docs or the remaining docs
idsWriter.writeDocIds(
d -> docDeltas[d + i],
Math.min(ES91OSQVectorsScorer.BULK_SIZE, size - i),
encoding,
postingsOutput
);
});
lengths.add(postingsOutput.getFilePointer() - fileOffset - offset);
// lengths.add(1);
}

if (logger.isDebugEnabled()) {
Expand Down Expand Up @@ -381,7 +386,7 @@ private void writeCentroidsWithParents(
osq,
globalCentroid
);
bulkWriter.writeVectors(parentQuantizeCentroid);
bulkWriter.writeVectors(parentQuantizeCentroid, null);
int offset = 0;
for (int i = 0; i < centroidGroups.centroids().length; i++) {
centroidOutput.writeInt(offset);
Expand All @@ -398,7 +403,7 @@ private void writeCentroidsWithParents(
for (int i = 0; i < centroidGroups.centroids().length; i++) {
final int[] centroidAssignments = centroidGroups.vectors()[i];
childrenQuantizeCentroid.reset(idx -> centroidAssignments[idx], centroidAssignments.length);
bulkWriter.writeVectors(childrenQuantizeCentroid);
bulkWriter.writeVectors(childrenQuantizeCentroid, null);
}
// write the centroid offsets at the end of the file
for (int i = 0; i < centroidGroups.centroids().length; i++) {
Expand Down Expand Up @@ -429,7 +434,7 @@ private void writeCentroidsWithoutParents(
osq,
globalCentroid
);
bulkWriter.writeVectors(quantizedCentroids);
bulkWriter.writeVectors(quantizedCentroids, null);
// write the centroid offsets at the end of the file
for (int i = 0; i < centroidSupplier.size(); i++) {
centroidOutput.writeLong(centroidOffsetAndLength.offsets().get(i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.search.CheckedIntConsumer;
import org.apache.lucene.store.IndexOutput;

import java.io.IOException;
Expand All @@ -27,7 +28,8 @@ protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) {
this.out = out;
}

abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException;
abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv, CheckedIntConsumer<IOException> docsWriter)
throws IOException;

static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
Expand All @@ -38,17 +40,24 @@ static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
}

@Override
void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException {
void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv, CheckedIntConsumer<IOException> docsWriter)
throws IOException {
int limit = qvv.count() - bulkSize + 1;
int i = 0;
for (; i < limit; i += bulkSize) {
if (docsWriter != null) {
docsWriter.accept(i);
}
for (int j = 0; j < bulkSize; j++) {
byte[] qv = qvv.next();
corrections[j] = qvv.getCorrections();
out.writeBytes(qv, qv.length);
}
writeCorrections(corrections);
}
if (i < qvv.count() && docsWriter != null) {
docsWriter.accept(i);
}
// write tail
for (; i < qvv.count(); ++i) {
byte[] qv = qvv.next();
Expand Down Expand Up @@ -94,7 +103,8 @@ static class SevenBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
}

@Override
void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException {
void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv, CheckedIntConsumer<IOException> docsWriter)
throws IOException {
int limit = qvv.count() - bulkSize + 1;
int i = 0;
for (; i < limit; i += bulkSize) {
Expand Down