diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 32b0ea496e942..304cc57284227 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -68,35 +68,23 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind return new CentroidQueryScorer() { int currentCentroid = -1; long postingListOffset; - private final float[] centroid = new float[fieldInfo.getVectorDimension()]; private final float[] centroidCorrectiveValues = new float[3]; - private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES); - private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension() + Long.BYTES; + private final long quantizeCentroidsLength = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + + Short.BYTES); @Override public int size() { return numCentroids; } - @Override - public float[] centroid(int centroidOrdinal) throws IOException { - readDataIfNecessary(centroidOrdinal); - return centroid; - } - @Override public long postingListOffset(int centroidOrdinal) throws IOException { - readDataIfNecessary(centroidOrdinal); - return postingListOffset; - } - - private void readDataIfNecessary(int centroidOrdinal) throws IOException { if (centroidOrdinal != currentCentroid) { - centroids.seek(rawCentroidsOffset + rawCentroidsByteSize * centroidOrdinal); - centroids.readFloats(centroid, 0, centroid.length); + centroids.seek(quantizeCentroidsLength + (long) Long.BYTES * centroidOrdinal); postingListOffset = centroids.readLong(); currentCentroid = centroidOrdinal; } + return postingListOffset; } public void bulkScore(NeighborQueue queue) throws IOException { @@ -193,7 +181,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { int vectors; boolean quantized = false; float centroidDp; - float[] centroid; + final float[] centroid; long slicePos; OptimizedScalarQuantizer.QuantizationResult queryCorrections; DocIdsWriter docIdsWriter = new DocIdsWriter(); @@ -217,7 +205,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { this.entry = entry; this.fieldInfo = fieldInfo; this.needsScoring = needsScoring; - + centroid = new float[fieldInfo.getVectorDimension()]; scratch = new float[target.length]; quantizationScratch = new int[target.length]; final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64); @@ -229,12 +217,12 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { } @Override - public int resetPostingsScorer(long offset, float[] centroid) throws IOException { + public int resetPostingsScorer(long offset) throws IOException { quantized = false; indexInput.seek(offset); - vectors = indexInput.readVInt(); + indexInput.readFloats(centroid, 0, centroid.length); centroidDp = Float.intBitsToFloat(indexInput.readInt()); - this.centroid = centroid; + vectors = indexInput.readVInt(); // read the doc ids docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch; docIdsWriter.readInts(indexInput, vectors, docIdsScratch); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index f34ad071b6ba0..2260e32187596 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -92,19 +92,25 @@ LongValues buildAndWritePostingsLists( fieldInfo.getVectorDimension(), new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()) ); + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); int[] cluster = assignmentsByCluster[c]; - // TODO align??? - offsets.add(postingsOutput.getFilePointer()); + offsets.add(postingsOutput.alignFilePointer(Float.BYTES)); + buffer.asFloatBuffer().put(centroid); + // write raw centroid for quantizing the query vectors + postingsOutput.writeBytes(buffer.array(), buffer.array().length); + // write centroid dot product for quantizing the query vectors + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); int size = cluster.length; + // write docIds postingsOutput.writeVInt(size); - postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]); // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); + // write vectors bulkWriter.writeVectors(onHeapQuantizedVectors); } @@ -209,20 +215,26 @@ LongValues buildAndWritePostingsLists( ); DocIdsWriter docIdsWriter = new DocIdsWriter(); DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); int[] cluster = assignmentsByCluster[c]; boolean[] isOverspill = isOverspillByCluster[c]; - offsets.add(postingsOutput.getFilePointer()); + offsets.add(postingsOutput.alignFilePointer(Float.BYTES)); + // write raw centroid for quantizing the query vectors + buffer.asFloatBuffer().put(centroid); + postingsOutput.writeBytes(buffer.array(), buffer.array().length); + // write centroid dot product for quantizing the query vectors + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + // write docIds int size = cluster.length; - // TODO align??? postingsOutput.writeVInt(size); - postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]); // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); + // write vectors bulkWriter.writeVectors(offHeapQuantizedVectors); } @@ -298,13 +310,8 @@ void writeCentroids( } writeQuantizedValue(centroidOutput, quantized, result); } - final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + // write the centroid offsets at the end of the file for (int i = 0; i < centroidSupplier.size(); i++) { - float[] centroid = centroidSupplier.centroid(i); - buffer.asFloatBuffer().put(centroid); - // write the centroids - centroidOutput.writeBytes(buffer.array(), buffer.array().length); - // write the offset of this posting list centroidOutput.writeLong(offsets.get(i)); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index f9d70c8d3d8eb..01cced04a9fcc 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -266,10 +266,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector int centroidOrdinal = centroidQueue.pop(); // todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing // is enough? - expectedDocs += scorer.resetPostingsScorer( - centroidQueryScorer.postingListOffset(centroidOrdinal), - centroidQueryScorer.centroid(centroidOrdinal) - ); + expectedDocs += scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal)); actualDocs += scorer.visit(knnCollector); } if (acceptDocs != null) { @@ -278,10 +275,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f); while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) { int centroidOrdinal = centroidQueue.pop(); - scorer.resetPostingsScorer( - centroidQueryScorer.postingListOffset(centroidOrdinal), - centroidQueryScorer.centroid(centroidOrdinal) - ); + scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal)); actualDocs += scorer.visit(knnCollector); } } @@ -332,8 +326,6 @@ abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postin interface CentroidQueryScorer { int size(); - float[] centroid(int centroidOrdinal) throws IOException; - long postingListOffset(int centroidOrdinal) throws IOException; void bulkScore(NeighborQueue queue) throws IOException; @@ -343,7 +335,7 @@ interface PostingVisitor { // TODO maybe we can not specifically pass the centroid... /** returns the number of documents in the posting list */ - int resetPostingsScorer(long offset, float[] centroid) throws IOException; + int resetPostingsScorer(long offset) throws IOException; /** returns the number of scored documents */ int visit(KnnCollector collector) throws IOException;