Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,7 @@ private static void score(
PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring)
throws IOException {
FieldEntry entry = fields.get(fieldInfo.number);
final int maxPostingListSize = indexInput.readVInt();
return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, maxPostingListSize, needsScoring);
return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, needsScoring);
}

@Override
Expand Down Expand Up @@ -341,7 +340,6 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
IndexInput indexInput,
FieldEntry entry,
FieldInfo fieldInfo,
int maxPostingListSize,
IntPredicate needsScoring
) throws IOException {
this.target = target;
Expand All @@ -358,7 +356,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
quantizedVectorByteSize = (discretizedDimensions / 8);
quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction(), DEFAULT_LAMBDA, 1);
osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension());
this.docIdsScratch = new int[maxPostingListSize];
this.docIdsScratch = new int[ES91OSQVectorsScorer.BULK_SIZE];
}

@Override
Expand All @@ -369,25 +367,23 @@ public int resetPostingsScorer(long offset) throws IOException {
centroidDp = Float.intBitsToFloat(indexInput.readInt());
vectors = indexInput.readVInt();
// read the doc ids
assert vectors <= docIdsScratch.length;
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);
slicePos = indexInput.getFilePointer();
return vectors;
}

private float scoreIndividually(int offset) throws IOException {
private float scoreIndividually() throws IOException {
float maxScore = Float.NEGATIVE_INFINITY;
// score individually, first the quantized byte chunk
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[j + offset];
int doc = docIdsScratch[j];
if (doc != -1) {
indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize));
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
scores[j] = qcDist;
} else {
indexInput.skipBytes(quantizedVectorByteSize);
}
}
// read in all corrections
indexInput.seek(slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize));
indexInput.readFloats(correctionsLower, 0, BULK_SIZE);
indexInput.readFloats(correctionsUpper, 0, BULK_SIZE);
for (int j = 0; j < BULK_SIZE; j++) {
Expand All @@ -396,7 +392,7 @@ private float scoreIndividually(int offset) throws IOException {
indexInput.readFloats(correctionsAdd, 0, BULK_SIZE);
// Now apply corrections
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[offset + j];
int doc = docIdsScratch[j];
if (doc != -1) {
scores[j] = osqVectorsScorer.score(
queryCorrections.lowerInterval(),
Expand All @@ -419,21 +415,20 @@ private float scoreIndividually(int offset) throws IOException {
return maxScore;
}

private static int docToBulkScore(int[] docIds, int offset, IntPredicate needsScoring) {
private static int docToBulkScore(int[] docIds, IntPredicate needsScoring) {
int docToScore = ES91OSQVectorsScorer.BULK_SIZE;
for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
final int idx = offset + i;
if (needsScoring.test(docIds[idx]) == false) {
docIds[idx] = -1;
if (needsScoring.test(docIds[i]) == false) {
docIds[i] = -1;
docToScore--;
}
}
return docToScore;
}

private static void collectBulk(int[] docIds, int offset, KnnCollector knnCollector, float[] scores) {
private static void collectBulk(int[] docIds, KnnCollector knnCollector, float[] scores) {
for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
final int doc = docIds[offset + i];
final int doc = docIds[i];
if (doc != -1) {
knnCollector.collect(doc, scores[i]);
}
Expand All @@ -442,20 +437,23 @@ private static void collectBulk(int[] docIds, int offset, KnnCollector knnCollec

@Override
public int visit(KnnCollector knnCollector) throws IOException {
indexInput.seek(slicePos);
// block processing
int scoredDocs = 0;
int limit = vectors - BULK_SIZE + 1;
int i = 0;

for (; i < limit; i += BULK_SIZE) {
final int docsToBulkScore = docToBulkScore(docIdsScratch, i, needsScoring);
docIdsWriter.readInts(indexInput, BULK_SIZE, docIdsScratch);
final int docsToBulkScore = docToBulkScore(docIdsScratch, needsScoring);
if (docsToBulkScore == 0) {
indexInput.skipBytes(BULK_SIZE * quantizedByteLength);
continue;
}
quantizeQueryIfNecessary();
indexInput.seek(slicePos + i * quantizedByteLength);
final float maxScore;
if (docsToBulkScore < BULK_SIZE / 2) {
maxScore = scoreIndividually(i);
maxScore = scoreIndividually();
} else {
maxScore = osqVectorsScorer.scoreBulk(
quantizedQueryScratch,
Expand All @@ -469,16 +467,17 @@ public int visit(KnnCollector knnCollector) throws IOException {
);
}
if (knnCollector.minCompetitiveSimilarity() < maxScore) {
collectBulk(docIdsScratch, i, knnCollector, scores);
collectBulk(docIdsScratch, knnCollector, scores);
}
scoredDocs += docsToBulkScore;
}
// process tail
for (; i < vectors; i++) {
int doc = docIdsScratch[i];
int tailLength = vectors - i;
docIdsWriter.readInts(indexInput, tailLength, docIdsScratch);
for (int j = 0; j < tailLength; j++) {
int doc = docIdsScratch[j];
if (needsScoring.test(doc)) {
quantizeQueryIfNecessary();
indexInput.seek(slicePos + i * quantizedByteLength);
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
indexInput.readFloats(correctiveValues, 0, 3);
final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort());
Expand All @@ -497,6 +496,8 @@ public int visit(KnnCollector knnCollector) throws IOException {
);
scoredDocs++;
knnCollector.collect(doc, score);
} else {
indexInput.skipBytes(quantizedByteLength);
}
}
if (scoredDocs > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,9 @@ LongValues buildAndWritePostingsLists(
}
}

int maxPostingListSize = 0;
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
for (int c = 0; c < centroidSupplier.size(); c++) {
int size = centroidVectorCount[c];
maxPostingListSize = Math.max(maxPostingListSize, size);
assignmentsByCluster[c] = new int[size];
}
Arrays.fill(centroidVectorCount, 0);
Expand All @@ -97,11 +95,8 @@ LongValues buildAndWritePostingsLists(
}
}
}
// write the max posting list size
postingsOutput.writeVInt(maxPostingListSize);
// write the posting lists
final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
DocIdsWriter docIdsWriter = new DocIdsWriter();
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
floatVectorValues,
Expand All @@ -125,9 +120,8 @@ LongValues buildAndWritePostingsLists(
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
// write vectors
bulkWriter.writeVectors(onHeapQuantizedVectors);
bulkWriter.writeVectors(onHeapQuantizedVectors, j -> floatVectorValues.ordToDoc(cluster[j]));
}

if (logger.isDebugEnabled()) {
Expand Down Expand Up @@ -203,12 +197,10 @@ LongValues buildAndWritePostingsLists(
}
}

int maxPostingListSize = 0;
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][];
for (int c = 0; c < centroidSupplier.size(); c++) {
int size = centroidVectorCount[c];
maxPostingListSize = Math.max(maxPostingListSize, size);
assignmentsByCluster[c] = new int[size];
isOverspillByCluster[c] = new boolean[size];
}
Expand All @@ -233,11 +225,8 @@ LongValues buildAndWritePostingsLists(
quantizedVectorsInput,
fieldInfo.getVectorDimension()
);
DocIdsWriter docIdsWriter = new DocIdsWriter();
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
// write the max posting list size
postingsOutput.writeVInt(maxPostingListSize);
// write the posting lists
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
Expand All @@ -256,9 +245,8 @@ LongValues buildAndWritePostingsLists(
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
// write vectors
bulkWriter.writeVectors(offHeapQuantizedVectors);
bulkWriter.writeVectors(offHeapQuantizedVectors, j -> floatVectorValues.ordToDoc(cluster[j]));
}

if (logger.isDebugEnabled()) {
Expand Down Expand Up @@ -342,7 +330,7 @@ private void writeCentroidsWithParents(
osq,
globalCentroid
);
bulkWriter.writeVectors(parentQuantizeCentroid);
bulkWriter.writeVectors(parentQuantizeCentroid, null);
int offset = 0;
for (int i = 0; i < centroidGroups.centroids().length; i++) {
centroidOutput.writeInt(offset);
Expand All @@ -359,7 +347,7 @@ private void writeCentroidsWithParents(
for (int i = 0; i < centroidGroups.centroids().length; i++) {
final int[] centroidAssignments = centroidGroups.vectors()[i];
childrenQuantizeCentroid.reset(idx -> centroidAssignments[idx], centroidAssignments.length);
bulkWriter.writeVectors(childrenQuantizeCentroid);
bulkWriter.writeVectors(childrenQuantizeCentroid, null);
}
// write the centroid offsets at the end of the file
for (int i = 0; i < centroidGroups.centroids().length; i++) {
Expand Down Expand Up @@ -389,7 +377,7 @@ private void writeCentroidsWithoutParents(
osq,
globalCentroid
);
bulkWriter.writeVectors(quantizedCentroids);
bulkWriter.writeVectors(quantizedCentroids, null);
// write the centroid offsets at the end of the file
for (int i = 0; i < centroidSupplier.size(); i++) {
centroidOutput.writeLong(offsets.get(i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.hnsw.IntToIntFunction;

import java.io.IOException;

Expand All @@ -27,21 +28,24 @@ protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) {
this.out = out;
}

abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException;
abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv, IntToIntFunction docIds) throws IOException;

static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
protected DocIdsWriter docIdsWriter = new DocIdsWriter();

OneBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) {
super(bulkSize, out);
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
}

@Override
void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException {
void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv, IntToIntFunction docIds) throws IOException {
int limit = qvv.count() - bulkSize + 1;
int i = 0;
for (; i < limit; i += bulkSize) {
int offset = i;
docIdsWriter.writeDocIds(idx -> docIds.apply(offset + idx), bulkSize, out);
for (int j = 0; j < bulkSize; j++) {
byte[] qv = qvv.next();
corrections[j] = qvv.getCorrections();
Expand All @@ -50,6 +54,8 @@ void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOEx
writeCorrections(corrections);
}
// write tail
int offset = i;
docIdsWriter.writeDocIds(idx -> docIds.apply(offset + idx), qvv.count() - i, out);
for (; i < qvv.count(); ++i) {
byte[] qv = qvv.next();
OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections();
Expand Down Expand Up @@ -94,7 +100,8 @@ static class SevenBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
}

@Override
void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException {
void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv, IntToIntFunction docIds) throws IOException {
assert docIds == null;
int limit = qvv.count() - bulkSize + 1;
int i = 0;
for (; i < limit; i += bulkSize) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
Expand Down Expand Up @@ -145,6 +147,30 @@ public void testSimpleOffHeapSize() throws IOException {
}
}

public void testSameVectorManyTimes() throws IOException {
float[] vector = randomVector(random().nextInt(12, 500));
try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
for (int i = 0; i < 10_000; i++) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc);
}
w.commit();
if (rarely()) {
w.forceMerge(1);
}
try (IndexReader reader = DirectoryReader.open(w)) {
List<LeafReaderContext> subReaders = reader.leaves();
for (LeafReaderContext r : subReaders) {
LeafReader leafReader = r.reader();
TopDocs topDocs = leafReader.searchNearestVectors("f", vector, 10, leafReader.getLiveDocs(), Integer.MAX_VALUE);
assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length);
}

}
}
}

// this is a modified version of lucene's TestSearchWithThreads test case
public void testWithThreads() throws Exception {
final int numThreads = random().nextInt(2, 5);
Expand Down