diff --git a/server/src/main/java/org/elasticsearch/search/vectors/AccessibleVectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/search/vectors/AccessibleVectorSimilarityFloatValueSource.java new file mode 100644 index 0000000000000..6abb913766000 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/AccessibleVectorSimilarityFloatValueSource.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; + +/** + * Subclass of VectorSimilarityFloatValueSource offering access to its members for other classes + * in the same package. + */ +class AccessibleVectorSimilarityFloatValueSource extends VectorSimilarityFloatValueSource { + + String field; + float[] target; + VectorSimilarityFunction vectorSimilarityFunction; + + AccessibleVectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) { + super(field, target, vectorSimilarityFunction); + this.field = field; + this.target = target; + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + public String field() { + return field; + } + + public float[] target() { + return target; + } + + public VectorSimilarityFunction similarityFunction() { + return vectorSimilarityFunction; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BatchVectorSimilarity.java b/server/src/main/java/org/elasticsearch/search/vectors/BatchVectorSimilarity.java new file mode 100644 index 0000000000000..903c8d72bd653 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/BatchVectorSimilarity.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.VectorSimilarityFunction; + +import java.util.Map; + +public final class BatchVectorSimilarity { + + private BatchVectorSimilarity() {} + + public static float[] computeBatchSimilarity( + float[] queryVector, + Map docVectors, + int[] docIds, + VectorSimilarityFunction function + ) { + float[] results = new float[docIds.length]; + float[][] data = organizeSIMDVectors(docVectors, docIds); + + for (int i = 0, l = data.length; i < l; i++) { + float[] docVector = data[i]; + results[i] = function.compare(queryVector, docVector); + } + + return results; + } + + public static float[][] organizeSIMDVectors(Map vectorMap, int[] docIds) { + float[][] vectors = new float[docIds.length][]; + for (int i = 0; i < docIds.length; i++) { + vectors[i] = vectorMap.get(docIds[i]); + } + return vectors; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQuery.java new file mode 100644 index 0000000000000..a90681ca28048 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQuery.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * Enhanced FunctionScoreQuery that enables bulk vector processing for KNN rescoring. + * When provided with a ScoreDoc array, performs bulk vector loading and similarity + * computation instead of individual per-document processing. + */ +public class BulkVectorFunctionScoreQuery extends Query { + + private final Query subQuery; + private final AccessibleVectorSimilarityFloatValueSource valueSource; + private final ScoreDoc[] scoreDocs; + + public BulkVectorFunctionScoreQuery(Query subQuery, AccessibleVectorSimilarityFloatValueSource valueSource, ScoreDoc[] scoreDocs) { + this.subQuery = subQuery; + this.valueSource = valueSource; + this.scoreDocs = scoreDocs; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + // TODO: take a closer look at ScoreMode + Weight subQueryWeight = subQuery.createWeight(searcher, scoreMode, boost); + return new BulkVectorFunctionScoreWeight(this, subQueryWeight, valueSource, scoreDocs); + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + Query rewrittenSubQuery = subQuery.rewrite(searcher); + if (rewrittenSubQuery != subQuery) { + return new BulkVectorFunctionScoreQuery(rewrittenSubQuery, valueSource, scoreDocs); + } + return this; + } + + @Override + public String toString(String field) { + StringBuilder sb = new StringBuilder(); + sb.append("bulk_vector_function_score("); + sb.append(subQuery.toString(field)); + sb.append(", vector_similarity=").append(valueSource.toString()); + if (scoreDocs != null) { + sb.append(", bulk_docs=").append(scoreDocs.length); + } + sb.append(")"); + return sb.toString(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + + BulkVectorFunctionScoreQuery that = (BulkVectorFunctionScoreQuery) obj; + return Objects.equals(subQuery, that.subQuery) + && Objects.equals(valueSource, that.valueSource) + && Arrays.equals(scoreDocs, that.scoreDocs); + } + + @Override + public int hashCode() { + return Objects.hash(subQuery, valueSource, scoreDocs); + } + + @Override + public void visit(org.apache.lucene.search.QueryVisitor visitor) { + subQuery.visit(visitor.getSubVisitor(org.apache.lucene.search.BooleanClause.Occur.MUST, this)); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java new file mode 100644 index 0000000000000..d1dc1fd6c1c99 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Weight implementation that enables bulk vector processing for KNN rescoring queries. + * Extracts segment-specific documents from ScoreDoc array and creates bulk scorers. + */ +public class BulkVectorFunctionScoreWeight extends Weight { + + private final Weight subQueryWeight; + private final AccessibleVectorSimilarityFloatValueSource valueSource; + private final ScoreDoc[] scoreDocs; + + public BulkVectorFunctionScoreWeight( + Query parent, + Weight subQueryWeight, + AccessibleVectorSimilarityFloatValueSource valueSource, + ScoreDoc[] scoreDocs + ) { + super(parent); + this.subQueryWeight = subQueryWeight; + this.valueSource = valueSource; + this.scoreDocs = scoreDocs; + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + ScorerSupplier subQueryScorerSupplier = subQueryWeight.scorerSupplier(context); + if (subQueryScorerSupplier == null) { + return null; + } + + // Extract documents belonging to this segment + int[] segmentDocIds = extractSegmentDocuments(scoreDocs, context); + if (segmentDocIds.length == 0) { + return null; // No documents in this segment + } + + return new ScorerSupplier() { + @Override + public Scorer get(long leadCost) throws IOException { + // if asked for basic Scorer, delegate to the underlying subquery scorer + return subQueryScorerSupplier.get(leadCost); + } + + @Override + public BulkScorer bulkScorer() throws IOException { + // Always use BulkScorer when bulk processing is enabled + BulkScorer subQueryBulkScorer = subQueryScorerSupplier.bulkScorer(); + return new BulkVectorScorer(subQueryBulkScorer, segmentDocIds, valueSource, context); + } + + @Override + public long cost() { + return segmentDocIds.length; + } + }; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + // Find the document in our ScoreDoc array + int globalDocId = doc + context.docBase; + for (ScoreDoc scoreDoc : scoreDocs) { + if (scoreDoc.doc == globalDocId) { + // Compute explanation for this specific document + try { + DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader(); + float[] docVector = batchLoader.loadSingleVector(doc, context, valueSource.field()); + float similarity = valueSource.similarityFunction().compare(valueSource.target(), docVector); + + return Explanation.match( + similarity, + "bulk vector similarity score, computed with vector similarity function: " + valueSource.similarityFunction() + ); + } catch (Exception e) { + return Explanation.noMatch("Failed to compute vector similarity: " + e.getMessage()); + } + } + } + return Explanation.noMatch("Document not in bulk processing set"); + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + + private int[] extractSegmentDocuments(ScoreDoc[] scoreDocs, LeafReaderContext context) { + List segmentDocs = new ArrayList<>(); + int docBase = context.docBase; + int maxDoc = docBase + context.reader().maxDoc(); + + for (ScoreDoc scoreDoc : scoreDocs) { + if (scoreDoc.doc >= docBase && scoreDoc.doc < maxDoc) { + // Convert to segment-relative document ID + segmentDocs.add(scoreDoc.doc - docBase); + } + } + + return segmentDocs.stream().mapToInt(Integer::intValue).toArray(); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorProcessingSettings.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorProcessingSettings.java new file mode 100644 index 0000000000000..21f167ed8e1c6 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorProcessingSettings.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * Feature flags and settings for bulk vector processing optimizations. + */ +public final class BulkVectorProcessingSettings { + + public static final boolean BULK_VECTOR_SCORING = new FeatureFlag("bulk_vector_scoring").isEnabled(); + + public static final int MIN_BULK_PROCESSING_THRESHOLD = 3; + + private BulkVectorProcessingSettings() { + // Utility class + } + + public static boolean shouldUseBulkProcessing(int documentCount) { + return BULK_VECTOR_SCORING && documentCount >= MIN_BULK_PROCESSING_THRESHOLD; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorScorer.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorScorer.java new file mode 100644 index 0000000000000..be898e2db6581 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorScorer.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.util.Bits; + +import java.io.IOException; +import java.util.Map; + +public class BulkVectorScorer extends BulkScorer { + + private final BulkScorer subQueryBulkScorer; + private final int[] segmentDocIds; + private final AccessibleVectorSimilarityFloatValueSource valueSource; + private final LeafReaderContext context; + + // Bulk processing cache + private Map precomputedScores; + private boolean bulkProcessingCompleted = false; + + public BulkVectorScorer( + BulkScorer subQueryBulkScorer, + int[] segmentDocIds, + AccessibleVectorSimilarityFloatValueSource valueSource, + LeafReaderContext context + ) { + this.subQueryBulkScorer = subQueryBulkScorer; + this.segmentDocIds = segmentDocIds; + this.valueSource = valueSource; + this.context = context; + } + + @Override + public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { + // Perform bulk processing once if not already completed + if (bulkProcessingCompleted == false) { + performBulkVectorProcessing(); + bulkProcessingCompleted = true; + } + + // Create bulk-aware collector wrapper + BulkAwareCollector bulkCollector = new BulkAwareCollector(collector, precomputedScores); + + // Delegate to subquery bulk scorer with our bulk-aware collector + return subQueryBulkScorer.score(bulkCollector, acceptDocs, min, max); + } + + @Override + public long cost() { + return segmentDocIds.length; + } + + private void performBulkVectorProcessing() throws IOException { + // batch loading + DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader(); + Map vectorCache = batchLoader.loadSegmentVectors(segmentDocIds, context, valueSource.field()); + + // batch similarity + float[] similarities = BatchVectorSimilarity.computeBatchSimilarity( + valueSource.target(), + vectorCache, + segmentDocIds, + valueSource.similarityFunction() + ); + + // batch scoring + precomputedScores = new java.util.HashMap<>(); + for (int i = 0; i < segmentDocIds.length; i++) { + precomputedScores.put(segmentDocIds[i], similarities[i]); + } + } + + private static class BulkAwareCollector implements LeafCollector { + + private final LeafCollector delegate; + private final Map precomputedScores; + private final BulkProcessedScorable bulkScorable; + + BulkAwareCollector(LeafCollector delegate, Map precomputedScores) { + this.delegate = delegate; + this.precomputedScores = precomputedScores; + this.bulkScorable = new BulkProcessedScorable(); + } + + @Override + public void setScorer(Scorable scorer) throws IOException { + // Set our bulk-aware scorer instead of the original + bulkScorable.setDelegate(scorer); + delegate.setScorer(bulkScorable); + } + + @Override + public void collect(int doc) throws IOException { + // Set the current document for score retrieval + bulkScorable.setCurrentDoc(doc); + delegate.collect(doc); + } + + /** + * Scorable that provides pre-computed bulk scores + */ + private class BulkProcessedScorable extends Scorable { + private Scorable delegate; + private int currentDoc = -1; + + public void setDelegate(Scorable delegate) { + this.delegate = delegate; + } + + public void setCurrentDoc(int doc) { + this.currentDoc = doc; + } + + @Override + public float score() throws IOException { + // Return pre-computed score if available + Float precomputedScore = precomputedScores.get(currentDoc); + if (precomputedScore != null) { + return precomputedScore; + } + + // Fallback to delegate scorer + return delegate != null ? delegate.score() : 0.0f; + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java new file mode 100644 index 0000000000000..9216cf6ab496f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java @@ -0,0 +1,135 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.common.util.Maps; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +/** + * Bulk vector loader that performs optimized I/O operations to load multiple vectors + * simultaneously using parallel random access when possible. + */ +public class DirectIOVectorBatchLoader { + + private static final int BATCH_PER_THREAD = 8; + // TODO: hook into a dedicated thread pool or at least name the virtual threads + private Executor vtExecutor = Executors.newVirtualThreadPerTaskExecutor(); + + public Map loadSegmentVectors(int[] docIds, LeafReaderContext context, String field) throws IOException { + return loadSegmentVectorsParallel(docIds, context, field); + } + + private Map loadSegmentVectorsParallel(int[] docIds, LeafReaderContext context, String field) throws IOException { + FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field); + if (vectorValues == null) { + throw new IllegalArgumentException("No float vector values found for field: " + field); + } + + Map docToOrdinal = buildDocToOrdinalMapping(vectorValues, docIds); + List> batches = createBatches(new ArrayList<>(docToOrdinal.keySet()), BATCH_PER_THREAD); + + List>> futures = new ArrayList<>(); + for (List batch : batches) { + futures.add(CompletableFuture.supplyAsync(() -> { + try { + return loadVectorBatch(vectorValues, batch, docToOrdinal); + } catch (IOException e) { + throw new RuntimeException("Failed to load vector batch", e); + } + }, vtExecutor)); + } + + Map combinedResult = new HashMap<>(); + try { + for (CompletableFuture> future : futures) { + var results = future.get(); + combinedResult.putAll(results); + } + } catch (Exception e) { + ExceptionsHelper.convertToElastic(e); + } + + return combinedResult; + } + + private Map loadVectorBatch( + FloatVectorValues vectorValues, + List docIdBatch, + Map docToOrdinal + ) throws IOException { + + Map batchResult = new HashMap<>(); + + for (Integer docId : docIdBatch) { + Integer ordinal = docToOrdinal.get(docId); + if (ordinal != null) { + // clone the vector since the reader reuses the array + float[] vector = vectorValues.vectorValue(ordinal).clone(); + batchResult.put(docId, vector); + } + } + + return batchResult; + } + + private Map buildDocToOrdinalMapping(FloatVectorValues vectorValues, int[] targetDocIds) throws IOException { + + Map docToOrdinal = Maps.newHashMapWithExpectedSize(targetDocIds.length); + + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int i = 0; i < targetDocIds.length; i++) { + var next = iterator.advance(targetDocIds[i]); + if (next == KnnVectorValues.DocIndexIterator.NO_MORE_DOCS || next != targetDocIds[i]) { + break; + } + docToOrdinal.put(next, iterator.index()); + } + return docToOrdinal; + } + + private List> createBatches(List items, int batchSize) { + List> batches = new ArrayList<>(); + for (int i = 0; i < items.size(); i += batchSize) { + batches.add(items.subList(i, Math.min(i + batchSize, items.size()))); + } + return batches; + } + + /** + * TODO: look into removing this method + */ + public float[] loadSingleVector(int docId, LeafReaderContext context, String field) throws IOException { + FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field); + if (vectorValues == null) { + throw new IllegalArgumentException("No float vector values found for field: " + field); + } + + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + var next = iterator.advance(docId); + float[] result = null; + if (next != KnnVectorValues.DocIndexIterator.NO_MORE_DOCS && next == docId) { + var ordinal = iterator.index(); + result = vectorValues.vectorValue(ordinal).clone(); + } + return result; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index c7346bb9edd75..c2cbac248651a 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -18,7 +18,6 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TopDocs; -import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; @@ -159,12 +158,25 @@ private InlineRescoreQuery( @Override public Query rewrite(IndexSearcher searcher) throws IOException { - var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + var valueSource = new AccessibleVectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + + // Retrieve top k documents from the function score query var functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); // Retrieve top k documents from the function score query var topDocs = searcher.search(functionScoreQuery, k); - vectorOperations = topDocs.totalHits.value(); - return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); + return createQuery(searcher, topDocs, valueSource); + } + + private Query createQuery(IndexSearcher searcher, TopDocs topDocs, AccessibleVectorSimilarityFloatValueSource valueSource) { + var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); + Query query = topDocsQuery; + + // Use bulk processing if feature flag is enabled + if (BulkVectorProcessingSettings.BULK_VECTOR_SCORING) { + // Create bulk-optimized query with ScoreDoc array + query = new BulkVectorFunctionScoreQuery(topDocsQuery, valueSource, topDocs.scoreDocs); + } + return query; } @Override @@ -204,10 +216,15 @@ public Query rewrite(IndexSearcher searcher) throws IOException { // Retrieve top `k` documents from the top `rescoreK` query var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); - var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); - var rescoreQuery = new FunctionScoreQuery(topDocsQuery, valueSource); + var valueSource = new AccessibleVectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + + // Use bulk processing if feature flag is enabled + var rescoreQuery = BulkVectorProcessingSettings.BULK_VECTOR_SCORING + ? new BulkVectorFunctionScoreQuery(topDocsQuery, valueSource, topDocs.scoreDocs) + : new FunctionScoreQuery(topDocsQuery, valueSource); var rescoreTopDocs = searcher.search(rescoreQuery.rewrite(searcher), k); return new KnnScoreDocQuery(rescoreTopDocs.scoreDocs, searcher.getIndexReader()); + } @Override diff --git a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java new file mode 100644 index 0000000000000..3d87ff24c0797 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java @@ -0,0 +1,261 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.MMapDirectory; +import org.elasticsearch.common.CheckedBiConsumer; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; + +public class BulkVectorFunctionScoreQueryTests extends ESTestCase { + + private static final String VECTOR_FIELD = "vector"; + private static final int VECTOR_DIMS = 128; + + public void testBulkProcessingWithScoreDocArray() throws IOException { + // Create test index with vector documents + try (Directory dir = new MMapDirectory(createTempDir())) { + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(dir, config)) { + // Add documents with random vectors + for (int i = 0; i < 100; i++) { + Document doc = new Document(); + float[] vector = randomVector(VECTOR_DIMS); + doc.add(new KnnFloatVectorField(VECTOR_FIELD, vector, VectorSimilarityFunction.COSINE)); + writer.addDocument(doc); + } + writer.commit(); + } + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Create query vector and value source + float[] queryVector = randomVector(VECTOR_DIMS); + var valueSource = new AccessibleVectorSimilarityFloatValueSource( + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.COSINE + ); + + // Get top documents + TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 50); + + // Test bulk vector function score query + BulkVectorFunctionScoreQuery bulkQuery = new BulkVectorFunctionScoreQuery( + new KnnScoreDocQuery(topDocs.scoreDocs, reader), + valueSource, + topDocs.scoreDocs + ); + + TopDocs bulkResults = searcher.search(bulkQuery, 10); + + // Verify results + assertThat("Should return results", bulkResults.totalHits.value(), greaterThan(0L)); + assertThat("Should not exceed requested count", bulkResults.scoreDocs.length, equalTo(10)); + + // Verify scores are computed + for (ScoreDoc scoreDoc : bulkResults.scoreDocs) { + assertTrue("Score should be computed", Float.isFinite(scoreDoc.score)); + assertTrue("Score should be positive for cosine similarity", scoreDoc.score >= 0.0f); + } + } + } + } + + public void testInlineRescoreBulkOptimization() throws IOException { + // Test that InlineRescoreQuery uses bulk processing when feature flag is enabled + float[] queryVector = randomVector(VECTOR_DIMS); + + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 20); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Create inline rescoring by using KnnFloatVectorQuery with matching k and rescoreK + KnnFloatVectorQuery innerQuery = new KnnFloatVectorQuery(VECTOR_FIELD, queryVector, 10); + RescoreKnnVectorQuery rescoreQuery = RescoreKnnVectorQuery.fromInnerQuery( + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.COSINE, + 5, + 10, + innerQuery + ); + + TopDocs results = searcher.search(rescoreQuery, 5); + + assertThat("Should return results from inline rescoring", results.totalHits.value(), greaterThan(0L)); + assertThat("Should return requested count", results.scoreDocs.length, equalTo(5)); + } + } + } + + public void testLateRescoreBulkOptimization() throws IOException { + // Test that LateRescoreQuery uses bulk processing when feature flag is enabled + float[] queryVector = randomVector(VECTOR_DIMS); + + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 50); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Create late rescoring by using different k and rescoreK values + RescoreKnnVectorQuery rescoreQuery = RescoreKnnVectorQuery.fromInnerQuery( + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.COSINE, + 8, + 30, + new MatchAllDocsQuery() + ); + + TopDocs results = searcher.search(rescoreQuery, 8); + + assertThat("Should return results from late rescoring", results.totalHits.value(), greaterThan(0L)); + assertThat("Should return requested count", results.scoreDocs.length, equalTo(8)); + } + } + } + + public void testScoreDocContextPreservation() throws IOException { + // Test that ScoreDoc context is properly maintained through rewrite cycles + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 30); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Create initial ScoreDoc array + TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 15); + ScoreDoc[] originalScoreDocs = topDocs.scoreDocs.clone(); + + // Create bulk query + float[] queryVector = randomVector(VECTOR_DIMS); + var valueSource = new AccessibleVectorSimilarityFloatValueSource( + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.COSINE + ); + BulkVectorFunctionScoreQuery query = new BulkVectorFunctionScoreQuery( + new KnnScoreDocQuery(originalScoreDocs, reader), + valueSource, + originalScoreDocs + ); + + // Test query rewrite preserves context + BulkVectorFunctionScoreQuery rewritten = (BulkVectorFunctionScoreQuery) query.rewrite(searcher); + assertNotNull("Rewritten query should not be null", rewritten); + + // Execute rewritten query + TopDocs results = searcher.search(rewritten, 10); + assertThat("Should return results after rewrite", results.totalHits.value(), greaterThan(0L)); + } + } + } + + private void createTestIndex(Directory dir, int docCount) throws IOException { + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(dir, config)) { + for (int i = 0; i < docCount; i++) { + Document doc = new Document(); + float[] vector = randomVector(VECTOR_DIMS); + doc.add(new KnnFloatVectorField(VECTOR_FIELD, vector, VectorSimilarityFunction.COSINE)); + writer.addDocument(doc); + } + writer.commit(); + } + } + + @SuppressWarnings("unchecked") + public void testParallelVectorLoading() throws IOException { + // Test parallel vector loading functionality + + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 50); + + Object[] results = new Object[2]; + + loadVectors(dir, (leafReaderContext, docIds) -> { + // Load vectors in parallel + DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader(); + results[0] = batchLoader.loadSegmentVectors(docIds, leafReaderContext, VECTOR_FIELD); + }); + + loadVectors(dir, (leafReaderContext, docIds) -> { + // Load vectors in parallel + DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader(); + Map sequentialResult = new HashMap<>(); + for (int docId : docIds) { + sequentialResult.put(docId, batchLoader.loadSingleVector(docId, leafReaderContext, VECTOR_FIELD)); + } + results[1] = sequentialResult; + }); + + var sequentialResult = (Map) results[1]; + var parallelResult = (Map) results[0]; + + // Verify results are identical + assertThat("Parallel and sequential results should have same size", parallelResult.size(), equalTo(sequentialResult.size())); + + for (int docId : sequentialResult.keySet()) { + float[] sequentialVector = sequentialResult.get(docId); + float[] parallelVector = parallelResult.get(docId); + + assertNotNull("Parallel result should contain vector for doc " + docId, parallelVector); + assertNotNull("Sequential result should contain vector for doc " + docId, sequentialVector); + assertArrayEquals("Vectors should be identical for doc " + docId, sequentialVector, parallelVector, 0.0001f); + } + } + } + + private void loadVectors(Directory dir, CheckedBiConsumer consumer) throws IOException { + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Get initial documents + TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 20); + int[] docIds = Arrays.stream(topDocs.scoreDocs).mapToInt(scoreDoc -> scoreDoc.doc).toArray(); + + var leafReaderContext = reader.leaves().get(0); + consumer.accept(leafReaderContext, docIds); + } + } + + private float[] randomVector(int dimensions) { + float[] vector = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + vector[i] = randomFloatBetween(-1.0f, 1.0f, true); + } + return vector; + } +} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorScorerTests.java b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorScorerTests.java new file mode 100644 index 0000000000000..bc06870588d6a --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorScorerTests.java @@ -0,0 +1,191 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.MMapDirectory; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.greaterThan; + +public class BulkVectorScorerTests extends ESTestCase { + + private static final String VECTOR_FIELD = "vector"; + private static final int VECTOR_DIMS = 64; + + public void testBulkScorerImplementation() throws IOException { + // Enable bulk processing for this test + System.setProperty("es.bulk_vector_scoring", "true"); + + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 30); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Get top documents for bulk processing + TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 20); + + // Create bulk vector query + float[] queryVector = randomVector(VECTOR_DIMS); + var valueSource = new AccessibleVectorSimilarityFloatValueSource( + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.COSINE + ); + + BulkVectorFunctionScoreQuery bulkQuery = new BulkVectorFunctionScoreQuery( + new KnnScoreDocQuery(topDocs.scoreDocs, reader), + valueSource, + topDocs.scoreDocs + ); + + Weight weight = bulkQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE, 1.0f); + + LeafReaderContext leafContext = reader.leaves().get(0); + BulkScorer bulkScorer = weight.scorerSupplier(leafContext).bulkScorer(); + + assertNotNull("BulkScorer should be created", bulkScorer); + assertTrue("Should be BulkVectorScorer instance", bulkScorer instanceof BulkVectorScorer); + + // Test bulk scoring execution + TestCollector collector = new TestCollector(); + int result = bulkScorer.score(collector, null, 0, Integer.MAX_VALUE); + + // Verify bulk processing occurred + assertThat("Should collect documents", collector.collectedDocs.size(), greaterThan(0)); + assertTrue("Should have precomputed scores", collector.hasValidScores()); + + } + } finally { + System.clearProperty("es.bulk_vector_scoring"); + } + } + + public void testBulkProcessorCollectorInterception() throws IOException { + // Test that BulkVectorScorer properly intercepts collector calls + System.setProperty("es.bulk_vector_scoring", "true"); + + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 25); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Create bulk query with known documents + TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 15); + float[] queryVector = randomVector(VECTOR_DIMS); + var valueSource = new AccessibleVectorSimilarityFloatValueSource( + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.DOT_PRODUCT + ); + + BulkVectorFunctionScoreQuery bulkQuery = new BulkVectorFunctionScoreQuery( + new KnnScoreDocQuery(topDocs.scoreDocs, reader), + valueSource, + topDocs.scoreDocs + ); + + // Execute with custom collector to verify interception + TestCollector collector = new TestCollector(); + + Weight weight = bulkQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE, 1.0f); + LeafReaderContext leafContext = reader.leaves().get(0); + BulkScorer bulkScorer = weight.scorerSupplier(leafContext).bulkScorer(); + + bulkScorer.score(collector, null, 0, 50); + + // Verify collector received bulk-processed scores + assertTrue("Collector should have received documents", collector.collectedDocs.size() > 0); + + for (CollectedDoc doc : collector.collectedDocs) { + assertTrue("All scores should be finite", Float.isFinite(doc.score)); + // For DOT_PRODUCT, scores can be negative, so just check they're computed + assertNotEquals("Score should not be exactly zero (indicating computation)", 0.0f, doc.score, 0.001f); + } + } + } finally { + System.clearProperty("es.bulk_vector_scoring"); + } + } + + private void createTestIndex(Directory dir, int docCount) throws IOException { + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(dir, config)) { + for (int i = 0; i < docCount; i++) { + Document doc = new Document(); + float[] vector = randomVector(VECTOR_DIMS); + doc.add(new KnnFloatVectorField(VECTOR_FIELD, vector, VectorSimilarityFunction.COSINE)); + writer.addDocument(doc); + } + writer.commit(); + } + } + + private float[] randomVector(int dimensions) { + float[] vector = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + vector[i] = randomFloat() * 2.0f - 1.0f; + } + return vector; + } + + /** + * Test collector that captures documents and scores for verification + */ + private static class TestCollector implements LeafCollector { + final List collectedDocs = new ArrayList<>(); + private Scorable scorer; + + @Override + public void setScorer(Scorable scorer) throws IOException { + this.scorer = scorer; + } + + @Override + public void collect(int doc) throws IOException { + float score = scorer.score(); + collectedDocs.add(new CollectedDoc(doc, score)); + } + + boolean hasValidScores() { + return collectedDocs.stream().allMatch(doc -> Float.isFinite(doc.score)); + } + } + + private static class CollectedDoc { + final int docId; + final float score; + + CollectedDoc(int docId, float score) { + this.docId = docId; + this.score = score; + } + } +}