diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index ac95f3c8ad0af..32b0ea496e942 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -67,10 +67,11 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension()); return new CentroidQueryScorer() { int currentCentroid = -1; + long postingListOffset; private final float[] centroid = new float[fieldInfo.getVectorDimension()]; private final float[] centroidCorrectiveValues = new float[3]; private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES); - private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension(); + private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension() + Long.BYTES; @Override public int size() { @@ -79,12 +80,23 @@ public int size() { @Override public float[] centroid(int centroidOrdinal) throws IOException { + readDataIfNecessary(centroidOrdinal); + return centroid; + } + + @Override + public long postingListOffset(int centroidOrdinal) throws IOException { + readDataIfNecessary(centroidOrdinal); + return postingListOffset; + } + + private void readDataIfNecessary(int centroidOrdinal) throws IOException { if (centroidOrdinal != currentCentroid) { centroids.seek(rawCentroidsOffset + rawCentroidsByteSize * centroidOrdinal); centroids.readFloats(centroid, 0, centroid.length); + postingListOffset = centroids.readLong(); currentCentroid = centroidOrdinal; } - return centroid; } public void bulkScore(NeighborQueue queue) throws IOException { @@ -217,9 +229,9 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { } @Override - public int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException { + public int resetPostingsScorer(long offset, float[] centroid) throws IOException { quantized = false; - indexInput.seek(entry.postingListOffsets()[centroidOrdinal]); + indexInput.seek(offset); vectors = indexInput.readVInt(); centroidDp = Float.intBitsToFloat(indexInput.readInt()); this.centroid = centroid; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index e94b728f934e4..726e84605157a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -265,8 +265,15 @@ CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentro return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo); } - static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput) - throws IOException { + @Override + void writeCentroids( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + float[] globalCentroid, + long[] offsets, + IndexOutput centroidOutput + ) throws IOException { + final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); int[] quantizedScratch = new int[fieldInfo.getVectorDimension()]; float[] centroidScratch = new float[fieldInfo.getVectorDimension()]; @@ -274,7 +281,8 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo // TODO do we want to store these distances as well for future use? // TODO: sort centroids by global centroid (was doing so previously here) // TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned - for (float[] centroid : centroids) { + for (int i = 0; i < centroidSupplier.size(); i++) { + float[] centroid = centroidSupplier.centroid(i); System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length); OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize( centroidScratch, @@ -282,54 +290,41 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo (byte) 4, globalCentroid ); - for (int i = 0; i < quantizedScratch.length; i++) { - quantized[i] = (byte) quantizedScratch[i]; + for (int j = 0; j < quantizedScratch.length; j++) { + quantized[j] = (byte) quantizedScratch[j]; } writeQuantizedValue(centroidOutput, quantized, result); } final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (float[] centroid : centroids) { + for (int i = 0; i < centroidSupplier.size(); i++) { + float[] centroid = centroidSupplier.centroid(i); buffer.asFloatBuffer().put(centroid); + // write the centroids centroidOutput.writeBytes(buffer.array(), buffer.array().length); + // write the offset of this posting list + centroidOutput.writeLong(offsets[i]); } } - @Override - CentroidAssignments calculateAndWriteCentroids( - FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput centroidOutput, - MergeState mergeState, - float[] globalCentroid - ) throws IOException { - // TODO: take advantage of prior generated clusters from mergeState in the future - return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, globalCentroid); - } - /** - * Calculate the centroids for the given field and write them to the given centroid output. + * Calculate the centroids for the given field. * We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments * * @param fieldInfo merging field info * @param floatVectorValues the float vector values to merge - * @param centroidOutput the centroid output * @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids * @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed * @throws IOException if an I/O error occurs */ @Override - CentroidAssignments calculateAndWriteCentroids( - FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput centroidOutput, - float[] globalCentroid - ) throws IOException { + CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) + throws IOException { long nanoTime = System.nanoTime(); // TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids - KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); - float[][] centroids = kMeansResult.centroids(); + CentroidAssignments centroidAssignments = buildCentroidAssignments(floatVectorValues, vectorPerCluster); + float[][] centroids = centroidAssignments.centroids(); // TODO: for flush we are doing this over the vectors and here centroids which seems duplicative // preliminary tests suggest recall is good using only centroids but need to do further evaluation // TODO: push this logic into vector util? @@ -342,17 +337,15 @@ CentroidAssignments calculateAndWriteCentroids( globalCentroid[j] /= centroids.length; } - // write centroids - writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput); - if (logger.isDebugEnabled()) { logger.debug("calculate centroids and assign vectors time ms: {}", (System.nanoTime() - nanoTime) / 1000000.0); logger.debug("final centroid count: {}", centroids.length); } - return buildCentroidAssignments(kMeansResult); + return centroidAssignments; } - static CentroidAssignments buildCentroidAssignments(KMeansResult kMeansResult) { + static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException { + KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); float[][] centroids = kMeansResult.centroids(); int[] assignments = kMeansResult.assignments(); int[] soarAssignments = kMeansResult.soarAssignments(); @@ -374,7 +367,6 @@ static class OffHeapCentroidSupplier implements CentroidSupplier { private final int numCentroids; private final int dimension; private final float[] scratch; - private final long rawCentroidOffset; private int currOrd = -1; OffHeapCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo info) { @@ -382,7 +374,6 @@ static class OffHeapCentroidSupplier implements CentroidSupplier { this.numCentroids = numCentroids; this.dimension = info.getVectorDimension(); this.scratch = new float[dimension]; - this.rawCentroidOffset = (dimension + 3 * Float.BYTES + Short.BYTES) * numCentroids; } @Override @@ -395,7 +386,7 @@ public float[] centroid(int centroidOrdinal) throws IOException { if (centroidOrdinal == currOrd) { return scratch; } - centroidsInput.seek(rawCentroidOffset + (long) centroidOrdinal * dimension * Float.BYTES); + centroidsInput.seek((long) centroidOrdinal * dimension * Float.BYTES); centroidsInput.readFloats(scratch, 0, dimension); this.currOrd = centroidOrdinal; return scratch; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index dbcdfd451df95..f9d70c8d3d8eb 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -140,19 +140,6 @@ private void readFields(ChecksumIndexInput meta) throws IOException { private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { final VectorEncoding vectorEncoding = readVectorEncoding(input); final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); - final long centroidOffset = input.readLong(); - final long centroidLength = input.readLong(); - final int numPostingLists = input.readVInt(); - final long[] postingListOffsets = new long[numPostingLists]; - for (int i = 0; i < numPostingLists; i++) { - postingListOffsets[i] = input.readLong(); - } - final float[] globalCentroid = new float[info.getVectorDimension()]; - float globalCentroidDp = 0; - if (numPostingLists > 0) { - input.readFloats(globalCentroid, 0, globalCentroid.length); - globalCentroidDp = Float.intBitsToFloat(input.readInt()); - } if (similarityFunction != info.getVectorSimilarityFunction()) { throw new IllegalStateException( "Inconsistent vector similarity function for field=\"" @@ -163,12 +150,21 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio + info.getVectorSimilarityFunction() ); } + final int numCentroids = input.readInt(); + final long centroidOffset = input.readLong(); + final long centroidLength = input.readLong(); + final float[] globalCentroid = new float[info.getVectorDimension()]; + float globalCentroidDp = 0; + if (centroidLength > 0) { + input.readFloats(globalCentroid, 0, globalCentroid.length); + globalCentroidDp = Float.intBitsToFloat(input.readInt()); + } return new FieldEntry( similarityFunction, vectorEncoding, + numCentroids, centroidOffset, centroidLength, - postingListOffsets, globalCentroid, globalCentroidDp ); @@ -242,7 +238,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector FieldEntry entry = fields.get(fieldInfo.number); CentroidQueryScorer centroidQueryScorer = getCentroidScorer( fieldInfo, - entry.postingListOffsets.length, + entry.numCentroids, entry.centroidSlice(ivfCentroids), target ); @@ -270,7 +266,10 @@ public final void search(String field, float[] target, KnnCollector knnCollector int centroidOrdinal = centroidQueue.pop(); // todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing // is enough? - expectedDocs += scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); + expectedDocs += scorer.resetPostingsScorer( + centroidQueryScorer.postingListOffset(centroidOrdinal), + centroidQueryScorer.centroid(centroidOrdinal) + ); actualDocs += scorer.visit(knnCollector); } if (acceptDocs != null) { @@ -279,7 +278,10 @@ public final void search(String field, float[] target, KnnCollector knnCollector float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f); while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) { int centroidOrdinal = centroidQueue.pop(); - scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); + scorer.resetPostingsScorer( + centroidQueryScorer.postingListOffset(centroidOrdinal), + centroidQueryScorer.centroid(centroidOrdinal) + ); actualDocs += scorer.visit(knnCollector); } } @@ -313,9 +315,9 @@ public void close() throws IOException { protected record FieldEntry( VectorSimilarityFunction similarityFunction, VectorEncoding vectorEncoding, + int numCentroids, long centroidOffset, long centroidLength, - long[] postingListOffsets, float[] globalCentroid, float globalCentroidDp ) { @@ -332,6 +334,8 @@ interface CentroidQueryScorer { float[] centroid(int centroidOrdinal) throws IOException; + long postingListOffset(int centroidOrdinal) throws IOException; + void bulkScore(NeighborQueue queue) throws IOException; } @@ -339,7 +343,7 @@ interface PostingVisitor { // TODO maybe we can not specifically pass the centroid... /** returns the number of documents in the posting list */ - int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException; + int resetPostingsScorer(long offset, float[] centroid) throws IOException; /** returns the number of scored documents */ int visit(KnnCollector collector) throws IOException; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index be7a60a3db893..f828c96f7a9e1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -119,19 +119,15 @@ public final KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOExc return rawVectorDelegate; } - abstract CentroidAssignments calculateAndWriteCentroids( - FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput centroidOutput, - MergeState mergeState, - float[] globalCentroid - ) throws IOException; + abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) + throws IOException; - abstract CentroidAssignments calculateAndWriteCentroids( + abstract void writeCentroids( FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput centroidOutput, - float[] globalCentroid + CentroidSupplier centroidSupplier, + float[] globalCentroid, + long[] centroidOffset, + IndexOutput centroidOutput ) throws IOException; abstract long[] buildAndWritePostingsLists( @@ -168,18 +164,10 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { // build a float vector values with random access final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); // build centroids - long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); - - final CentroidAssignments centroidAssignments = calculateAndWriteCentroids( - fieldWriter.fieldInfo, - floatVectorValues, - ivfCentroids, - globalCentroid - ); - - CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.centroids()); - - long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues, globalCentroid); + // wrap centroids with a supplier + final CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.centroids()); + // write posting lists final long[] offsets = buildAndWritePostingsLists( fieldWriter.fieldInfo, centroidSupplier, @@ -188,8 +176,13 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { centroidAssignments.assignments(), centroidAssignments.overspillAssignments() ); - // write posting lists - writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); + assert offsets.length == centroidSupplier.size(); + final long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); + // write centroids + writeCentroids(fieldWriter.fieldInfo, centroidSupplier, globalCentroid, offsets, ivfCentroids); + final long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + // write meta file + writeMeta(fieldWriter.fieldInfo, centroidSupplier.size(), centroidOffset, centroidLength, globalCentroid); } } @@ -305,13 +298,17 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws try { centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT); centroidTempName = centroidTemp.getName(); - CentroidAssignments centroidAssignments = calculateAndWriteCentroids( + CentroidAssignments centroidAssignments = calculateCentroids( fieldInfo, getFloatVectorValues(fieldInfo, docs, vectors, numVectors), - centroidTemp, - mergeState, calculatedGlobalCentroid ); + // write the centroids to a temporary file so we are not holding them on heap + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (float[] centroid : centroidAssignments.centroids()) { + buffer.asFloatBuffer().put(centroid); + centroidTemp.writeBytes(buffer.array(), buffer.array().length); + } numCentroids = centroidAssignments.numCentroids(); assignments = centroidAssignments.assignments(); overspillAssignments = centroidAssignments.overspillAssignments(); @@ -325,27 +322,22 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws try { if (numCentroids == 0) { centroidOffset = ivfCentroids.getFilePointer(); - writeMeta(fieldInfo, centroidOffset, 0, new long[0], null); + writeMeta(fieldInfo, 0, centroidOffset, 0, null); CodecUtil.writeFooter(centroidTemp); IOUtils.close(centroidTemp); return; } CodecUtil.writeFooter(centroidTemp); IOUtils.close(centroidTemp); - centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); - try (IndexInput centroidsInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) { - ivfCentroids.copyBytes(centroidsInput, centroidsInput.length() - CodecUtil.footerLength()); - centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + try (IndexInput centroidsInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) { CentroidSupplier centroidSupplier = createCentroidSupplier( centroidsInput, numCentroids, fieldInfo, calculatedGlobalCentroid ); - - // build a float vector values with random access - // build centroids + // write posting lists final long[] offsets = buildAndWritePostingsLists( fieldInfo, centroidSupplier, @@ -356,7 +348,12 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws overspillAssignments ); assert offsets.length == centroidSupplier.size(); - writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid); + centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); + // write centroids + writeCentroids(fieldInfo, centroidSupplier, calculatedGlobalCentroid, offsets, ivfCentroids); + centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + // write meta + writeMeta(fieldInfo, centroidSupplier.size(), centroidOffset, centroidLength, calculatedGlobalCentroid); } } finally { org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); @@ -439,18 +436,15 @@ private static int writeFloatVectorValues( return numVectors; } - private void writeMeta(FieldInfo field, long centroidOffset, long centroidLength, long[] offsets, float[] globalCentroid) + private void writeMeta(FieldInfo field, int numCentroids, long centroidOffset, long centroidLength, float[] globalCentroid) throws IOException { ivfMeta.writeInt(field.number); ivfMeta.writeInt(field.getVectorEncoding().ordinal()); ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction())); + ivfMeta.writeInt(numCentroids); ivfMeta.writeLong(centroidOffset); ivfMeta.writeLong(centroidLength); - ivfMeta.writeVInt(offsets.length); - for (long offset : offsets) { - ivfMeta.writeLong(offset); - } - if (offsets.length > 0) { + if (centroidLength > 0) { final ByteBuffer buffer = ByteBuffer.allocate(globalCentroid.length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); buffer.asFloatBuffer().put(globalCentroid); ivfMeta.writeBytes(buffer.array(), buffer.array().length);