diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index f51c550e5292e..ddd89c8c43dcc 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -52,9 +52,16 @@ record CmdLineArgs( int quantizeBits, VectorEncoding vectorEncoding, int dimensions, - boolean earlyTermination + boolean earlyTermination, + FILTER_KIND filterKind, + boolean sortIndex ) implements ToXContentObject { + public enum FILTER_KIND { + RANDOM, + RANGE + } + static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors"); static final ParseField QUERY_VECTORS_FIELD = new ParseField("query_vectors"); static final ParseField NUM_DOCS_FIELD = new ParseField("num_docs"); @@ -79,6 +86,8 @@ record CmdLineArgs( static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination"); static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity"); static final ParseField SEED_FIELD = new ParseField("seed"); + static final ParseField FILTER_KIND_FIELD = new ParseField("filter_kind"); + static final ParseField SORT_INDEX_FIELD = new ParseField("sort_index"); static CmdLineArgs fromXContent(XContentParser parser) throws IOException { Builder builder = PARSER.apply(parser, null); @@ -112,6 +121,8 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD); PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD); PARSER.declareLong(Builder::setSeed, SEED_FIELD); + PARSER.declareString(Builder::setFilterKind, FILTER_KIND_FIELD); + PARSER.declareBoolean(Builder::setSortIndex, SORT_INDEX_FIELD); } @Override @@ -179,6 +190,8 @@ static class Builder { private boolean earlyTermination; private float filterSelectivity = 1f; private long seed = 1751900822751L; + private FILTER_KIND filterKind = FILTER_KIND.RANDOM; + private boolean sortIndex = false; public Builder setDocVectors(List docVectors) { if (docVectors == null || docVectors.isEmpty()) { @@ -304,6 +317,16 @@ public Builder setSeed(long seed) { return this; } + public Builder setFilterKind(String filterKind) { + this.filterKind = FILTER_KIND.valueOf(filterKind.toUpperCase(Locale.ROOT)); + return this; + } + + public Builder setSortIndex(boolean sortIndex) { + this.sortIndex = sortIndex; + return this; + } + public CmdLineArgs build() { if (docVectors == null) { throw new IllegalArgumentException("Document vectors path must be provided"); @@ -337,7 +360,9 @@ public CmdLineArgs build() { quantizeBits, vectorEncoding, dimensions, - earlyTermination + earlyTermination, + filterKind, + sortIndex ); } } diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index c4b0ccdfe35e3..7f1c172bfcd0c 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -205,7 +205,8 @@ public static void main(String[] args) throws Exception { cmdLineArgs.vectorEncoding(), cmdLineArgs.dimensions(), cmdLineArgs.vectorSpace(), - cmdLineArgs.numDocs() + cmdLineArgs.numDocs(), + cmdLineArgs.sortIndex() ); if (cmdLineArgs.reindex() == false && Files.exists(indexPath) == false) { throw new IllegalArgumentException("Index path does not exist: " + indexPath); diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java index f7d00c9806c8d..f777f05f0d7e8 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java @@ -25,6 +25,7 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.ConcurrentMergeScheduler; import org.apache.lucene.index.DirectoryReader; @@ -33,6 +34,8 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.util.PrintStreamInfoStream; import org.elasticsearch.common.io.Channels; @@ -69,6 +72,7 @@ class KnnIndexer { private final Codec codec; private final int numDocs; private final int numIndexThreads; + private final boolean sortIndex; KnnIndexer( List docsPath, @@ -78,7 +82,8 @@ class KnnIndexer { VectorEncoding vectorEncoding, int dim, VectorSimilarityFunction similarityFunction, - int numDocs + int numDocs, + boolean sortIndex ) { this.docsPath = docsPath; this.indexPath = indexPath; @@ -88,6 +93,7 @@ class KnnIndexer { this.dim = dim; this.similarityFunction = similarityFunction; this.numDocs = numDocs; + this.sortIndex = sortIndex; } void numSegments(KnnIndexTester.Results result) { @@ -103,7 +109,9 @@ void createIndex(KnnIndexTester.Results result) throws IOException, InterruptedE iwc.setCodec(codec); iwc.setRAMBufferSizeMB(WRITER_BUFFER_MB); iwc.setUseCompoundFile(false); - + if (sortIndex) { + iwc.setIndexSort(new Sort(new SortField(ID_FIELD + "_sort", SortField.Type.LONG, false))); + } iwc.setMaxFullFlushMergeWaitMillis(0); iwc.setInfoStream(new PrintStreamInfoStream(System.out) { @@ -292,6 +300,7 @@ private void _run() throws IOException { logger.debug("Done indexing " + (id + 1) + " documents."); } doc.add(new StoredField(ID_FIELD, id)); + doc.add(new NumericDocValuesField(ID_FIELD + "_sort", id)); iw.addDocument(doc); } } diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index bb13dd75a4d9e..483ada75d17ea 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -115,6 +115,7 @@ class KnnSearcher { private final float overSamplingFactor; private final int searchThreads; private final int numSearchers; + private final CmdLineArgs.FILTER_KIND filterKind; KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) { this.docPath = cmdLineArgs.docVectors(); @@ -137,10 +138,13 @@ class KnnSearcher { this.numSearchers = cmdLineArgs.numSearchers(); this.randomSeed = cmdLineArgs.seed(); this.selectivity = cmdLineArgs.filterSelectivity(); + this.filterKind = cmdLineArgs.filterKind(); } void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException { - Query filterQuery = this.selectivity < 1f ? generateRandomQuery(new Random(randomSeed), indexPath, numDocs, selectivity) : null; + Query filterQuery = this.selectivity < 1f + ? generateRandomQuery(new Random(randomSeed), indexPath, numDocs, selectivity, filterKind) + : null; TopDocs[] results = new TopDocs[numQueryVectors]; int[][] resultIds = new int[numQueryVectors][]; long elapsed, totalCpuTimeMS, totalVisited = 0; @@ -307,14 +311,27 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed; } - private static Query generateRandomQuery(Random random, Path indexPath, int size, float selectivity) throws IOException { + private static Query generateRandomQuery(Random random, Path indexPath, int size, float selectivity, CmdLineArgs.FILTER_KIND filterKind) + throws IOException { FixedBitSet bitSet = new FixedBitSet(size); - for (int i = 0; i < size; i++) { - if (random.nextFloat() < selectivity) { - bitSet.set(i); - } else { - bitSet.clear(i); - } + switch (filterKind) { + case RANDOM: + for (int i = 0; i < size; i++) { + if (random.nextFloat() < selectivity) { + bitSet.set(i); + } else { + bitSet.clear(i); + } + } + break; + case RANGE: + int rangeBound = (int) (size * selectivity); + // set a random range of bits of length rangeBound + int start = random.nextInt(size - rangeBound); + for (int i = start; i < start + rangeBound; i++) { + bitSet.set(i); + } + break; } try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) { @@ -346,6 +363,7 @@ private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes, Query filterQue topK, similarityFunction.ordinal(), selectivity, + filterKind, randomSeed ), 36 diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 304cc57284227..c693e5e17023d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -16,6 +16,7 @@ import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.GroupVIntUtil; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.NeighborQueue; import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats; @@ -177,14 +178,13 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { final int[] correctionsSum = new int[BULK_SIZE]; final float[] correctionsAdd = new float[BULK_SIZE]; - int[] docIdsScratch = new int[0]; - int vectors; + int[] docIdsScratch = new int[0], spilledDocIdsScratch = new int[0]; + int vectors, spilledVectors; boolean quantized = false; float centroidDp; final float[] centroid; - long slicePos; + private long slicePos; OptimizedScalarQuantizer.QuantizationResult queryCorrections; - DocIdsWriter docIdsWriter = new DocIdsWriter(); final float[] scratch; final int[] quantizationScratch; @@ -222,18 +222,28 @@ public int resetPostingsScorer(long offset) throws IOException { indexInput.seek(offset); indexInput.readFloats(centroid, 0, centroid.length); centroidDp = Float.intBitsToFloat(indexInput.readInt()); - vectors = indexInput.readVInt(); + vectors = indexInput.readInt(); + spilledVectors = indexInput.readInt(); // read the doc ids docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch; - docIdsWriter.readInts(indexInput, vectors, docIdsScratch); + spilledDocIdsScratch = spilledVectors > spilledDocIdsScratch.length ? new int[spilledVectors] : spilledDocIdsScratch; + GroupVIntUtil.readGroupVInts(indexInput, docIdsScratch, vectors); + GroupVIntUtil.readGroupVInts(indexInput, spilledDocIdsScratch, spilledVectors); + // reconstitute from the deltas + for (int i = 1; i < vectors; i++) { + docIdsScratch[i] += docIdsScratch[i - 1]; + } + for (int i = 1; i < spilledVectors; i++) { + spilledDocIdsScratch[i] += spilledDocIdsScratch[i - 1]; + } slicePos = indexInput.getFilePointer(); return vectors; } - void scoreIndividually(int offset) throws IOException { + private void scoreIndividually(int offset, long slicePos, int[] docIds) throws IOException { // score individually, first the quantized byte chunk for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[j + offset]; + int doc = docIds[j + offset]; if (doc != -1) { indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize)); float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); @@ -250,7 +260,7 @@ void scoreIndividually(int offset) throws IOException { indexInput.readFloats(correctionsAdd, 0, BULK_SIZE); // Now apply corrections for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[offset + j]; + int doc = docIds[offset + j]; if (doc != -1) { scores[j] = osqVectorsScorer.score( queryCorrections.lowerInterval(), @@ -271,16 +281,25 @@ void scoreIndividually(int offset) throws IOException { @Override public int visit(KnnCollector knnCollector) throws IOException { + int scored = scoreDocs(knnCollector, vectors, 0, docIdsScratch); + if (spilledVectors > 0) { + scored += scoreDocs(knnCollector, spilledVectors, vectors * quantizedByteLength, spilledDocIdsScratch); + } + return scored; + } + + private int scoreDocs(KnnCollector knnCollector, int count, long sliceOffset, int[] docIds) throws IOException { // block processing int scoredDocs = 0; - int limit = vectors - BULK_SIZE + 1; + int limit = count - BULK_SIZE + 1; int i = 0; + long slicePos = this.slicePos + (int) sliceOffset; for (; i < limit; i += BULK_SIZE) { int docsToScore = BULK_SIZE; for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[i + j]; + int doc = docIds[i + j]; if (needsScoring.test(doc) == false) { - docIdsScratch[i + j] = -1; + docIds[i + j] = -1; docsToScore--; } } @@ -290,7 +309,7 @@ public int visit(KnnCollector knnCollector) throws IOException { quantizeQueryIfNecessary(); indexInput.seek(slicePos + i * quantizedByteLength); if (docsToScore < BULK_SIZE / 2) { - scoreIndividually(i); + scoreIndividually(i, slicePos, docIds); } else { osqVectorsScorer.scoreBulk( quantizedQueryScratch, @@ -304,7 +323,7 @@ public int visit(KnnCollector knnCollector) throws IOException { ); } for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[i + j]; + int doc = docIds[i + j]; if (doc != -1) { scoredDocs++; knnCollector.collect(doc, scores[j]); @@ -312,8 +331,8 @@ public int visit(KnnCollector knnCollector) throws IOException { } } // process tail - for (; i < vectors; i++) { - int doc = docIdsScratch[i]; + for (; i < count; i++) { + int doc = docIds[i]; if (needsScoring.test(doc)) { quantizeQueryIfNecessary(); indexInput.seek(slicePos + i * quantizedByteLength); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index f47ecc549831a..e477ead0e7f10 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -48,29 +48,62 @@ public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVec this.vectorPerCluster = vectorPerCluster; } - @Override - LongValues buildAndWritePostingsLists( - FieldInfo fieldInfo, - CentroidSupplier centroidSupplier, - FloatVectorValues floatVectorValues, - IndexOutput postingsOutput, + private static void deltaEncode(int[] vals, int size, int[] deltas) { + if (size == 0) { + return; + } + deltas[0] = vals[0]; + for (int i = 1; i < size; i++) { + assert vals[i] >= vals[i - 1] : "vals are not sorted: " + vals[i] + " < " + vals[i - 1]; + deltas[i] = vals[i] - vals[i - 1]; + } + } + + private static void translateOrdsToDocs( + int[] ords, + int size, + int[] spillOrds, + int spillSize, + int[] docIds, + int[] spillDocIds, + IntToIntFunction ordToDoc + ) { + int ordIdx = 0, spillOrdIdx = 0; + while (ordIdx < size || spillOrdIdx < spillSize) { + int nextOrd = (ordIdx < size) ? ords[ordIdx] : Integer.MAX_VALUE; + int nextSpillOrd = (spillOrdIdx < spillSize) ? spillOrds[spillOrdIdx] : Integer.MAX_VALUE; + if (nextOrd < nextSpillOrd) { + docIds[ordIdx] = ordToDoc.apply(nextOrd); + ordIdx++; + } else { + spillDocIds[spillOrdIdx] = ordToDoc.apply(nextSpillOrd); + spillOrdIdx++; + } + } + } + + private static void pivotAssignments( + int centroidCount, int[] assignments, - int[] overspillAssignments - ) throws IOException { - int[] centroidVectorCount = new int[centroidSupplier.size()]; + int[] overspillAssignments, + int[][] assignmentsByCluster, + int[][] overspillAssignmentsByCluster + ) { + int[] centroidVectorCount = new int[centroidCount]; + int[] overspillVectorCount = new int[centroidCount]; for (int i = 0; i < assignments.length; i++) { centroidVectorCount[assignments[i]]++; // if soar assignments are present, count them as well if (overspillAssignments.length > i && overspillAssignments[i] != -1) { - centroidVectorCount[overspillAssignments[i]]++; + overspillVectorCount[overspillAssignments[i]]++; } } - - int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; - for (int c = 0; c < centroidSupplier.size(); c++) { + for (int c = 0; c < centroidCount; c++) { assignmentsByCluster[c] = new int[centroidVectorCount[c]]; + overspillAssignmentsByCluster[c] = new int[overspillVectorCount[c]]; } Arrays.fill(centroidVectorCount, 0); + Arrays.fill(overspillVectorCount, 0); for (int i = 0; i < assignments.length; i++) { int c = assignments[i]; @@ -79,15 +112,35 @@ LongValues buildAndWritePostingsLists( if (overspillAssignments.length > i) { int s = overspillAssignments[i]; if (s != -1) { - assignmentsByCluster[s][centroidVectorCount[s]++] = i; + overspillAssignmentsByCluster[s][overspillVectorCount[s]++] = i; } } } + } + + @Override + LongValues buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + int[] assignments, + int[] overspillAssignments + ) throws IOException { + // write the posting lists final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); - DocIdsWriter docIdsWriter = new DocIdsWriter(); DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); - OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors( + // pivot the assignments into clusters + int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; + int[][] overspillAssignmentsByCluster = new int[centroidSupplier.size()][]; + pivotAssignments(centroidSupplier.size(), assignments, overspillAssignments, assignmentsByCluster, overspillAssignmentsByCluster); + + int[] docIds = null; + int[] docDeltas = null; + int[] spillDocIds = null; + int[] spillDeltas = null; + final OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors( floatVectorValues, fieldInfo.getVectorDimension(), new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()) @@ -96,26 +149,47 @@ LongValues buildAndWritePostingsLists( for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); int[] cluster = assignmentsByCluster[c]; + int[] overspillCluster = overspillAssignmentsByCluster[c]; offsets.add(postingsOutput.alignFilePointer(Float.BYTES)); buffer.asFloatBuffer().put(centroid); // write raw centroid for quantizing the query vectors postingsOutput.writeBytes(buffer.array(), buffer.array().length); - // write centroid dot product for quantizing the query vectors - postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); int size = cluster.length; - // write docIds - postingsOutput.writeVInt(size); - onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]); + int spillSize = overspillCluster.length; + if (docIds == null || docIds.length < size) { + docIds = new int[size]; + docDeltas = new int[size]; + } + if (spillDocIds == null || spillDocIds.length < spillSize) { + spillDocIds = new int[spillSize]; + spillDeltas = new int[spillSize]; + } + translateOrdsToDocs(cluster, size, overspillCluster, spillSize, docIds, spillDocIds, floatVectorValues::ordToDoc); + // encode doc deltas + if (size > 0) { + deltaEncode(docIds, size, docDeltas); + } + if (spillSize > 0) { + deltaEncode(spillDocIds, spillSize, spillDeltas); + } + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + postingsOutput.writeInt(size); + postingsOutput.writeInt(spillSize); // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache - docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); - // write vectors + postingsOutput.writeGroupVInts(docDeltas, size); + postingsOutput.writeGroupVInts(spillDeltas, spillSize); + onHeapQuantizedVectors.reset(centroid, size, j -> cluster[j]); + bulkWriter.writeVectors(onHeapQuantizedVectors); + // write overspill vectors + onHeapQuantizedVectors.reset(centroid, spillSize, j -> overspillCluster[j]); bulkWriter.writeVectors(onHeapQuantizedVectors); } if (logger.isDebugEnabled()) { printClusterQualityStatistics(assignmentsByCluster); + printClusterQualityStatistics(overspillAssignmentsByCluster); } return offsets.build(); @@ -177,69 +251,73 @@ LongValues buildAndWritePostingsLists( mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName()); } } - int[] centroidVectorCount = new int[centroidSupplier.size()]; - for (int i = 0; i < assignments.length; i++) { - centroidVectorCount[assignments[i]]++; - // if soar assignments are present, count them as well - if (overspillAssignments.length > i && overspillAssignments[i] != -1) { - centroidVectorCount[overspillAssignments[i]]++; - } - } - int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; - boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][]; - for (int c = 0; c < centroidSupplier.size(); c++) { - assignmentsByCluster[c] = new int[centroidVectorCount[c]]; - isOverspillByCluster[c] = new boolean[centroidVectorCount[c]]; - } - Arrays.fill(centroidVectorCount, 0); - - for (int i = 0; i < assignments.length; i++) { - int c = assignments[i]; - assignmentsByCluster[c][centroidVectorCount[c]++] = i; - // if soar assignments are present, add them to the cluster as well - if (overspillAssignments.length > i) { - int s = overspillAssignments[i]; - if (s != -1) { - assignmentsByCluster[s][centroidVectorCount[s]] = i; - isOverspillByCluster[s][centroidVectorCount[s]++] = true; - } - } - } + int[][] overspillAssignmentsByCluster = new int[centroidSupplier.size()][]; + // pivot the assignments into clusters + pivotAssignments(centroidSupplier.size(), assignments, overspillAssignments, assignmentsByCluster, overspillAssignmentsByCluster); // now we can read the quantized vectors from the temporary file try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) { final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); - OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors( + + final DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter( + ES91OSQVectorsScorer.BULK_SIZE, + postingsOutput + ); + final OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors( quantizedVectorsInput, fieldInfo.getVectorDimension() ); - DocIdsWriter docIdsWriter = new DocIdsWriter(); - DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); + int[] docIds = null; + int[] docDeltas = null; + int[] spillDocIds = null; + int[] spillDeltas = null; final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); int[] cluster = assignmentsByCluster[c]; - boolean[] isOverspill = isOverspillByCluster[c]; + int[] overspillCluster = overspillAssignmentsByCluster[c]; offsets.add(postingsOutput.alignFilePointer(Float.BYTES)); // write raw centroid for quantizing the query vectors buffer.asFloatBuffer().put(centroid); postingsOutput.writeBytes(buffer.array(), buffer.array().length); - // write centroid dot product for quantizing the query vectors - postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); - // write docIds int size = cluster.length; - postingsOutput.writeVInt(size); - offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]); + int spillSize = overspillCluster.length; + if (docIds == null || docIds.length < size) { + docIds = new int[size]; + docDeltas = new int[size]; + } + if (spillDocIds == null || spillDocIds.length < spillSize) { + spillDocIds = new int[spillSize]; + spillDeltas = new int[spillSize]; + } + // translate ordinals to docIds + translateOrdsToDocs(cluster, size, overspillCluster, spillSize, docIds, spillDocIds, floatVectorValues::ordToDoc); + // encode doc deltas + if (size > 0) { + deltaEncode(docIds, size, docDeltas); + } + if (spillSize > 0) { + deltaEncode(spillDocIds, spillSize, spillDeltas); + } + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + postingsOutput.writeInt(size); + postingsOutput.writeInt(spillSize); // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache - docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); - // write vectors + postingsOutput.writeGroupVInts(docDeltas, size); + postingsOutput.writeGroupVInts(spillDeltas, spillSize); + // write overspill vectors + + offHeapQuantizedVectors.reset(size, false, j -> cluster[j]); + bulkWriter.writeVectors(offHeapQuantizedVectors); + offHeapQuantizedVectors.reset(spillSize, true, j -> overspillCluster[j]); bulkWriter.writeVectors(offHeapQuantizedVectors); } if (logger.isDebugEnabled()) { printClusterQualityStatistics(assignmentsByCluster); + printClusterQualityStatistics(overspillAssignmentsByCluster); } return offsets.build(); } @@ -411,10 +489,6 @@ interface QuantizedVectorValues { OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException; } - interface IntToBooleanFunction { - boolean apply(int ord); - } - static class OnHeapQuantizedVectors implements QuantizedVectorValues { private final FloatVectorValues vectorValues; private final OptimizedScalarQuantizer quantizer; @@ -436,7 +510,7 @@ static class OnHeapQuantizedVectors implements QuantizedVectorValues { this.corrections = null; } - private void reset(float[] centroid, int count, IntToIntFunction ordTransformer) { + void reset(float[] centroid, int count, IntToIntFunction ordTransformer) { this.currentCentroid = centroid; this.ordTransformer = ordTransformer; this.currOrd = -1; @@ -482,7 +556,7 @@ static class OffHeapQuantizedVectors implements QuantizedVectorValues { private short bitSum; private int currOrd = -1; private int count; - private IntToBooleanFunction isOverspill = null; + private boolean isOverspill; private IntToIntFunction ordTransformer = null; OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) { @@ -491,7 +565,7 @@ static class OffHeapQuantizedVectors implements QuantizedVectorValues { this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES); } - private void reset(int count, IntToBooleanFunction isOverspill, IntToIntFunction ordTransformer) { + private void reset(int count, boolean isOverspill, IntToIntFunction ordTransformer) { this.count = count; this.isOverspill = isOverspill; this.ordTransformer = ordTransformer; @@ -510,7 +584,6 @@ public byte[] next() throws IOException { } currOrd++; int ord = ordTransformer.apply(currOrd); - boolean isOverspill = this.isOverspill.apply(currOrd); return getVector(ord, isOverspill); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java index 8499aa9a17320..95d79cc0d929b 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java @@ -17,13 +17,17 @@ import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -174,4 +178,55 @@ public void testWithThreads() throws Exception { } } } + + public void testSortedKnnSearch() throws Exception { + final String sortField = "sort"; + final int numThreads = random().nextInt(2, 5); + final int numSearches = atLeast(100); + final int numDocs = atLeast(1000); + final int dimensions = random().nextInt(12, 500); + IndexWriterConfig config = newIndexWriterConfig(); + config.setIndexSort(new Sort(new SortField(sortField, SortField.Type.INT))); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int docCount = 0; docCount < numDocs; docCount++) { + final Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", randomVector(dimensions), VectorSimilarityFunction.EUCLIDEAN)); + doc.add(new NumericDocValuesField(sortField, random().nextInt())); + w.addDocument(doc); + if (random().nextBoolean()) { + w.commit(); + } + } + w.forceMerge(1); + try (IndexReader reader = DirectoryReader.open(w)) { + final AtomicBoolean failed = new AtomicBoolean(); + Thread[] threads = new Thread[numThreads]; + for (int threadID = 0; threadID < numThreads; threadID++) { + threads[threadID] = new Thread(() -> { + try { + long totSearch = 0; + for (; totSearch < numSearches && failed.get() == false; totSearch++) { + float[] vector = randomVector(dimensions); + LeafReader leafReader = getOnlyLeafReader(reader); + leafReader.searchNearestVectors("f", vector, 10, leafReader.getLiveDocs(), Integer.MAX_VALUE); + } + assertTrue(totSearch > 0); + } catch (Exception exc) { + failed.set(true); + throw new RuntimeException(exc); + } + }); + threads[threadID].setDaemon(true); + } + + for (Thread t : threads) { + t.start(); + } + + for (Thread t : threads) { + t.join(); + } + } + } + } }