Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ private static void score(
PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring)
throws IOException {
FieldEntry entry = fields.get(fieldInfo.number);
return new MemorySegmentPostingsVisitor(target, indexInput.clone(), entry, fieldInfo, needsScoring);
final int maxPostingListSize = indexInput.readVInt();
return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, maxPostingListSize, needsScoring);
}

@Override
Expand All @@ -318,8 +319,8 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
final float[] correctionsUpper = new float[BULK_SIZE];
final int[] correctionsSum = new int[BULK_SIZE];
final float[] correctionsAdd = new float[BULK_SIZE];
final int[] docIdsScratch;

int[] docIdsScratch = new int[0];
int vectors;
boolean quantized = false;
float centroidDp;
Expand All @@ -340,6 +341,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
IndexInput indexInput,
FieldEntry entry,
FieldInfo fieldInfo,
int maxPostingListSize,
IntPredicate needsScoring
) throws IOException {
this.target = target;
Expand All @@ -356,6 +358,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
quantizedVectorByteSize = (discretizedDimensions / 8);
quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction(), DEFAULT_LAMBDA, 1);
osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension());
this.docIdsScratch = new int[maxPostingListSize];
}

@Override
Expand All @@ -366,7 +369,7 @@ public int resetPostingsScorer(long offset) throws IOException {
centroidDp = Float.intBitsToFloat(indexInput.readInt());
vectors = indexInput.readVInt();
// read the doc ids
docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch;
assert vectors <= docIdsScratch.length;
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);
slicePos = indexInput.getFilePointer();
return vectors;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ LongValues buildAndWritePostingsLists(
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
long fileOffset,
int[] assignments,
int[] overspillAssignments
) throws IOException {
Expand All @@ -76,9 +77,12 @@ LongValues buildAndWritePostingsLists(
}
}

int maxPostingListSize = 0;
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
for (int c = 0; c < centroidSupplier.size(); c++) {
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
int size = centroidVectorCount[c];
maxPostingListSize = Math.max(maxPostingListSize, size);
assignmentsByCluster[c] = new int[size];
}
Arrays.fill(centroidVectorCount, 0);

Expand All @@ -93,6 +97,8 @@ LongValues buildAndWritePostingsLists(
}
}
}
// write the max posting list size
postingsOutput.writeVInt(maxPostingListSize);
// write the posting lists
final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
DocIdsWriter docIdsWriter = new DocIdsWriter();
Expand All @@ -106,7 +112,7 @@ LongValues buildAndWritePostingsLists(
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
int[] cluster = assignmentsByCluster[c];
offsets.add(postingsOutput.alignFilePointer(Float.BYTES));
offsets.add(postingsOutput.alignFilePointer(Float.BYTES) - fileOffset);
buffer.asFloatBuffer().put(centroid);
// write raw centroid for quantizing the query vectors
postingsOutput.writeBytes(buffer.array(), buffer.array().length);
Expand Down Expand Up @@ -137,6 +143,7 @@ LongValues buildAndWritePostingsLists(
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
long fileOffset,
MergeState mergeState,
int[] assignments,
int[] overspillAssignments
Expand Down Expand Up @@ -196,11 +203,14 @@ LongValues buildAndWritePostingsLists(
}
}

int maxPostingListSize = 0;
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][];
for (int c = 0; c < centroidSupplier.size(); c++) {
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
isOverspillByCluster[c] = new boolean[centroidVectorCount[c]];
int size = centroidVectorCount[c];
maxPostingListSize = Math.max(maxPostingListSize, size);
assignmentsByCluster[c] = new int[size];
isOverspillByCluster[c] = new boolean[size];
}
Arrays.fill(centroidVectorCount, 0);

Expand All @@ -226,11 +236,14 @@ LongValues buildAndWritePostingsLists(
DocIdsWriter docIdsWriter = new DocIdsWriter();
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
// write the max posting list size
postingsOutput.writeVInt(maxPostingListSize);
// write the posting lists
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
int[] cluster = assignmentsByCluster[c];
boolean[] isOverspill = isOverspillByCluster[c];
offsets.add(postingsOutput.alignFilePointer(Float.BYTES));
offsets.add(postingsOutput.alignFilePointer(Float.BYTES) - fileOffset);
// write raw centroid for quantizing the query vectors
buffer.asFloatBuffer().put(centroid);
postingsOutput.writeBytes(buffer.array(), buffer.array().length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,12 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio
final long centroidOffset = input.readLong();
final long centroidLength = input.readLong();
final float[] globalCentroid = new float[info.getVectorDimension()];
long postingListOffset = -1;
long postingListLength = -1;
float globalCentroidDp = 0;
if (centroidLength > 0) {
postingListOffset = input.readLong();
postingListLength = input.readLong();
input.readFloats(globalCentroid, 0, globalCentroid.length);
globalCentroidDp = Float.intBitsToFloat(input.readInt());
}
Expand All @@ -164,6 +168,8 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio
numCentroids,
centroidOffset,
centroidLength,
postingListOffset,
postingListLength,
globalCentroid,
globalCentroidDp
);
Expand Down Expand Up @@ -245,7 +251,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1);
}
CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target);
PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring);
PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, needsScoring);
int centroidsVisited = 0;
long expectedDocs = 0;
long actualDocs = 0;
Expand Down Expand Up @@ -298,12 +304,18 @@ protected record FieldEntry(
int numCentroids,
long centroidOffset,
long centroidLength,
long postingListOffset,
long postingListLength,
float[] globalCentroid,
float globalCentroidDp
) {
IndexInput centroidSlice(IndexInput centroidFile) throws IOException {
return centroidFile.slice("centroids", centroidOffset, centroidLength);
}

IndexInput postingListSlice(IndexInput postingListFile) throws IOException {
return postingListFile.slice("postingLists", postingListOffset, postingListLength);
}
}

abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ abstract LongValues buildAndWritePostingsLists(
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
long fileOffset,
int[] assignments,
int[] overspillAssignments
) throws IOException;
Expand All @@ -145,6 +146,7 @@ abstract LongValues buildAndWritePostingsLists(
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
long fileOffset,
MergeState mergeState,
int[] assignments,
int[] overspillAssignments
Expand All @@ -169,20 +171,31 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
// wrap centroids with a supplier
final CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.centroids());
// write posting lists
final long postingListOffset = ivfClusters.alignFilePointer(Float.BYTES);
final LongValues offsets = buildAndWritePostingsLists(
fieldWriter.fieldInfo,
centroidSupplier,
floatVectorValues,
ivfClusters,
postingListOffset,
centroidAssignments.assignments(),
centroidAssignments.overspillAssignments()
);
final long postingListLength = ivfClusters.getFilePointer() - postingListOffset;
// write centroids
final long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
writeCentroids(fieldWriter.fieldInfo, centroidSupplier, globalCentroid, offsets, ivfCentroids);
final long centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
// write meta file
writeMeta(fieldWriter.fieldInfo, centroidSupplier.size(), centroidOffset, centroidLength, globalCentroid);
writeMeta(
fieldWriter.fieldInfo,
centroidSupplier.size(),
centroidOffset,
centroidLength,
postingListOffset,
postingListLength,
globalCentroid
);
}
}

Expand Down Expand Up @@ -288,6 +301,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws

final long centroidOffset;
final long centroidLength;
final long postingListOffset;
final long postingListLength;
final int numCentroids;
final int[] assignments;
final int[] overspillAssignments;
Expand Down Expand Up @@ -322,7 +337,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
try {
if (numCentroids == 0) {
centroidOffset = ivfCentroids.getFilePointer();
writeMeta(fieldInfo, 0, centroidOffset, 0, null);
writeMeta(fieldInfo, 0, centroidOffset, 0, 0, 0, null);
CodecUtil.writeFooter(centroidTemp);
IOUtils.close(centroidTemp);
return;
Expand All @@ -338,21 +353,32 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
calculatedGlobalCentroid
);
// write posting lists
postingListOffset = ivfClusters.alignFilePointer(Float.BYTES);
final LongValues offsets = buildAndWritePostingsLists(
fieldInfo,
centroidSupplier,
floatVectorValues,
ivfClusters,
postingListOffset,
mergeState,
assignments,
overspillAssignments
);
postingListLength = ivfClusters.getFilePointer() - postingListOffset;
// write centroids
centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
writeCentroids(fieldInfo, centroidSupplier, calculatedGlobalCentroid, offsets, ivfCentroids);
centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
// write meta
writeMeta(fieldInfo, centroidSupplier.size(), centroidOffset, centroidLength, calculatedGlobalCentroid);
writeMeta(
fieldInfo,
centroidSupplier.size(),
centroidOffset,
centroidLength,
postingListOffset,
postingListLength,
calculatedGlobalCentroid
);
}
} finally {
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName);
Expand Down Expand Up @@ -435,15 +461,24 @@ private static int writeFloatVectorValues(
return numVectors;
}

private void writeMeta(FieldInfo field, int numCentroids, long centroidOffset, long centroidLength, float[] globalCentroid)
throws IOException {
private void writeMeta(
FieldInfo field,
int numCentroids,
long centroidOffset,
long centroidLength,
long postingListOffset,
long postingListLength,
float[] globalCentroid
) throws IOException {
ivfMeta.writeInt(field.number);
ivfMeta.writeInt(field.getVectorEncoding().ordinal());
ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction()));
ivfMeta.writeInt(numCentroids);
ivfMeta.writeLong(centroidOffset);
ivfMeta.writeLong(centroidLength);
if (centroidLength > 0) {
ivfMeta.writeLong(postingListOffset);
ivfMeta.writeLong(postingListLength);
final ByteBuffer buffer = ByteBuffer.allocate(globalCentroid.length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
buffer.asFloatBuffer().put(globalCentroid);
ivfMeta.writeBytes(buffer.array(), buffer.array().length);
Expand Down