From 81235e34fc11b3dbe9c6e3975486a72c5131be38 Mon Sep 17 00:00:00 2001 From: Costin Leau Date: Mon, 25 Aug 2025 21:50:00 -0700 Subject: [PATCH 1/6] Bulk vector processing Add processing bulk at various stages of the KNN query: a. BulkVectorFunctionQuery To capture the array of ScoreDocs for bulk processing b. BulkVectorScorer (through dedicated Weight) 1. To load the vectors in bulk through DirectIOVectorBatchLoader 2. Compute the similarity across multiple vectors 3. Store the scores across a batch of docs wip --- ...sibleVectorSimilarityFloatValueSource.java | 45 ++++ .../search/vectors/BatchVectorSimilarity.java | 41 ++++ .../vectors/BulkVectorFunctionScoreQuery.java | 89 ++++++++ .../BulkVectorFunctionScoreWeight.java | 123 +++++++++++ .../vectors/BulkVectorProcessingSettings.java | 30 +++ .../search/vectors/BulkVectorScorer.java | 138 +++++++++++++ .../vectors/DirectIOVectorBatchLoader.java | 83 ++++++++ .../search/vectors/RescoreKnnVectorQuery.java | 28 ++- .../BulkVectorFunctionScoreQueryTests.java | 194 ++++++++++++++++++ .../search/vectors/BulkVectorScorerTests.java | 180 ++++++++++++++++ 10 files changed, 946 insertions(+), 5 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/AccessibleVectorSimilarityFloatValueSource.java create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/BatchVectorSimilarity.java create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQuery.java create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/BulkVectorProcessingSettings.java create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/BulkVectorScorer.java create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java create mode 100644 server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java create mode 100644 server/src/test/java/org/elasticsearch/search/vectors/BulkVectorScorerTests.java diff --git a/server/src/main/java/org/elasticsearch/search/vectors/AccessibleVectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/search/vectors/AccessibleVectorSimilarityFloatValueSource.java new file mode 100644 index 0000000000000..14453ca1160ec --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/AccessibleVectorSimilarityFloatValueSource.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; + +/** + * Subclass of VectorSimilarityFloatValueSource offering access to its members for other classes + * in the same package. + */ +class AccessibleVectorSimilarityFloatValueSource extends VectorSimilarityFloatValueSource { + + String field; + float[] target; + VectorSimilarityFunction vectorSimilarityFunction; + + AccessibleVectorSimilarityFloatValueSource(String field, + float[] target, + VectorSimilarityFunction vectorSimilarityFunction) { + super(field, target, vectorSimilarityFunction); + this.field = field; + this.target = target; + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + public String field() { + return field; + } + + public float[] target() { + return target; + } + + public VectorSimilarityFunction similarityFunction() { + return vectorSimilarityFunction; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BatchVectorSimilarity.java b/server/src/main/java/org/elasticsearch/search/vectors/BatchVectorSimilarity.java new file mode 100644 index 0000000000000..66f58b2eaef8b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/BatchVectorSimilarity.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.VectorSimilarityFunction; + +import java.util.Map; + +public final class BatchVectorSimilarity { + + private BatchVectorSimilarity() { + } + + public static float[] computeBatchSimilarity(float[] queryVector, Map docVectors, + int[] docIds, VectorSimilarityFunction function) { + float[] results = new float[docIds.length]; + float[][] data = organizeSIMDVectors(docVectors, docIds); + + for (int i = 0, l = data.length; i < l; i++) { + float[] docVector = data[i]; + results[i] = function.compare(queryVector, docVector); + } + + return results; + } + + public static float[][] organizeSIMDVectors(Map vectorMap, int[] docIds) { + float[][] vectors = new float[docIds.length][]; + for (int i = 0; i < docIds.length; i++) { + vectors[i] = vectorMap.get(docIds[i]); + } + return vectors; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQuery.java new file mode 100644 index 0000000000000..2ad2cb1913ba4 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQuery.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * Enhanced FunctionScoreQuery that enables bulk vector processing for KNN rescoring. + * When provided with a ScoreDoc array, performs bulk vector loading and similarity + * computation instead of individual per-document processing. + */ +public class BulkVectorFunctionScoreQuery extends Query { + + private final Query subQuery; + private final AccessibleVectorSimilarityFloatValueSource valueSource; + private final ScoreDoc[] scoreDocs; + + public BulkVectorFunctionScoreQuery(Query subQuery, AccessibleVectorSimilarityFloatValueSource valueSource, ScoreDoc[] scoreDocs) { + this.subQuery = subQuery; + this.valueSource = valueSource; + this.scoreDocs = scoreDocs; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + // TODO: take a closer look at ScoreMode + Weight subQueryWeight = subQuery.createWeight(searcher, scoreMode, boost); + return new BulkVectorFunctionScoreWeight(this, subQueryWeight, valueSource, scoreDocs); + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + Query rewrittenSubQuery = subQuery.rewrite(searcher); + if (rewrittenSubQuery != subQuery) { + return new BulkVectorFunctionScoreQuery(rewrittenSubQuery, valueSource, scoreDocs); + } + return this; + } + + @Override + public String toString(String field) { + StringBuilder sb = new StringBuilder(); + sb.append("bulk_vector_function_score("); + sb.append(subQuery.toString(field)); + sb.append(", vector_similarity=").append(valueSource.toString()); + if (scoreDocs != null) { + sb.append(", bulk_docs=").append(scoreDocs.length); + } + sb.append(")"); + return sb.toString(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + + BulkVectorFunctionScoreQuery that = (BulkVectorFunctionScoreQuery) obj; + return Objects.equals(subQuery, that.subQuery) + && Objects.equals(valueSource, that.valueSource) + && Arrays.equals(scoreDocs, that.scoreDocs); + } + + @Override + public int hashCode() { + return Objects.hash(subQuery, valueSource, scoreDocs); + } + + @Override + public void visit(org.apache.lucene.search.QueryVisitor visitor) { + subQuery.visit(visitor.getSubVisitor(org.apache.lucene.search.BooleanClause.Occur.MUST, this)); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java new file mode 100644 index 0000000000000..68b954b581458 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Weight implementation that enables bulk vector processing for KNN rescoring queries. + * Extracts segment-specific documents from ScoreDoc array and creates bulk scorers. + */ +public class BulkVectorFunctionScoreWeight extends Weight { + + private final Weight subQueryWeight; + private final AccessibleVectorSimilarityFloatValueSource valueSource; + private final ScoreDoc[] scoreDocs; + + public BulkVectorFunctionScoreWeight( + Query parent, + Weight subQueryWeight, + AccessibleVectorSimilarityFloatValueSource valueSource, + ScoreDoc[] scoreDocs) { + super(parent); + this.subQueryWeight = subQueryWeight; + this.valueSource = valueSource; + this.scoreDocs = scoreDocs; + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + ScorerSupplier subQueryScorerSupplier = subQueryWeight.scorerSupplier(context); + if (subQueryScorerSupplier == null) { + return null; + } + + // Extract documents belonging to this segment + int[] segmentDocIds = extractSegmentDocuments(scoreDocs, context); + if (segmentDocIds.length == 0) { + return null; // No documents in this segment + } + + return new ScorerSupplier() { + @Override + public Scorer get(long leadCost) throws IOException { + throw new UnsupportedOperationException( + "Individual Scorer not supported when bulk vector processing is enabled. Use bulkScorer() instead."); + } + + @Override + public BulkScorer bulkScorer() throws IOException { + // Always use BulkScorer when bulk processing is enabled + BulkScorer subQueryBulkScorer = subQueryScorerSupplier.bulkScorer(); + return new BulkVectorScorer(subQueryBulkScorer, segmentDocIds, valueSource, context); + } + + @Override + public long cost() { + return segmentDocIds.length; + } + }; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + // Find the document in our ScoreDoc array + int globalDocId = doc + context.docBase; + for (ScoreDoc scoreDoc : scoreDocs) { + if (scoreDoc.doc == globalDocId) { + // Compute explanation for this specific document + try { + DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader(); + float[] docVector = batchLoader.loadSingleVector(doc, context, valueSource.field()); + float similarity = valueSource.similarityFunction().compare(valueSource.target(), docVector); + + return Explanation.match( + similarity, + "bulk vector similarity score, computed with vector similarity function: " + valueSource.similarityFunction() + ); + } catch (Exception e) { + return Explanation.noMatch("Failed to compute vector similarity: " + e.getMessage()); + } + } + } + return Explanation.noMatch("Document not in bulk processing set"); + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + + private int[] extractSegmentDocuments(ScoreDoc[] scoreDocs, LeafReaderContext context) { + List segmentDocs = new ArrayList<>(); + int docBase = context.docBase; + int maxDoc = docBase + context.reader().maxDoc(); + + for (ScoreDoc scoreDoc : scoreDocs) { + if (scoreDoc.doc >= docBase && scoreDoc.doc < maxDoc) { + // Convert to segment-relative document ID + segmentDocs.add(scoreDoc.doc - docBase); + } + } + + return segmentDocs.stream().mapToInt(Integer::intValue).toArray(); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorProcessingSettings.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorProcessingSettings.java new file mode 100644 index 0000000000000..21f167ed8e1c6 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorProcessingSettings.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * Feature flags and settings for bulk vector processing optimizations. + */ +public final class BulkVectorProcessingSettings { + + public static final boolean BULK_VECTOR_SCORING = new FeatureFlag("bulk_vector_scoring").isEnabled(); + + public static final int MIN_BULK_PROCESSING_THRESHOLD = 3; + + private BulkVectorProcessingSettings() { + // Utility class + } + + public static boolean shouldUseBulkProcessing(int documentCount) { + return BULK_VECTOR_SCORING && documentCount >= MIN_BULK_PROCESSING_THRESHOLD; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorScorer.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorScorer.java new file mode 100644 index 0000000000000..518a011b16403 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorScorer.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.util.Bits; + +import java.io.IOException; +import java.util.Map; + +public class BulkVectorScorer extends BulkScorer { + + private final BulkScorer subQueryBulkScorer; + private final int[] segmentDocIds; + private final AccessibleVectorSimilarityFloatValueSource valueSource; + private final LeafReaderContext context; + + // Bulk processing cache + private Map precomputedScores; + private boolean bulkProcessingCompleted = false; + + public BulkVectorScorer( + BulkScorer subQueryBulkScorer, + int[] segmentDocIds, + AccessibleVectorSimilarityFloatValueSource valueSource, + LeafReaderContext context) { + this.subQueryBulkScorer = subQueryBulkScorer; + this.segmentDocIds = segmentDocIds; + this.valueSource = valueSource; + this.context = context; + } + + @Override + public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { + // Perform bulk processing once if not already completed + if (bulkProcessingCompleted == false) { + performBulkVectorProcessing(); + bulkProcessingCompleted = true; + } + + // Create bulk-aware collector wrapper + BulkAwareCollector bulkCollector = new BulkAwareCollector(collector, precomputedScores); + + // Delegate to subquery bulk scorer with our bulk-aware collector + return subQueryBulkScorer.score(bulkCollector, acceptDocs, min, max); + } + + @Override + public long cost() { + return segmentDocIds.length; + } + + private void performBulkVectorProcessing() throws IOException { + // batch loading + DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader(); + Map vectorCache = batchLoader.loadSegmentVectors( + segmentDocIds, context, valueSource.field()); + + // batch similarity + float[] similarities = BatchVectorSimilarity.computeBatchSimilarity( + valueSource.target(), + vectorCache, + segmentDocIds, + valueSource.similarityFunction() + ); + + // batch scoring + precomputedScores = new java.util.HashMap<>(); + for (int i = 0; i < segmentDocIds.length; i++) { + precomputedScores.put(segmentDocIds[i], similarities[i]); + } + } + + private static class BulkAwareCollector implements LeafCollector { + + private final LeafCollector delegate; + private final Map precomputedScores; + private final BulkProcessedScorable bulkScorable; + + BulkAwareCollector(LeafCollector delegate, Map precomputedScores) { + this.delegate = delegate; + this.precomputedScores = precomputedScores; + this.bulkScorable = new BulkProcessedScorable(); + } + + @Override + public void setScorer(Scorable scorer) throws IOException { + // Set our bulk-aware scorer instead of the original + bulkScorable.setDelegate(scorer); + delegate.setScorer(bulkScorable); + } + + @Override + public void collect(int doc) throws IOException { + // Set the current document for score retrieval + bulkScorable.setCurrentDoc(doc); + delegate.collect(doc); + } + + /** + * Scorable that provides pre-computed bulk scores + */ + private class BulkProcessedScorable extends Scorable { + private Scorable delegate; + private int currentDoc = -1; + + public void setDelegate(Scorable delegate) { + this.delegate = delegate; + } + + public void setCurrentDoc(int doc) { + this.currentDoc = doc; + } + + @Override + public float score() throws IOException { + // Return pre-computed score if available + Float precomputedScore = precomputedScores.get(currentDoc); + if (precomputedScore != null) { + return precomputedScore; + } + + // Fallback to delegate scorer + return delegate != null ? delegate.score() : 0.0f; + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java new file mode 100644 index 0000000000000..5eb85d9e5fdde --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReaderContext; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * Bulk vector loader that performs optimized I/O operations to load multiple vectors + * simultaneously using direct I/O and sector alignment when possible. + */ +public class DirectIOVectorBatchLoader { + + private static final int SECTOR_SIZE = 4096; // Default sector size for alignment + + /** + * Loads vectors for multiple document IDs in a single bulk operation. + */ + public Map loadSegmentVectors(int[] docIds, LeafReaderContext context, String field) throws IOException { + Map vectorCache = new HashMap<>(); + + // Get vector values for the field + FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field); + if (vectorValues == null) { + throw new IllegalArgumentException("No float vector values found for field: " + field); + } + + // TODO: For now, use sequential access - future optimization can implement true bulk I/O + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + + // Build a lookup of available documents + Map docToIndex = new HashMap<>(); + for (int docId = iterator.nextDoc(); docId != KnnVectorValues.DocIndexIterator.NO_MORE_DOCS; docId = iterator.nextDoc()) { + docToIndex.put(docId, iterator.index()); + } + + // Load vectors for requested documents + for (int docId : docIds) { + Integer vectorIndex = docToIndex.get(docId); + if (vectorIndex != null) { + float[] vector = vectorValues.vectorValue(vectorIndex); + if (vector != null) { + vectorCache.put(docId, vector); + } + } + } + + return vectorCache; + } + + /** + * TODO: look into removing this method + */ + public float[] loadSingleVector(int docId, LeafReaderContext context, String field) throws IOException { + FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field); + if (vectorValues == null) { + throw new IllegalArgumentException("No float vector values found for field: " + field); + } + + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int currentDoc = iterator.nextDoc(); currentDoc != KnnVectorValues.DocIndexIterator.NO_MORE_DOCS; currentDoc = iterator + .nextDoc()) { + if (currentDoc == docId) { + float[] vector = vectorValues.vectorValue(iterator.index()); + return vector != null ? vector: null; + } + } + + throw new IllegalArgumentException("Document " + docId + " not found in vector values"); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index c7346bb9edd75..22a9296d412ef 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -159,12 +159,25 @@ private InlineRescoreQuery( @Override public Query rewrite(IndexSearcher searcher) throws IOException { - var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + var valueSource = new AccessibleVectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + + // Retrieve top k documents from the function score query var functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); // Retrieve top k documents from the function score query var topDocs = searcher.search(functionScoreQuery, k); - vectorOperations = topDocs.totalHits.value(); - return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); + return createQuery(searcher, topDocs, valueSource); + } + + private Query createQuery(IndexSearcher searcher, TopDocs topDocs, AccessibleVectorSimilarityFloatValueSource valueSource) { + var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); + Query query = topDocsQuery; + + // Use bulk processing if feature flag is enabled + if (BulkVectorProcessingSettings.BULK_VECTOR_SCORING) { + // Create bulk-optimized query with ScoreDoc array + query = new BulkVectorFunctionScoreQuery(topDocsQuery, valueSource, topDocs.scoreDocs); + } + return query; } @Override @@ -204,10 +217,15 @@ public Query rewrite(IndexSearcher searcher) throws IOException { // Retrieve top `k` documents from the top `rescoreK` query var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); - var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); - var rescoreQuery = new FunctionScoreQuery(topDocsQuery, valueSource); + var valueSource = new AccessibleVectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + + // Use bulk processing if feature flag is enabled + var rescoreQuery = BulkVectorProcessingSettings.BULK_VECTOR_SCORING + ? new BulkVectorFunctionScoreQuery(topDocsQuery, valueSource, topDocs.scoreDocs) + : new FunctionScoreQuery(topDocsQuery, valueSource); var rescoreTopDocs = searcher.search(rescoreQuery.rewrite(searcher), k); return new KnnScoreDocQuery(rescoreTopDocs.scoreDocs, searcher.getIndexReader()); + } @Override diff --git a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java new file mode 100644 index 0000000000000..ef9b5b289c7a9 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java @@ -0,0 +1,194 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.MMapDirectory; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; + +public class BulkVectorFunctionScoreQueryTests extends ESTestCase { + + private static final String VECTOR_FIELD = "vector"; + private static final int VECTOR_DIMS = 128; + + public void testBulkProcessingWithScoreDocArray() throws IOException { + // Create test index with vector documents + try (Directory dir = new MMapDirectory(createTempDir())) { + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(dir, config)) { + // Add documents with random vectors + for (int i = 0; i < 100; i++) { + Document doc = new Document(); + float[] vector = randomVector(VECTOR_DIMS); + doc.add(new KnnFloatVectorField(VECTOR_FIELD, vector, VectorSimilarityFunction.COSINE)); + writer.addDocument(doc); + } + writer.commit(); + } + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Create query vector and value source + float[] queryVector = randomVector(VECTOR_DIMS); + var valueSource = new AccessibleVectorSimilarityFloatValueSource( + VECTOR_FIELD, queryVector, VectorSimilarityFunction.COSINE); + + // Get top documents + TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 50); + + // Test bulk vector function score query + BulkVectorFunctionScoreQuery bulkQuery = new BulkVectorFunctionScoreQuery( + new KnnScoreDocQuery(topDocs.scoreDocs, reader), valueSource, topDocs.scoreDocs); + + TopDocs bulkResults = searcher.search(bulkQuery, 10); + + // Verify results + assertThat("Should return results", bulkResults.totalHits.value(), greaterThan(0L)); + assertThat("Should not exceed requested count", bulkResults.scoreDocs.length, equalTo(10)); + + // Verify scores are computed + for (ScoreDoc scoreDoc : bulkResults.scoreDocs) { + assertTrue("Score should be computed", Float.isFinite(scoreDoc.score)); + assertTrue("Score should be positive for cosine similarity", scoreDoc.score >= 0.0f); + } + } + } + } + + public void testInlineRescoreBulkOptimization() throws IOException { + // Test that InlineRescoreQuery uses bulk processing when feature flag is enabled + float[] queryVector = randomVector(VECTOR_DIMS); + + // Temporarily enable feature flag for testing + boolean originalFlag = BulkVectorProcessingSettings.BULK_VECTOR_SCORING; + System.setProperty("es.bulk_vector_scoring", "true"); + + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 20); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Create inline rescoring by using KnnFloatVectorQuery with matching k and rescoreK + KnnFloatVectorQuery innerQuery = new KnnFloatVectorQuery(VECTOR_FIELD, queryVector, 10); + RescoreKnnVectorQuery rescoreQuery = RescoreKnnVectorQuery.fromInnerQuery( + VECTOR_FIELD, queryVector, VectorSimilarityFunction.COSINE, 5, 10, innerQuery); + + TopDocs results = searcher.search(rescoreQuery, 5); + + assertThat("Should return results from inline rescoring", results.totalHits.value(), greaterThan(0L)); + assertThat("Should return requested count", results.scoreDocs.length, equalTo(5)); + } + } finally { + // Restore original flag state + if (originalFlag) { + System.setProperty("es.bulk_vector_scoring", "true"); + } else { + System.clearProperty("es.bulk_vector_scoring"); + } + } + } + + public void testLateRescoreBulkOptimization() throws IOException { + // Test that LateRescoreQuery uses bulk processing when feature flag is enabled + float[] queryVector = randomVector(VECTOR_DIMS); + + // Temporarily enable feature flag for testing + System.setProperty("es.bulk_vector_scoring", "true"); + + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 50); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Create late rescoring by using different k and rescoreK values + RescoreKnnVectorQuery rescoreQuery = RescoreKnnVectorQuery.fromInnerQuery( + VECTOR_FIELD, queryVector, VectorSimilarityFunction.COSINE, 8, 30, new MatchAllDocsQuery()); + + TopDocs results = searcher.search(rescoreQuery, 8); + + assertThat("Should return results from late rescoring", results.totalHits.value(), greaterThan(0L)); + assertThat("Should return requested count", results.scoreDocs.length, equalTo(8)); + } + } finally { + System.clearProperty("es.bulk_vector_scoring"); + } + } + + public void testScoreDocContextPreservation() throws IOException { + // Test that ScoreDoc context is properly maintained through rewrite cycles + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 30); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Create initial ScoreDoc array + TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 15); + ScoreDoc[] originalScoreDocs = topDocs.scoreDocs.clone(); + + // Create bulk query + float[] queryVector = randomVector(VECTOR_DIMS); + var valueSource = new AccessibleVectorSimilarityFloatValueSource( + VECTOR_FIELD, queryVector, VectorSimilarityFunction.COSINE); + BulkVectorFunctionScoreQuery query = new BulkVectorFunctionScoreQuery( + new KnnScoreDocQuery(originalScoreDocs, reader), valueSource, originalScoreDocs); + + // Test query rewrite preserves context + BulkVectorFunctionScoreQuery rewritten = (BulkVectorFunctionScoreQuery) query.rewrite(searcher); + assertNotNull("Rewritten query should not be null", rewritten); + + // Execute rewritten query + TopDocs results = searcher.search(rewritten, 10); + assertThat("Should return results after rewrite", results.totalHits.value(), greaterThan(0L)); + } + } + } + + private void createTestIndex(Directory dir, int docCount) throws IOException { + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(dir, config)) { + for (int i = 0; i < docCount; i++) { + Document doc = new Document(); + float[] vector = randomVector(VECTOR_DIMS); + doc.add(new KnnFloatVectorField(VECTOR_FIELD, vector, VectorSimilarityFunction.COSINE)); + writer.addDocument(doc); + } + writer.commit(); + } + } + + private float[] randomVector(int dimensions) { + float[] vector = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + vector[i] = randomFloat() * 2.0f - 1.0f; // Range [-1, 1] + } + return vector; + } +} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorScorerTests.java b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorScorerTests.java new file mode 100644 index 0000000000000..6847154138827 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorScorerTests.java @@ -0,0 +1,180 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.MMapDirectory; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.greaterThan; + +public class BulkVectorScorerTests extends ESTestCase { + + private static final String VECTOR_FIELD = "vector"; + private static final int VECTOR_DIMS = 64; + + public void testBulkScorerImplementation() throws IOException { + // Enable bulk processing for this test + System.setProperty("es.bulk_vector_scoring", "true"); + + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 30); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Get top documents for bulk processing + TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 20); + + // Create bulk vector query + float[] queryVector = randomVector(VECTOR_DIMS); + var valueSource = new AccessibleVectorSimilarityFloatValueSource( + VECTOR_FIELD, queryVector, VectorSimilarityFunction.COSINE); + + BulkVectorFunctionScoreQuery bulkQuery = new BulkVectorFunctionScoreQuery( + new KnnScoreDocQuery(topDocs.scoreDocs, reader), valueSource, topDocs.scoreDocs); + + Weight weight = bulkQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE, 1.0f); + + LeafReaderContext leafContext = reader.leaves().get(0); + BulkScorer bulkScorer = weight.scorerSupplier(leafContext).bulkScorer(); + + assertNotNull("BulkScorer should be created", bulkScorer); + assertTrue("Should be BulkVectorScorer instance", + bulkScorer instanceof BulkVectorScorer); + + // Test bulk scoring execution + TestCollector collector = new TestCollector(); + int result = bulkScorer.score(collector, null, 0, Integer.MAX_VALUE); + + // Verify bulk processing occurred + assertThat("Should collect documents", collector.collectedDocs.size(), greaterThan(0)); + assertTrue("Should have precomputed scores", collector.hasValidScores()); + + } + } finally { + System.clearProperty("es.bulk_vector_scoring"); + } + } + + public void testBulkProcessorCollectorInterception() throws IOException { + // Test that BulkVectorScorer properly intercepts collector calls + System.setProperty("es.bulk_vector_scoring", "true"); + + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 25); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Create bulk query with known documents + TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 15); + float[] queryVector = randomVector(VECTOR_DIMS); + var valueSource = new AccessibleVectorSimilarityFloatValueSource( + VECTOR_FIELD, queryVector, VectorSimilarityFunction.DOT_PRODUCT); + + BulkVectorFunctionScoreQuery bulkQuery = new BulkVectorFunctionScoreQuery( + new KnnScoreDocQuery(topDocs.scoreDocs, reader), valueSource, topDocs.scoreDocs); + + // Execute with custom collector to verify interception + TestCollector collector = new TestCollector(); + + Weight weight = bulkQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE, 1.0f); + LeafReaderContext leafContext = reader.leaves().get(0); + BulkScorer bulkScorer = weight.scorerSupplier(leafContext).bulkScorer(); + + bulkScorer.score(collector, null, 0, 50); + + // Verify collector received bulk-processed scores + assertTrue("Collector should have received documents", collector.collectedDocs.size() > 0); + + for (CollectedDoc doc : collector.collectedDocs) { + assertTrue("All scores should be finite", Float.isFinite(doc.score)); + // For DOT_PRODUCT, scores can be negative, so just check they're computed + assertNotEquals("Score should not be exactly zero (indicating computation)", 0.0f, doc.score, 0.001f); + } + } + } finally { + System.clearProperty("es.bulk_vector_scoring"); + } + } + + private void createTestIndex(Directory dir, int docCount) throws IOException { + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(dir, config)) { + for (int i = 0; i < docCount; i++) { + Document doc = new Document(); + float[] vector = randomVector(VECTOR_DIMS); + doc.add(new KnnFloatVectorField(VECTOR_FIELD, vector, VectorSimilarityFunction.COSINE)); + writer.addDocument(doc); + } + writer.commit(); + } + } + + private float[] randomVector(int dimensions) { + float[] vector = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + vector[i] = randomFloat() * 2.0f - 1.0f; + } + return vector; + } + + /** + * Test collector that captures documents and scores for verification + */ + private static class TestCollector implements LeafCollector { + final List collectedDocs = new ArrayList<>(); + private Scorable scorer; + + @Override + public void setScorer(Scorable scorer) throws IOException { + this.scorer = scorer; + } + + @Override + public void collect(int doc) throws IOException { + float score = scorer.score(); + collectedDocs.add(new CollectedDoc(doc, score)); + } + + boolean hasValidScores() { + return collectedDocs.stream().allMatch(doc -> Float.isFinite(doc.score)); + } + } + + private static class CollectedDoc { + final int docId; + final float score; + + CollectedDoc(int docId, float score) { + this.docId = docId; + this.score = score; + } + } +} From 83a6902501dada375e2e0a8686ce641bfddcd6e9 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 28 Aug 2025 02:39:32 +0000 Subject: [PATCH 2/6] [CI] Auto commit changes from spotless --- ...sibleVectorSimilarityFloatValueSource.java | 4 +- .../search/vectors/BatchVectorSimilarity.java | 13 ++++--- .../vectors/BulkVectorFunctionScoreQuery.java | 1 - .../BulkVectorFunctionScoreWeight.java | 6 ++- .../search/vectors/BulkVectorScorer.java | 8 ++-- .../vectors/DirectIOVectorBatchLoader.java | 2 +- .../search/vectors/RescoreKnnVectorQuery.java | 1 - .../BulkVectorFunctionScoreQueryTests.java | 37 +++++++++++++++---- .../search/vectors/BulkVectorScorerTests.java | 23 +++++++++--- 9 files changed, 65 insertions(+), 30 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/AccessibleVectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/search/vectors/AccessibleVectorSimilarityFloatValueSource.java index 14453ca1160ec..6abb913766000 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/AccessibleVectorSimilarityFloatValueSource.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/AccessibleVectorSimilarityFloatValueSource.java @@ -22,9 +22,7 @@ class AccessibleVectorSimilarityFloatValueSource extends VectorSimilarityFloatVa float[] target; VectorSimilarityFunction vectorSimilarityFunction; - AccessibleVectorSimilarityFloatValueSource(String field, - float[] target, - VectorSimilarityFunction vectorSimilarityFunction) { + AccessibleVectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) { super(field, target, vectorSimilarityFunction); this.field = field; this.target = target; diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BatchVectorSimilarity.java b/server/src/main/java/org/elasticsearch/search/vectors/BatchVectorSimilarity.java index 66f58b2eaef8b..903c8d72bd653 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/BatchVectorSimilarity.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/BatchVectorSimilarity.java @@ -15,11 +15,14 @@ public final class BatchVectorSimilarity { - private BatchVectorSimilarity() { - } - - public static float[] computeBatchSimilarity(float[] queryVector, Map docVectors, - int[] docIds, VectorSimilarityFunction function) { + private BatchVectorSimilarity() {} + + public static float[] computeBatchSimilarity( + float[] queryVector, + Map docVectors, + int[] docIds, + VectorSimilarityFunction function + ) { float[] results = new float[docIds.length]; float[][] data = organizeSIMDVectors(docVectors, docIds); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQuery.java index 2ad2cb1913ba4..a90681ca28048 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQuery.java @@ -14,7 +14,6 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Weight; -import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import java.io.IOException; import java.util.Arrays; diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java index 68b954b581458..6d2f5b09927a2 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java @@ -36,7 +36,8 @@ public BulkVectorFunctionScoreWeight( Query parent, Weight subQueryWeight, AccessibleVectorSimilarityFloatValueSource valueSource, - ScoreDoc[] scoreDocs) { + ScoreDoc[] scoreDocs + ) { super(parent); this.subQueryWeight = subQueryWeight; this.valueSource = valueSource; @@ -60,7 +61,8 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti @Override public Scorer get(long leadCost) throws IOException { throw new UnsupportedOperationException( - "Individual Scorer not supported when bulk vector processing is enabled. Use bulkScorer() instead."); + "Individual Scorer not supported when bulk vector processing is enabled. Use bulkScorer() instead." + ); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorScorer.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorScorer.java index 518a011b16403..be898e2db6581 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorScorer.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorScorer.java @@ -31,9 +31,10 @@ public class BulkVectorScorer extends BulkScorer { public BulkVectorScorer( BulkScorer subQueryBulkScorer, - int[] segmentDocIds, + int[] segmentDocIds, AccessibleVectorSimilarityFloatValueSource valueSource, - LeafReaderContext context) { + LeafReaderContext context + ) { this.subQueryBulkScorer = subQueryBulkScorer; this.segmentDocIds = segmentDocIds; this.valueSource = valueSource; @@ -63,8 +64,7 @@ public long cost() { private void performBulkVectorProcessing() throws IOException { // batch loading DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader(); - Map vectorCache = batchLoader.loadSegmentVectors( - segmentDocIds, context, valueSource.field()); + Map vectorCache = batchLoader.loadSegmentVectors(segmentDocIds, context, valueSource.field()); // batch similarity float[] similarities = BatchVectorSimilarity.computeBatchSimilarity( diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java index 5eb85d9e5fdde..254acfb762e86 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java @@ -74,7 +74,7 @@ public float[] loadSingleVector(int docId, LeafReaderContext context, String fie .nextDoc()) { if (currentDoc == docId) { float[] vector = vectorValues.vectorValue(iterator.index()); - return vector != null ? vector: null; + return vector != null ? vector : null; } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 22a9296d412ef..c2cbac248651a 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -18,7 +18,6 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TopDocs; -import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; diff --git a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java index ef9b5b289c7a9..71dd3b2713eec 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java @@ -22,7 +22,6 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.MMapDirectory; -import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -56,14 +55,20 @@ public void testBulkProcessingWithScoreDocArray() throws IOException { // Create query vector and value source float[] queryVector = randomVector(VECTOR_DIMS); var valueSource = new AccessibleVectorSimilarityFloatValueSource( - VECTOR_FIELD, queryVector, VectorSimilarityFunction.COSINE); + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.COSINE + ); // Get top documents TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 50); // Test bulk vector function score query BulkVectorFunctionScoreQuery bulkQuery = new BulkVectorFunctionScoreQuery( - new KnnScoreDocQuery(topDocs.scoreDocs, reader), valueSource, topDocs.scoreDocs); + new KnnScoreDocQuery(topDocs.scoreDocs, reader), + valueSource, + topDocs.scoreDocs + ); TopDocs bulkResults = searcher.search(bulkQuery, 10); @@ -97,7 +102,13 @@ public void testInlineRescoreBulkOptimization() throws IOException { // Create inline rescoring by using KnnFloatVectorQuery with matching k and rescoreK KnnFloatVectorQuery innerQuery = new KnnFloatVectorQuery(VECTOR_FIELD, queryVector, 10); RescoreKnnVectorQuery rescoreQuery = RescoreKnnVectorQuery.fromInnerQuery( - VECTOR_FIELD, queryVector, VectorSimilarityFunction.COSINE, 5, 10, innerQuery); + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.COSINE, + 5, + 10, + innerQuery + ); TopDocs results = searcher.search(rescoreQuery, 5); @@ -129,7 +140,13 @@ public void testLateRescoreBulkOptimization() throws IOException { // Create late rescoring by using different k and rescoreK values RescoreKnnVectorQuery rescoreQuery = RescoreKnnVectorQuery.fromInnerQuery( - VECTOR_FIELD, queryVector, VectorSimilarityFunction.COSINE, 8, 30, new MatchAllDocsQuery()); + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.COSINE, + 8, + 30, + new MatchAllDocsQuery() + ); TopDocs results = searcher.search(rescoreQuery, 8); @@ -156,9 +173,15 @@ public void testScoreDocContextPreservation() throws IOException { // Create bulk query float[] queryVector = randomVector(VECTOR_DIMS); var valueSource = new AccessibleVectorSimilarityFloatValueSource( - VECTOR_FIELD, queryVector, VectorSimilarityFunction.COSINE); + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.COSINE + ); BulkVectorFunctionScoreQuery query = new BulkVectorFunctionScoreQuery( - new KnnScoreDocQuery(originalScoreDocs, reader), valueSource, originalScoreDocs); + new KnnScoreDocQuery(originalScoreDocs, reader), + valueSource, + originalScoreDocs + ); // Test query rewrite preserves context BulkVectorFunctionScoreQuery rewritten = (BulkVectorFunctionScoreQuery) query.rewrite(searcher); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorScorerTests.java b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorScorerTests.java index 6847154138827..bc06870588d6a 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorScorerTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorScorerTests.java @@ -54,10 +54,16 @@ public void testBulkScorerImplementation() throws IOException { // Create bulk vector query float[] queryVector = randomVector(VECTOR_DIMS); var valueSource = new AccessibleVectorSimilarityFloatValueSource( - VECTOR_FIELD, queryVector, VectorSimilarityFunction.COSINE); + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.COSINE + ); BulkVectorFunctionScoreQuery bulkQuery = new BulkVectorFunctionScoreQuery( - new KnnScoreDocQuery(topDocs.scoreDocs, reader), valueSource, topDocs.scoreDocs); + new KnnScoreDocQuery(topDocs.scoreDocs, reader), + valueSource, + topDocs.scoreDocs + ); Weight weight = bulkQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE, 1.0f); @@ -65,8 +71,7 @@ public void testBulkScorerImplementation() throws IOException { BulkScorer bulkScorer = weight.scorerSupplier(leafContext).bulkScorer(); assertNotNull("BulkScorer should be created", bulkScorer); - assertTrue("Should be BulkVectorScorer instance", - bulkScorer instanceof BulkVectorScorer); + assertTrue("Should be BulkVectorScorer instance", bulkScorer instanceof BulkVectorScorer); // Test bulk scoring execution TestCollector collector = new TestCollector(); @@ -96,10 +101,16 @@ public void testBulkProcessorCollectorInterception() throws IOException { TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 15); float[] queryVector = randomVector(VECTOR_DIMS); var valueSource = new AccessibleVectorSimilarityFloatValueSource( - VECTOR_FIELD, queryVector, VectorSimilarityFunction.DOT_PRODUCT); + VECTOR_FIELD, + queryVector, + VectorSimilarityFunction.DOT_PRODUCT + ); BulkVectorFunctionScoreQuery bulkQuery = new BulkVectorFunctionScoreQuery( - new KnnScoreDocQuery(topDocs.scoreDocs, reader), valueSource, topDocs.scoreDocs); + new KnnScoreDocQuery(topDocs.scoreDocs, reader), + valueSource, + topDocs.scoreDocs + ); // Execute with custom collector to verify interception TestCollector collector = new TestCollector(); From 11b94cc3c975c7f17332af85d8c4860b5b9fe657 Mon Sep 17 00:00:00 2001 From: Costin Leau Date: Tue, 2 Sep 2025 12:44:39 -0700 Subject: [PATCH 3/6] Add vector parallel loading The query test fails with a race condition, might be related to the vector source --- .../BulkVectorFunctionScoreWeight.java | 5 +- .../vectors/DirectIOVectorBatchLoader.java | 109 ++++++++++++++---- .../BulkVectorFunctionScoreQueryTests.java | 69 ++++++++--- 3 files changed, 142 insertions(+), 41 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java index 6d2f5b09927a2..d1dc1fd6c1c99 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreWeight.java @@ -60,9 +60,8 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti return new ScorerSupplier() { @Override public Scorer get(long leadCost) throws IOException { - throw new UnsupportedOperationException( - "Individual Scorer not supported when bulk vector processing is enabled. Use bulkScorer() instead." - ); + // if asked for basic Scorer, delegate to the underlying subquery scorer + return subQueryScorerSupplier.get(leadCost); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java index 254acfb762e86..03b23b70f4148 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java @@ -12,52 +12,117 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReaderContext; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.common.util.set.Sets; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; /** * Bulk vector loader that performs optimized I/O operations to load multiple vectors - * simultaneously using direct I/O and sector alignment when possible. + * simultaneously using parallel random access when possible. */ public class DirectIOVectorBatchLoader { - private static final int SECTOR_SIZE = 4096; // Default sector size for alignment + private static final int BATCH_PER_THREAD = 8; + // TODO: hook into a dedicated thread pool or at least name the virtual threads + private Executor vtExecutor = Executors.newVirtualThreadPerTaskExecutor(); - /** - * Loads vectors for multiple document IDs in a single bulk operation. - */ public Map loadSegmentVectors(int[] docIds, LeafReaderContext context, String field) throws IOException { - Map vectorCache = new HashMap<>(); + return loadSegmentVectorsParallel(docIds, context, field); + } - // Get vector values for the field + private Map loadSegmentVectorsParallel(int[] docIds, LeafReaderContext context, String field) throws IOException { FloatVectorValues vectorValues = context.reader().getFloatVectorValues(field); if (vectorValues == null) { throw new IllegalArgumentException("No float vector values found for field: " + field); } - // TODO: For now, use sequential access - future optimization can implement true bulk I/O - KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + Map docToOrdinal = buildDocToOrdinalMapping(vectorValues, docIds); + List> batches = createBatches(new ArrayList<>(docToOrdinal.keySet()), BATCH_PER_THREAD); - // Build a lookup of available documents - Map docToIndex = new HashMap<>(); - for (int docId = iterator.nextDoc(); docId != KnnVectorValues.DocIndexIterator.NO_MORE_DOCS; docId = iterator.nextDoc()) { - docToIndex.put(docId, iterator.index()); + List>> futures = new ArrayList<>(); + for (List batch : batches) { + futures.add(CompletableFuture.supplyAsync(() -> { + try { + return loadVectorBatch(vectorValues, batch, docToOrdinal); + } catch (IOException e) { + throw new RuntimeException("Failed to load vector batch", e); + } + }, vtExecutor)); + } + + Map combinedResult = new HashMap<>(); + try { + for (CompletableFuture> future : futures) { + var results = future.get(); + combinedResult.putAll(results); + } + } catch (Exception e) { + ExceptionsHelper.convertToElastic(e); + } + + return combinedResult; + } + + private Map loadVectorBatch( + FloatVectorValues vectorValues, + List docIdBatch, + Map docToOrdinal) throws IOException { + + Map batchResult = new HashMap<>(); + + for (Integer docId : docIdBatch) { + Integer ordinal = docToOrdinal.get(docId); + if (ordinal != null) { + // clone the vector since the reader reuses the array + float[] vector = vectorValues.vectorValue(ordinal).clone(); + batchResult.put(docId, vector); + } + } + + return batchResult; + } + + private Map buildDocToOrdinalMapping( + FloatVectorValues vectorValues, + int[] targetDocIds) throws IOException { + + Map docToOrdinal = new HashMap<>(); + + Set targetDocSet = Sets.newHashSetWithExpectedSize(targetDocIds.length); + for (int docId : targetDocIds) { + targetDocSet.add(docId); } - // Load vectors for requested documents - for (int docId : docIds) { - Integer vectorIndex = docToIndex.get(docId); - if (vectorIndex != null) { - float[] vector = vectorValues.vectorValue(vectorIndex); - if (vector != null) { - vectorCache.put(docId, vector); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int docId = iterator.nextDoc(); docId != KnnVectorValues.DocIndexIterator.NO_MORE_DOCS; docId = iterator.nextDoc()) { + if (targetDocSet.contains(docId)) { // Only map docs we actually need + docToOrdinal.put(docId, iterator.index()); + + // Early termination when all target docs found + if (docToOrdinal.size() == targetDocSet.size()) { + break; } } } - return vectorCache; + return docToOrdinal; + } + + private List> createBatches(List items, int batchSize) { + List> batches = new ArrayList<>(); + for (int i = 0; i < items.size(); i += batchSize) { + batches.add(items.subList(i, Math.min(i + batchSize, items.size()))); + } + return batches; } /** @@ -73,7 +138,7 @@ public float[] loadSingleVector(int docId, LeafReaderContext context, String fie for (int currentDoc = iterator.nextDoc(); currentDoc != KnnVectorValues.DocIndexIterator.NO_MORE_DOCS; currentDoc = iterator .nextDoc()) { if (currentDoc == docId) { - float[] vector = vectorValues.vectorValue(iterator.index()); + float[] vector = vectorValues.vectorValue(iterator.index()).clone(); return vector != null ? vector : null; } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java index 71dd3b2713eec..80d7db6639161 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java @@ -25,6 +25,9 @@ import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -33,6 +36,7 @@ public class BulkVectorFunctionScoreQueryTests extends ESTestCase { private static final String VECTOR_FIELD = "vector"; private static final int VECTOR_DIMS = 128; + public static final String BULK_VECTOR_SCORING = "es.bulk_vector_scoring"; public void testBulkProcessingWithScoreDocArray() throws IOException { // Create test index with vector documents @@ -89,10 +93,6 @@ public void testInlineRescoreBulkOptimization() throws IOException { // Test that InlineRescoreQuery uses bulk processing when feature flag is enabled float[] queryVector = randomVector(VECTOR_DIMS); - // Temporarily enable feature flag for testing - boolean originalFlag = BulkVectorProcessingSettings.BULK_VECTOR_SCORING; - System.setProperty("es.bulk_vector_scoring", "true"); - try (Directory dir = new MMapDirectory(createTempDir())) { createTestIndex(dir, 20); @@ -115,13 +115,6 @@ public void testInlineRescoreBulkOptimization() throws IOException { assertThat("Should return results from inline rescoring", results.totalHits.value(), greaterThan(0L)); assertThat("Should return requested count", results.scoreDocs.length, equalTo(5)); } - } finally { - // Restore original flag state - if (originalFlag) { - System.setProperty("es.bulk_vector_scoring", "true"); - } else { - System.clearProperty("es.bulk_vector_scoring"); - } } } @@ -129,9 +122,6 @@ public void testLateRescoreBulkOptimization() throws IOException { // Test that LateRescoreQuery uses bulk processing when feature flag is enabled float[] queryVector = randomVector(VECTOR_DIMS); - // Temporarily enable feature flag for testing - System.setProperty("es.bulk_vector_scoring", "true"); - try (Directory dir = new MMapDirectory(createTempDir())) { createTestIndex(dir, 50); @@ -153,8 +143,6 @@ public void testLateRescoreBulkOptimization() throws IOException { assertThat("Should return results from late rescoring", results.totalHits.value(), greaterThan(0L)); assertThat("Should return requested count", results.scoreDocs.length, equalTo(8)); } - } finally { - System.clearProperty("es.bulk_vector_scoring"); } } @@ -207,6 +195,55 @@ private void createTestIndex(Directory dir, int docCount) throws IOException { } } + public void testParallelVectorLoading() throws IOException { + // Test parallel vector loading functionality + float[] queryVector = randomVector(VECTOR_DIMS); + + try (Directory dir = new MMapDirectory(createTempDir())) { + createTestIndex(dir, 50); + + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Get initial documents + TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 20); + int[] docIds = Arrays.stream(topDocs.scoreDocs) + .mapToInt(scoreDoc -> scoreDoc.doc) + .toArray(); + + // Test parallel loading + DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader(); + + Map parallelResult = batchLoader.loadSegmentVectors( + docIds, + reader.leaves().get(0), + VECTOR_FIELD + ); + + // use regular vector loader + Map sequentialResult = new HashMap<>(); + for (int docId : docIds) { + sequentialResult.put(docId, batchLoader.loadSingleVector(docId, reader.leaves().get(0), VECTOR_FIELD)); + } + + // Verify results are identical + assertThat( + "Parallel and sequential results should have same size", + parallelResult.size(), equalTo(sequentialResult.size())); + + for (int docId : docIds) { + float[] parallelVector = parallelResult.get(docId); + float[] sequentialVector = sequentialResult.get(docId); + + assertNotNull("Parallel result should contain vector for doc " + docId, parallelVector); + assertNotNull("Sequential result should contain vector for doc " + docId, sequentialVector); + assertArrayEquals("Vectors should be identical for doc " + docId, + sequentialVector, parallelVector, 0.0001f); + } + } + } + } + private float[] randomVector(int dimensions) { float[] vector = new float[dimensions]; for (int i = 0; i < dimensions; i++) { From f1243559813b37bc437eddc832d08eebd3c36f79 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 3 Sep 2025 05:00:08 +0000 Subject: [PATCH 4/6] [CI] Auto commit changes from spotless --- .../vectors/DirectIOVectorBatchLoader.java | 11 +++++------ .../BulkVectorFunctionScoreQueryTests.java | 17 ++++++----------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java index 03b23b70f4148..52d03bfa8a2ff 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java @@ -73,9 +73,10 @@ private Map loadSegmentVectorsParallel(int[] docIds, LeafReade } private Map loadVectorBatch( - FloatVectorValues vectorValues, - List docIdBatch, - Map docToOrdinal) throws IOException { + FloatVectorValues vectorValues, + List docIdBatch, + Map docToOrdinal + ) throws IOException { Map batchResult = new HashMap<>(); @@ -91,9 +92,7 @@ private Map loadVectorBatch( return batchResult; } - private Map buildDocToOrdinalMapping( - FloatVectorValues vectorValues, - int[] targetDocIds) throws IOException { + private Map buildDocToOrdinalMapping(FloatVectorValues vectorValues, int[] targetDocIds) throws IOException { Map docToOrdinal = new HashMap<>(); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java index 80d7db6639161..d053f14f4399c 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java @@ -207,18 +207,12 @@ public void testParallelVectorLoading() throws IOException { // Get initial documents TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 20); - int[] docIds = Arrays.stream(topDocs.scoreDocs) - .mapToInt(scoreDoc -> scoreDoc.doc) - .toArray(); + int[] docIds = Arrays.stream(topDocs.scoreDocs).mapToInt(scoreDoc -> scoreDoc.doc).toArray(); // Test parallel loading DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader(); - Map parallelResult = batchLoader.loadSegmentVectors( - docIds, - reader.leaves().get(0), - VECTOR_FIELD - ); + Map parallelResult = batchLoader.loadSegmentVectors(docIds, reader.leaves().get(0), VECTOR_FIELD); // use regular vector loader Map sequentialResult = new HashMap<>(); @@ -229,7 +223,9 @@ public void testParallelVectorLoading() throws IOException { // Verify results are identical assertThat( "Parallel and sequential results should have same size", - parallelResult.size(), equalTo(sequentialResult.size())); + parallelResult.size(), + equalTo(sequentialResult.size()) + ); for (int docId : docIds) { float[] parallelVector = parallelResult.get(docId); @@ -237,8 +233,7 @@ public void testParallelVectorLoading() throws IOException { assertNotNull("Parallel result should contain vector for doc " + docId, parallelVector); assertNotNull("Sequential result should contain vector for doc " + docId, sequentialVector); - assertArrayEquals("Vectors should be identical for doc " + docId, - sequentialVector, parallelVector, 0.0001f); + assertArrayEquals("Vectors should be identical for doc " + docId, sequentialVector, parallelVector, 0.0001f); } } } From 7517625e5c7b3044f7bad88a50f164f3c018c1fe Mon Sep 17 00:00:00 2001 From: Costin Leau Date: Wed, 3 Sep 2025 17:52:06 -0700 Subject: [PATCH 5/6] Loop only over needed documents --- .../vectors/DirectIOVectorBatchLoader.java | 38 ++++------- .../BulkVectorFunctionScoreQueryTests.java | 65 ++++++++++++------- 2 files changed, 55 insertions(+), 48 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java index 52d03bfa8a2ff..9216cf6ab496f 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/DirectIOVectorBatchLoader.java @@ -13,14 +13,13 @@ import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.elasticsearch.ExceptionsHelper; -import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.common.util.Maps; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.Executors; @@ -94,25 +93,16 @@ private Map loadVectorBatch( private Map buildDocToOrdinalMapping(FloatVectorValues vectorValues, int[] targetDocIds) throws IOException { - Map docToOrdinal = new HashMap<>(); - - Set targetDocSet = Sets.newHashSetWithExpectedSize(targetDocIds.length); - for (int docId : targetDocIds) { - targetDocSet.add(docId); - } + Map docToOrdinal = Maps.newHashMapWithExpectedSize(targetDocIds.length); KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); - for (int docId = iterator.nextDoc(); docId != KnnVectorValues.DocIndexIterator.NO_MORE_DOCS; docId = iterator.nextDoc()) { - if (targetDocSet.contains(docId)) { // Only map docs we actually need - docToOrdinal.put(docId, iterator.index()); - - // Early termination when all target docs found - if (docToOrdinal.size() == targetDocSet.size()) { - break; - } + for (int i = 0; i < targetDocIds.length; i++) { + var next = iterator.advance(targetDocIds[i]); + if (next == KnnVectorValues.DocIndexIterator.NO_MORE_DOCS || next != targetDocIds[i]) { + break; } + docToOrdinal.put(next, iterator.index()); } - return docToOrdinal; } @@ -134,14 +124,12 @@ public float[] loadSingleVector(int docId, LeafReaderContext context, String fie } KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); - for (int currentDoc = iterator.nextDoc(); currentDoc != KnnVectorValues.DocIndexIterator.NO_MORE_DOCS; currentDoc = iterator - .nextDoc()) { - if (currentDoc == docId) { - float[] vector = vectorValues.vectorValue(iterator.index()).clone(); - return vector != null ? vector : null; - } + var next = iterator.advance(docId); + float[] result = null; + if (next != KnnVectorValues.DocIndexIterator.NO_MORE_DOCS && next == docId) { + var ordinal = iterator.index(); + result = vectorValues.vectorValue(ordinal).clone(); } - - throw new IllegalArgumentException("Document " + docId + " not found in vector values"); + return result; } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java index d053f14f4399c..80de1bd41a8bd 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java @@ -14,6 +14,7 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -22,6 +23,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.MMapDirectory; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -36,7 +38,6 @@ public class BulkVectorFunctionScoreQueryTests extends ESTestCase { private static final String VECTOR_FIELD = "vector"; private static final int VECTOR_DIMS = 128; - public static final String BULK_VECTOR_SCORING = "es.bulk_vector_scoring"; public void testBulkProcessingWithScoreDocArray() throws IOException { // Create test index with vector documents @@ -195,54 +196,72 @@ private void createTestIndex(Directory dir, int docCount) throws IOException { } } + @SuppressWarnings("unchecked") public void testParallelVectorLoading() throws IOException { // Test parallel vector loading functionality - float[] queryVector = randomVector(VECTOR_DIMS); try (Directory dir = new MMapDirectory(createTempDir())) { createTestIndex(dir, 50); - try (DirectoryReader reader = DirectoryReader.open(dir)) { - IndexSearcher searcher = new IndexSearcher(reader); - - // Get initial documents - TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 20); - int[] docIds = Arrays.stream(topDocs.scoreDocs).mapToInt(scoreDoc -> scoreDoc.doc).toArray(); + Object[] results = new Object[2]; - // Test parallel loading + loadVectors(dir, (leafReaderContext, docIds) -> { + // Load vectors in parallel DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader(); + results[0] = batchLoader.loadSegmentVectors(docIds, leafReaderContext, VECTOR_FIELD); + }); - Map parallelResult = batchLoader.loadSegmentVectors(docIds, reader.leaves().get(0), VECTOR_FIELD); - - // use regular vector loader + loadVectors(dir, (leafReaderContext, docIds) -> { + // Load vectors in parallel + DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader(); Map sequentialResult = new HashMap<>(); for (int docId : docIds) { - sequentialResult.put(docId, batchLoader.loadSingleVector(docId, reader.leaves().get(0), VECTOR_FIELD)); + sequentialResult.put(docId, batchLoader.loadSingleVector(docId, leafReaderContext, VECTOR_FIELD)); } + results[1] = sequentialResult; + }); - // Verify results are identical - assertThat( - "Parallel and sequential results should have same size", + var sequentialResult = (Map) results[1]; + var parallelResult = (Map) results[0]; + + // Verify results are identical + assertThat( + "Parallel and sequential results should have same size", parallelResult.size(), equalTo(sequentialResult.size()) ); - for (int docId : docIds) { - float[] parallelVector = parallelResult.get(docId); - float[] sequentialVector = sequentialResult.get(docId); - assertNotNull("Parallel result should contain vector for doc " + docId, parallelVector); - assertNotNull("Sequential result should contain vector for doc " + docId, sequentialVector); + for (int docId : sequentialResult.keySet()) { + float[] sequentialVector = sequentialResult.get(docId); + float[] parallelVector = parallelResult.get(docId); + + assertNotNull("Parallel result should contain vector for doc " + docId, parallelVector); + assertNotNull("Sequential result should contain vector for doc " + docId, sequentialVector); assertArrayEquals("Vectors should be identical for doc " + docId, sequentialVector, parallelVector, 0.0001f); - } } } } + private void loadVectors(Directory dir, CheckedBiConsumer consumer) throws IOException { + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Get initial documents + TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), 20); + int[] docIds = Arrays.stream(topDocs.scoreDocs) + .mapToInt(scoreDoc -> scoreDoc.doc) + .toArray(); + + var leafReaderContext = reader.leaves().get(0); + consumer.accept(leafReaderContext, docIds); + } + } + private float[] randomVector(int dimensions) { float[] vector = new float[dimensions]; for (int i = 0; i < dimensions; i++) { - vector[i] = randomFloat() * 2.0f - 1.0f; // Range [-1, 1] + vector[i] = randomFloatBetween(-1.0f, 1.0f, true); } return vector; } From a70988439f71aa799f48584a95bb5689a23c2045 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 4 Sep 2025 01:06:35 +0000 Subject: [PATCH 6/6] [CI] Auto commit changes from spotless --- .../vectors/BulkVectorFunctionScoreQueryTests.java | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java index 80de1bd41a8bd..3d87ff24c0797 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/BulkVectorFunctionScoreQueryTests.java @@ -225,12 +225,7 @@ public void testParallelVectorLoading() throws IOException { var parallelResult = (Map) results[0]; // Verify results are identical - assertThat( - "Parallel and sequential results should have same size", - parallelResult.size(), - equalTo(sequentialResult.size()) - ); - + assertThat("Parallel and sequential results should have same size", parallelResult.size(), equalTo(sequentialResult.size())); for (int docId : sequentialResult.keySet()) { float[] sequentialVector = sequentialResult.get(docId); @@ -238,7 +233,7 @@ public void testParallelVectorLoading() throws IOException { assertNotNull("Parallel result should contain vector for doc " + docId, parallelVector); assertNotNull("Sequential result should contain vector for doc " + docId, sequentialVector); - assertArrayEquals("Vectors should be identical for doc " + docId, sequentialVector, parallelVector, 0.0001f); + assertArrayEquals("Vectors should be identical for doc " + docId, sequentialVector, parallelVector, 0.0001f); } } } @@ -249,9 +244,7 @@ private void loadVectors(Directory dir, CheckedBiConsumer scoreDoc.doc) - .toArray(); + int[] docIds = Arrays.stream(topDocs.scoreDocs).mapToInt(scoreDoc -> scoreDoc.doc).toArray(); var leafReaderContext = reader.leaves().get(0); consumer.accept(leafReaderContext, docIds);