Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.vectors;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;

/**
* Subclass of VectorSimilarityFloatValueSource offering access to its members for other classes
* in the same package.
*/
class AccessibleVectorSimilarityFloatValueSource extends VectorSimilarityFloatValueSource {

String field;
float[] target;
VectorSimilarityFunction vectorSimilarityFunction;

AccessibleVectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) {
super(field, target, vectorSimilarityFunction);
this.field = field;
this.target = target;
this.vectorSimilarityFunction = vectorSimilarityFunction;
}

public String field() {
return field;
}

public float[] target() {
return target;
}

public VectorSimilarityFunction similarityFunction() {
return vectorSimilarityFunction;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.vectors;

import org.apache.lucene.index.VectorSimilarityFunction;

import java.util.Map;

public final class BatchVectorSimilarity {

private BatchVectorSimilarity() {}

public static float[] computeBatchSimilarity(
float[] queryVector,
Map<Integer, float[]> docVectors,
int[] docIds,
VectorSimilarityFunction function
) {
float[] results = new float[docIds.length];
float[][] data = organizeSIMDVectors(docVectors, docIds);

for (int i = 0, l = data.length; i < l; i++) {
float[] docVector = data[i];
results[i] = function.compare(queryVector, docVector);
}

return results;
}

public static float[][] organizeSIMDVectors(Map<Integer, float[]> vectorMap, int[] docIds) {
float[][] vectors = new float[docIds.length][];
for (int i = 0; i < docIds.length; i++) {
vectors[i] = vectorMap.get(docIds[i]);
}
return vectors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.vectors;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

/**
* Enhanced FunctionScoreQuery that enables bulk vector processing for KNN rescoring.
* When provided with a ScoreDoc array, performs bulk vector loading and similarity
* computation instead of individual per-document processing.
*/
public class BulkVectorFunctionScoreQuery extends Query {

private final Query subQuery;
private final AccessibleVectorSimilarityFloatValueSource valueSource;
private final ScoreDoc[] scoreDocs;

public BulkVectorFunctionScoreQuery(Query subQuery, AccessibleVectorSimilarityFloatValueSource valueSource, ScoreDoc[] scoreDocs) {
this.subQuery = subQuery;
this.valueSource = valueSource;
this.scoreDocs = scoreDocs;
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
// TODO: take a closer look at ScoreMode
Weight subQueryWeight = subQuery.createWeight(searcher, scoreMode, boost);
return new BulkVectorFunctionScoreWeight(this, subQueryWeight, valueSource, scoreDocs);
}

@Override
public Query rewrite(IndexSearcher searcher) throws IOException {
Query rewrittenSubQuery = subQuery.rewrite(searcher);
if (rewrittenSubQuery != subQuery) {
return new BulkVectorFunctionScoreQuery(rewrittenSubQuery, valueSource, scoreDocs);
}
return this;
}

@Override
public String toString(String field) {
StringBuilder sb = new StringBuilder();
sb.append("bulk_vector_function_score(");
sb.append(subQuery.toString(field));
sb.append(", vector_similarity=").append(valueSource.toString());
if (scoreDocs != null) {
sb.append(", bulk_docs=").append(scoreDocs.length);
}
sb.append(")");
return sb.toString();
}

@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;

BulkVectorFunctionScoreQuery that = (BulkVectorFunctionScoreQuery) obj;
return Objects.equals(subQuery, that.subQuery)
&& Objects.equals(valueSource, that.valueSource)
&& Arrays.equals(scoreDocs, that.scoreDocs);
}

@Override
public int hashCode() {
return Objects.hash(subQuery, valueSource, scoreDocs);
}

@Override
public void visit(org.apache.lucene.search.QueryVisitor visitor) {
subQuery.visit(visitor.getSubVisitor(org.apache.lucene.search.BooleanClause.Occur.MUST, this));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.vectors;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
* Weight implementation that enables bulk vector processing for KNN rescoring queries.
* Extracts segment-specific documents from ScoreDoc array and creates bulk scorers.
*/
public class BulkVectorFunctionScoreWeight extends Weight {

private final Weight subQueryWeight;
private final AccessibleVectorSimilarityFloatValueSource valueSource;
private final ScoreDoc[] scoreDocs;

public BulkVectorFunctionScoreWeight(
Query parent,
Weight subQueryWeight,
AccessibleVectorSimilarityFloatValueSource valueSource,
ScoreDoc[] scoreDocs
) {
super(parent);
this.subQueryWeight = subQueryWeight;
this.valueSource = valueSource;
this.scoreDocs = scoreDocs;
}

@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
ScorerSupplier subQueryScorerSupplier = subQueryWeight.scorerSupplier(context);
if (subQueryScorerSupplier == null) {
return null;
}

// Extract documents belonging to this segment
int[] segmentDocIds = extractSegmentDocuments(scoreDocs, context);
if (segmentDocIds.length == 0) {
return null; // No documents in this segment
}

return new ScorerSupplier() {
@Override
public Scorer get(long leadCost) throws IOException {
// if asked for basic Scorer, delegate to the underlying subquery scorer
return subQueryScorerSupplier.get(leadCost);
}

@Override
public BulkScorer bulkScorer() throws IOException {
// Always use BulkScorer when bulk processing is enabled
BulkScorer subQueryBulkScorer = subQueryScorerSupplier.bulkScorer();
return new BulkVectorScorer(subQueryBulkScorer, segmentDocIds, valueSource, context);
}

@Override
public long cost() {
return segmentDocIds.length;
}
};
}

@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
// Find the document in our ScoreDoc array
int globalDocId = doc + context.docBase;
for (ScoreDoc scoreDoc : scoreDocs) {
if (scoreDoc.doc == globalDocId) {
// Compute explanation for this specific document
try {
DirectIOVectorBatchLoader batchLoader = new DirectIOVectorBatchLoader();
float[] docVector = batchLoader.loadSingleVector(doc, context, valueSource.field());
float similarity = valueSource.similarityFunction().compare(valueSource.target(), docVector);

return Explanation.match(
similarity,
"bulk vector similarity score, computed with vector similarity function: " + valueSource.similarityFunction()
);
} catch (Exception e) {
return Explanation.noMatch("Failed to compute vector similarity: " + e.getMessage());
}
}
}
return Explanation.noMatch("Document not in bulk processing set");
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}

private int[] extractSegmentDocuments(ScoreDoc[] scoreDocs, LeafReaderContext context) {
List<Integer> segmentDocs = new ArrayList<>();
int docBase = context.docBase;
int maxDoc = docBase + context.reader().maxDoc();

for (ScoreDoc scoreDoc : scoreDocs) {
if (scoreDoc.doc >= docBase && scoreDoc.doc < maxDoc) {
// Convert to segment-relative document ID
segmentDocs.add(scoreDoc.doc - docBase);
}
}

return segmentDocs.stream().mapToInt(Integer::intValue).toArray();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.vectors;

import org.elasticsearch.common.util.FeatureFlag;

/**
* Feature flags and settings for bulk vector processing optimizations.
*/
public final class BulkVectorProcessingSettings {

public static final boolean BULK_VECTOR_SCORING = new FeatureFlag("bulk_vector_scoring").isEnabled();

public static final int MIN_BULK_PROCESSING_THRESHOLD = 3;

private BulkVectorProcessingSettings() {
// Utility class
}

public static boolean shouldUseBulkProcessing(int documentCount) {
return BULK_VECTOR_SCORING && documentCount >= MIN_BULK_PROCESSING_THRESHOLD;
}
}
Loading