diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 39daa3c64dc12..445d572618d87 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -297,7 +297,8 @@ private static void score( PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring) throws IOException { FieldEntry entry = fields.get(fieldInfo.number); - return new MemorySegmentPostingsVisitor(target, indexInput.clone(), entry, fieldInfo, needsScoring); + final int maxPostingListSize = indexInput.readVInt(); + return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, maxPostingListSize, needsScoring); } @Override @@ -318,8 +319,8 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { final float[] correctionsUpper = new float[BULK_SIZE]; final int[] correctionsSum = new int[BULK_SIZE]; final float[] correctionsAdd = new float[BULK_SIZE]; + final int[] docIdsScratch; - int[] docIdsScratch = new int[0]; int vectors; boolean quantized = false; float centroidDp; @@ -340,6 +341,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { IndexInput indexInput, FieldEntry entry, FieldInfo fieldInfo, + int maxPostingListSize, IntPredicate needsScoring ) throws IOException { this.target = target; @@ -356,6 +358,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { quantizedVectorByteSize = (discretizedDimensions / 8); quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction(), DEFAULT_LAMBDA, 1); osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension()); + this.docIdsScratch = new int[maxPostingListSize]; } @Override @@ -366,7 +369,7 @@ public int resetPostingsScorer(long offset) throws IOException { centroidDp = Float.intBitsToFloat(indexInput.readInt()); vectors = indexInput.readVInt(); // read the doc ids - docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch; + assert vectors <= docIdsScratch.length; docIdsWriter.readInts(indexInput, vectors, docIdsScratch); slicePos = indexInput.getFilePointer(); return vectors; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index bd8930d13ba48..d16163d6934e8 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -64,6 +64,7 @@ LongValues buildAndWritePostingsLists( CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, IndexOutput postingsOutput, + long fileOffset, int[] assignments, int[] overspillAssignments ) throws IOException { @@ -76,9 +77,12 @@ LongValues buildAndWritePostingsLists( } } + int maxPostingListSize = 0; int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; for (int c = 0; c < centroidSupplier.size(); c++) { - assignmentsByCluster[c] = new int[centroidVectorCount[c]]; + int size = centroidVectorCount[c]; + maxPostingListSize = Math.max(maxPostingListSize, size); + assignmentsByCluster[c] = new int[size]; } Arrays.fill(centroidVectorCount, 0); @@ -93,6 +97,8 @@ LongValues buildAndWritePostingsLists( } } } + // write the max posting list size + postingsOutput.writeVInt(maxPostingListSize); // write the posting lists final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); DocIdsWriter docIdsWriter = new DocIdsWriter(); @@ -106,7 +112,7 @@ LongValues buildAndWritePostingsLists( for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); int[] cluster = assignmentsByCluster[c]; - offsets.add(postingsOutput.alignFilePointer(Float.BYTES)); + offsets.add(postingsOutput.alignFilePointer(Float.BYTES) - fileOffset); buffer.asFloatBuffer().put(centroid); // write raw centroid for quantizing the query vectors postingsOutput.writeBytes(buffer.array(), buffer.array().length); @@ -137,6 +143,7 @@ LongValues buildAndWritePostingsLists( CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, IndexOutput postingsOutput, + long fileOffset, MergeState mergeState, int[] assignments, int[] overspillAssignments @@ -196,11 +203,14 @@ LongValues buildAndWritePostingsLists( } } + int maxPostingListSize = 0; int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][]; for (int c = 0; c < centroidSupplier.size(); c++) { - assignmentsByCluster[c] = new int[centroidVectorCount[c]]; - isOverspillByCluster[c] = new boolean[centroidVectorCount[c]]; + int size = centroidVectorCount[c]; + maxPostingListSize = Math.max(maxPostingListSize, size); + assignmentsByCluster[c] = new int[size]; + isOverspillByCluster[c] = new boolean[size]; } Arrays.fill(centroidVectorCount, 0); @@ -226,11 +236,14 @@ LongValues buildAndWritePostingsLists( DocIdsWriter docIdsWriter = new DocIdsWriter(); DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + // write the max posting list size + postingsOutput.writeVInt(maxPostingListSize); + // write the posting lists for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); int[] cluster = assignmentsByCluster[c]; boolean[] isOverspill = isOverspillByCluster[c]; - offsets.add(postingsOutput.alignFilePointer(Float.BYTES)); + offsets.add(postingsOutput.alignFilePointer(Float.BYTES) - fileOffset); // write raw centroid for quantizing the query vectors buffer.asFloatBuffer().put(centroid); postingsOutput.writeBytes(buffer.array(), buffer.array().length); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index b570bd83f36e4..bde7b1d5b60c0 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -153,8 +153,12 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio final long centroidOffset = input.readLong(); final long centroidLength = input.readLong(); final float[] globalCentroid = new float[info.getVectorDimension()]; + long postingListOffset = -1; + long postingListLength = -1; float globalCentroidDp = 0; if (centroidLength > 0) { + postingListOffset = input.readLong(); + postingListLength = input.readLong(); input.readFloats(globalCentroid, 0, globalCentroid.length); globalCentroidDp = Float.intBitsToFloat(input.readInt()); } @@ -164,6 +168,8 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio numCentroids, centroidOffset, centroidLength, + postingListOffset, + postingListLength, globalCentroid, globalCentroidDp ); @@ -245,7 +251,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1); } CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target); - PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring); + PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, needsScoring); int centroidsVisited = 0; long expectedDocs = 0; long actualDocs = 0; @@ -298,12 +304,18 @@ protected record FieldEntry( int numCentroids, long centroidOffset, long centroidLength, + long postingListOffset, + long postingListLength, float[] globalCentroid, float globalCentroidDp ) { IndexInput centroidSlice(IndexInput centroidFile) throws IOException { return centroidFile.slice("centroids", centroidOffset, centroidLength); } + + IndexInput postingListSlice(IndexInput postingListFile) throws IOException { + return postingListFile.slice("postingLists", postingListOffset, postingListLength); + } } abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index 149db2eb96b83..308ee391b5f4a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -136,6 +136,7 @@ abstract LongValues buildAndWritePostingsLists( CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, IndexOutput postingsOutput, + long fileOffset, int[] assignments, int[] overspillAssignments ) throws IOException; @@ -145,6 +146,7 @@ abstract LongValues buildAndWritePostingsLists( CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, IndexOutput postingsOutput, + long fileOffset, MergeState mergeState, int[] assignments, int[] overspillAssignments @@ -169,20 +171,31 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { // wrap centroids with a supplier final CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.centroids()); // write posting lists + final long postingListOffset = ivfClusters.alignFilePointer(Float.BYTES); final LongValues offsets = buildAndWritePostingsLists( fieldWriter.fieldInfo, centroidSupplier, floatVectorValues, ivfClusters, + postingListOffset, centroidAssignments.assignments(), centroidAssignments.overspillAssignments() ); + final long postingListLength = ivfClusters.getFilePointer() - postingListOffset; // write centroids final long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); writeCentroids(fieldWriter.fieldInfo, centroidSupplier, globalCentroid, offsets, ivfCentroids); final long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; // write meta file - writeMeta(fieldWriter.fieldInfo, centroidSupplier.size(), centroidOffset, centroidLength, globalCentroid); + writeMeta( + fieldWriter.fieldInfo, + centroidSupplier.size(), + centroidOffset, + centroidLength, + postingListOffset, + postingListLength, + globalCentroid + ); } } @@ -288,6 +301,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws final long centroidOffset; final long centroidLength; + final long postingListOffset; + final long postingListLength; final int numCentroids; final int[] assignments; final int[] overspillAssignments; @@ -322,7 +337,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws try { if (numCentroids == 0) { centroidOffset = ivfCentroids.getFilePointer(); - writeMeta(fieldInfo, 0, centroidOffset, 0, null); + writeMeta(fieldInfo, 0, centroidOffset, 0, 0, 0, null); CodecUtil.writeFooter(centroidTemp); IOUtils.close(centroidTemp); return; @@ -338,21 +353,32 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws calculatedGlobalCentroid ); // write posting lists + postingListOffset = ivfClusters.alignFilePointer(Float.BYTES); final LongValues offsets = buildAndWritePostingsLists( fieldInfo, centroidSupplier, floatVectorValues, ivfClusters, + postingListOffset, mergeState, assignments, overspillAssignments ); + postingListLength = ivfClusters.getFilePointer() - postingListOffset; // write centroids centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); writeCentroids(fieldInfo, centroidSupplier, calculatedGlobalCentroid, offsets, ivfCentroids); centroidLength = ivfCentroids.getFilePointer() - centroidOffset; // write meta - writeMeta(fieldInfo, centroidSupplier.size(), centroidOffset, centroidLength, calculatedGlobalCentroid); + writeMeta( + fieldInfo, + centroidSupplier.size(), + centroidOffset, + centroidLength, + postingListOffset, + postingListLength, + calculatedGlobalCentroid + ); } } finally { org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); @@ -435,8 +461,15 @@ private static int writeFloatVectorValues( return numVectors; } - private void writeMeta(FieldInfo field, int numCentroids, long centroidOffset, long centroidLength, float[] globalCentroid) - throws IOException { + private void writeMeta( + FieldInfo field, + int numCentroids, + long centroidOffset, + long centroidLength, + long postingListOffset, + long postingListLength, + float[] globalCentroid + ) throws IOException { ivfMeta.writeInt(field.number); ivfMeta.writeInt(field.getVectorEncoding().ordinal()); ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction())); @@ -444,6 +477,8 @@ private void writeMeta(FieldInfo field, int numCentroids, long centroidOffset, l ivfMeta.writeLong(centroidOffset); ivfMeta.writeLong(centroidLength); if (centroidLength > 0) { + ivfMeta.writeLong(postingListOffset); + ivfMeta.writeLong(postingListLength); final ByteBuffer buffer = ByteBuffer.allocate(globalCentroid.length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); buffer.asFloatBuffer().put(globalCentroid); ivfMeta.writeBytes(buffer.array(), buffer.array().length);