Skip to content

Commit 81235e3

Browse files
committed
Bulk vector processing
Add processing bulk at various stages of the KNN query: a. BulkVectorFunctionQuery To capture the array of ScoreDocs for bulk processing b. BulkVectorScorer (through dedicated Weight) 1. To load the vectors in bulk through DirectIOVectorBatchLoader 2. Compute the similarity across multiple vectors 3. Store the scores across a batch of docs wip
1 parent 19c035f commit 81235e3

10 files changed

+946
-5
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.index.VectorSimilarityFunction;
13+
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
14+
15+
/**
16+
* Subclass of VectorSimilarityFloatValueSource offering access to its members for other classes
17+
* in the same package.
18+
*/
19+
class AccessibleVectorSimilarityFloatValueSource extends VectorSimilarityFloatValueSource {
20+
21+
String field;
22+
float[] target;
23+
VectorSimilarityFunction vectorSimilarityFunction;
24+
25+
AccessibleVectorSimilarityFloatValueSource(String field,
26+
float[] target,
27+
VectorSimilarityFunction vectorSimilarityFunction) {
28+
super(field, target, vectorSimilarityFunction);
29+
this.field = field;
30+
this.target = target;
31+
this.vectorSimilarityFunction = vectorSimilarityFunction;
32+
}
33+
34+
public String field() {
35+
return field;
36+
}
37+
38+
public float[] target() {
39+
return target;
40+
}
41+
42+
public VectorSimilarityFunction similarityFunction() {
43+
return vectorSimilarityFunction;
44+
}
45+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.index.VectorSimilarityFunction;
13+
14+
import java.util.Map;
15+
16+
public final class BatchVectorSimilarity {
17+
18+
private BatchVectorSimilarity() {
19+
}
20+
21+
public static float[] computeBatchSimilarity(float[] queryVector, Map<Integer, float[]> docVectors,
22+
int[] docIds, VectorSimilarityFunction function) {
23+
float[] results = new float[docIds.length];
24+
float[][] data = organizeSIMDVectors(docVectors, docIds);
25+
26+
for (int i = 0, l = data.length; i < l; i++) {
27+
float[] docVector = data[i];
28+
results[i] = function.compare(queryVector, docVector);
29+
}
30+
31+
return results;
32+
}
33+
34+
public static float[][] organizeSIMDVectors(Map<Integer, float[]> vectorMap, int[] docIds) {
35+
float[][] vectors = new float[docIds.length][];
36+
for (int i = 0; i < docIds.length; i++) {
37+
vectors[i] = vectorMap.get(docIds[i]);
38+
}
39+
return vectors;
40+
}
41+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.search.IndexSearcher;
13+
import org.apache.lucene.search.Query;
14+
import org.apache.lucene.search.ScoreDoc;
15+
import org.apache.lucene.search.ScoreMode;
16+
import org.apache.lucene.search.Weight;
17+
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
18+
19+
import java.io.IOException;
20+
import java.util.Arrays;
21+
import java.util.Objects;
22+
23+
/**
24+
* Enhanced FunctionScoreQuery that enables bulk vector processing for KNN rescoring.
25+
* When provided with a ScoreDoc array, performs bulk vector loading and similarity
26+
* computation instead of individual per-document processing.
27+
*/
28+
public class BulkVectorFunctionScoreQuery extends Query {
29+
30+
private final Query subQuery;
31+
private final AccessibleVectorSimilarityFloatValueSource valueSource;
32+
private final ScoreDoc[] scoreDocs;
33+
34+
public BulkVectorFunctionScoreQuery(Query subQuery, AccessibleVectorSimilarityFloatValueSource valueSource, ScoreDoc[] scoreDocs) {
35+
this.subQuery = subQuery;
36+
this.valueSource = valueSource;
37+
this.scoreDocs = scoreDocs;
38+
}
39+
40+
@Override
41+
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
42+
// TODO: take a closer look at ScoreMode
43+
Weight subQueryWeight = subQuery.createWeight(searcher, scoreMode, boost);
44+
return new BulkVectorFunctionScoreWeight(this, subQueryWeight, valueSource, scoreDocs);
45+
}
46+
47+
@Override
48+
public Query rewrite(IndexSearcher searcher) throws IOException {
49+
Query rewrittenSubQuery = subQuery.rewrite(searcher);
50+
if (rewrittenSubQuery != subQuery) {
51+
return new BulkVectorFunctionScoreQuery(rewrittenSubQuery, valueSource, scoreDocs);
52+
}
53+
return this;
54+
}
55+
56+
@Override
57+
public String toString(String field) {
58+
StringBuilder sb = new StringBuilder();
59+
sb.append("bulk_vector_function_score(");
60+
sb.append(subQuery.toString(field));
61+
sb.append(", vector_similarity=").append(valueSource.toString());
62+
if (scoreDocs != null) {
63+
sb.append(", bulk_docs=").append(scoreDocs.length);
64+
}
65+
sb.append(")");
66+
return sb.toString();
67+
}
68+
69+
@Override
70+
public boolean equals(Object obj) {
71+
if (this == obj) return true;
72+
if (obj == null || getClass() != obj.getClass()) return false;
73+
74+
BulkVectorFunctionScoreQuery that = (BulkVectorFunctionScoreQuery) obj;
75+
return Objects.equals(subQuery, that.subQuery)
76+
&& Objects.equals(valueSource, that.valueSource)
77+
&& Arrays.equals(scoreDocs, that.scoreDocs);
78+
}
79+
80+
@Override
81+
public int hashCode() {
82+
return Objects.hash(subQuery, valueSource, scoreDocs);
83+
}
84+
85+
@Override
86+
public void visit(org.apache.lucene.search.QueryVisitor visitor) {
87+
subQuery.visit(visitor.getSubVisitor(org.apache.lucene.search.BooleanClause.Occur.MUST, this));
88+
}
89+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.index.LeafReaderContext;
13+
import org.apache.lucene.search.BulkScorer;
14+
import org.apache.lucene.search.Explanation;
15+
import org.apache.lucene.search.Query;
16+
import org.apache.lucene.search.ScoreDoc;
17+
import org.apache.lucene.search.Scorer;
18+
import org.apache.lucene.search.ScorerSupplier;
19+
import org.apache.lucene.search.Weight;
20+
21+
import java.io.IOException;
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
25+
/**
26+
* Weight implementation that enables bulk vector processing for KNN rescoring queries.
27+
* Extracts segment-specific documents from ScoreDoc array and creates bulk scorers.
28+
*/
29+
public class BulkVectorFunctionScoreWeight extends Weight {
30+
31+
private final Weight subQueryWeight;
32+
private final AccessibleVectorSimilarityFloatValueSource valueSource;
33+
private final ScoreDoc[] scoreDocs;
34+
35+
public BulkVectorFunctionScoreWeight(
36+
Query parent,
37+
Weight subQueryWeight,
38+
AccessibleVectorSimilarityFloatValueSource valueSource,
39+
ScoreDoc[] scoreDocs) {
40+
super(parent);
41+
this.subQueryWeight = subQueryWeight;
42+
this.valueSource = valueSource;
43+
this.scoreDocs = scoreDocs;
44+
}
45+
46+
@Override
47+
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
48+
ScorerSupplier subQueryScorerSupplier = subQueryWeight.scorerSupplier(context);
49+
if (subQueryScorerSupplier == null) {
50+
return null;
51+
}
52+
53+
// Extract documents belonging to this segment
54+
int[] segmentDocIds = extractSegmentDocuments(scoreDocs, context);
55+
if (segmentDocIds.length == 0) {
56+
return null; // No documents in this segment
57+
}
58+
59+
return new ScorerSupplier() {
60+
@Override
61+
public Scorer get(long leadCost) throws IOException {
62+
throw new UnsupportedOperationException(
63+
"Individual Scorer not supported when bulk vector processing is enabled. Use bulkScorer() instead.");
64+
}
65+
66+
@Override
67+
public BulkScorer bulkScorer() throws IOException {
68+
// Always use BulkScorer when bulk processing is enabled
69+
BulkScorer subQueryBulkScorer = subQueryScorerSupplier.bulkScorer();
70+
return new BulkVectorScorer(subQueryBulkScorer, segmentDocIds, valueSource, context);
71+
}
72+
73+
@Override
74+
public long cost() {
75+
return segmentDocIds.length;
76+
}
77+
};
78+
}
79+
80+
@Override
81+
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
82+
// Find the document in our ScoreDoc array
83+
int globalDocId = doc + context.docBase;
84+
for (ScoreDoc scoreDoc : scoreDocs) {
85+
if (scoreDoc.doc == globalDocId) {
86+
// Compute explanation for this specific document
87+
try {
88+
DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader();
89+
float[] docVector = batchLoader.loadSingleVector(doc, context, valueSource.field());
90+
float similarity = valueSource.similarityFunction().compare(valueSource.target(), docVector);
91+
92+
return Explanation.match(
93+
similarity,
94+
"bulk vector similarity score, computed with vector similarity function: " + valueSource.similarityFunction()
95+
);
96+
} catch (Exception e) {
97+
return Explanation.noMatch("Failed to compute vector similarity: " + e.getMessage());
98+
}
99+
}
100+
}
101+
return Explanation.noMatch("Document not in bulk processing set");
102+
}
103+
104+
@Override
105+
public boolean isCacheable(LeafReaderContext ctx) {
106+
return false;
107+
}
108+
109+
private int[] extractSegmentDocuments(ScoreDoc[] scoreDocs, LeafReaderContext context) {
110+
List<Integer> segmentDocs = new ArrayList<>();
111+
int docBase = context.docBase;
112+
int maxDoc = docBase + context.reader().maxDoc();
113+
114+
for (ScoreDoc scoreDoc : scoreDocs) {
115+
if (scoreDoc.doc >= docBase && scoreDoc.doc < maxDoc) {
116+
// Convert to segment-relative document ID
117+
segmentDocs.add(scoreDoc.doc - docBase);
118+
}
119+
}
120+
121+
return segmentDocs.stream().mapToInt(Integer::intValue).toArray();
122+
}
123+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.elasticsearch.common.util.FeatureFlag;
13+
14+
/**
15+
* Feature flags and settings for bulk vector processing optimizations.
16+
*/
17+
public final class BulkVectorProcessingSettings {
18+
19+
public static final boolean BULK_VECTOR_SCORING = new FeatureFlag("bulk_vector_scoring").isEnabled();
20+
21+
public static final int MIN_BULK_PROCESSING_THRESHOLD = 3;
22+
23+
private BulkVectorProcessingSettings() {
24+
// Utility class
25+
}
26+
27+
public static boolean shouldUseBulkProcessing(int documentCount) {
28+
return BULK_VECTOR_SCORING && documentCount >= MIN_BULK_PROCESSING_THRESHOLD;
29+
}
30+
}

0 commit comments

Comments
 (0)