Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.elasticsearch.search.profile.query.QueryProfiler;
import org.elasticsearch.search.vectors.QueryProfilerProvider;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -29,12 +27,11 @@
* DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the
* original vector values stored in the index
*/
public class VectorSimilarityFloatValueSource extends DoubleValuesSource implements QueryProfilerProvider {
public class VectorSimilarityFloatValueSource extends DoubleValuesSource {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Profiling can be simplified, as we always know the number of results to return


private final String field;
private final float[] target;
private final VectorSimilarityFunction vectorSimilarityFunction;
private long vectorOpsCount;

public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) {
this.field = field;
Expand All @@ -52,7 +49,6 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
vectorOpsCount++;
return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(iterator.index()));
}

Expand All @@ -73,11 +69,6 @@ public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
return this;
}

@Override
public void profile(QueryProfiler queryProfiler) {
queryProfiler.addVectorOpsCount(vectorOpsCount);
}

@Override
public int hashCode() {
return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,11 +487,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
return this;
}

@Override
protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException {
return super.doIndexMetadataRewrite(context);
}

@Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
MappedFieldType fieldType = context.getFieldType(fieldName);
Expand Down Expand Up @@ -529,8 +524,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
String parentPath = context.nestedLookup().getNestedParent(fieldName);
Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor();

BitSetProducer parentBitSet = null;
if (parentPath != null) {
final BitSetProducer parentBitSet;
final Query parentFilter;
NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper();
if (originalObjectMapper != null) {
Expand Down Expand Up @@ -559,17 +554,17 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
// Now join the filterQuery & parentFilter to provide the matching blocks of children
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
}
return vectorFieldType.createKnnQuery(
queryVector,
k,
adjustedNumCands,
numCandidatesFactor,
filterQuery,
vectorSimilarity,
parentBitSet
);
}
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, numCandidatesFactor, filterQuery, vectorSimilarity, null);

return vectorFieldType.createKnnQuery(
queryVector,
k,
adjustedNumCands,
numCandidatesFactor,
filterQuery,
vectorSimilarity,
parentBitSet
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,15 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide
private final String fieldName;
private final float[] floatTarget;
private final VectorSimilarityFunction vectorSimilarityFunction;
private final Integer k;
private final int k;
private final Query innerQuery;

private QueryProfilerProvider vectorProfiling;
private long vectorOperations = 0;

public RescoreKnnVectorQuery(
String fieldName,
float[] floatTarget,
VectorSimilarityFunction vectorSimilarityFunction,
Integer k,
int k,
Query innerQuery
) {
this.fieldName = fieldName;
Expand All @@ -54,19 +53,12 @@ public RescoreKnnVectorQuery(
@Override
public Query rewrite(IndexSearcher searcher) throws IOException {
DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
// Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need
// to calculate top k and return directly the query to understand how many comparisons were done
vectorProfiling = (QueryProfilerProvider) valueSource;
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
Query query = searcher.rewrite(functionScoreQuery);

if (k == null) {
// No need to calculate top k - let the request size limit the results.
return query;
}

// Retrieve top k documents from the rescored query
TopDocs topDocs = searcher.search(query, k);
vectorOperations = topDocs.totalHits.value();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We know in advance the number of comparisons done

ScoreDoc[] scoreDocs = topDocs.scoreDocs;
int[] docIds = new int[scoreDocs.length];
float[] scores = new float[scoreDocs.length];
Expand All @@ -82,7 +74,7 @@ public Query innerQuery() {
return innerQuery;
}

public Integer k() {
public int k() {
return k;
}

Expand All @@ -92,10 +84,7 @@ public void profile(QueryProfiler queryProfiler) {
queryProfilerProvider.profile(queryProfiler);
}

if (vectorProfiling == null) {
throw new IllegalStateException("Query should have been rewritten");
}
vectorProfiling.profile(queryProfiler);
queryProfiler.addVectorOpsCount(vectorOperations);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,18 +456,19 @@ public void testRescoreOversampleModifiesNumCandidates() {
);

// Total results is k, internal k is multiplied by oversample
checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, null, 500, 10);
checkRescoreQueryParameters(fieldType, 10, 200, randomInt(), 2.5F, null, 500, 10);
// If numCands < k, update numCands to k
checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, null, 50, 10);
checkRescoreQueryParameters(fieldType, 10, 20, randomInt(), 2.5F, null, 50, 10);
// Oversampling limits for num candidates
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, null, 10000, 1000);
checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, null, 10000, 5000);
checkRescoreQueryParameters(fieldType, 1000, 1000, randomInt(), 11.0F, null, 10000, 1000);
checkRescoreQueryParameters(fieldType, 5000, 7500, randomInt(), 2.5F, null, 10000, 5000);
}

private static void checkRescoreQueryParameters(
DenseVectorFieldType fieldType,
Integer k,
int k,
int candidates,
int requestSize,
float numCandsFactor,
Integer expectedK,
int expectedCandidates,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

package org.elasticsearch.search.vectors;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.DirectoryReader;
Expand All @@ -33,11 +31,9 @@

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
Expand All @@ -49,21 +45,11 @@
public class RescoreKnnVectorQueryTests extends ESTestCase {

public static final String FIELD_NAME = "float_vector";
private final int numDocs;
private final Integer k;

public RescoreKnnVectorQueryTests(boolean useK) {
this.numDocs = randomIntBetween(10, 100);
this.k = useK ? randomIntBetween(1, numDocs - 1) : null;
}

public void testRescoreDocs() throws Exception {
int numDocs = randomIntBetween(10, 100);
int numDims = randomIntBetween(5, 100);

Integer adjustedK = k;
if (k == null) {
adjustedK = numDocs;
}
int k = randomIntBetween(1, numDocs - 1);

try (Directory d = newDirectory()) {
addRandomDocuments(numDocs, d, numDims);
Expand All @@ -77,7 +63,7 @@ public void testRescoreDocs() throws Exception {
FIELD_NAME,
queryVector,
VectorSimilarityFunction.COSINE,
adjustedK,
k,
new MatchAllDocsQuery()
);

Expand All @@ -86,7 +72,7 @@ public void testRescoreDocs() throws Exception {
Map<Integer, Float> rescoredDocs = Arrays.stream(docs.scoreDocs)
.collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score));

assertThat(rescoredDocs.size(), equalTo(adjustedK));
assertThat(rescoredDocs.size(), equalTo(k));

Collection<Float> rescoredScores = new HashSet<>(rescoredDocs.values());

Expand All @@ -113,7 +99,7 @@ public void testRescoreDocs() throws Exception {
assertThat(rescoredDocs.size(), equalTo(0));

// Check top scoring docs are contained in rescored docs
for (int i = 0; i < adjustedK; i++) {
for (int i = 0; i < k; i++) {
Float topScore = topK.poll();
if (rescoredScores.contains(topScore) == false) {
fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores);
Expand All @@ -124,21 +110,23 @@ public void testRescoreDocs() throws Exception {
}

public void testProfiling() throws Exception {
int numDocs = randomIntBetween(10, 100);
int numDims = randomIntBetween(5, 100);
int k = randomIntBetween(1, numDocs - 1);

try (Directory d = newDirectory()) {
addRandomDocuments(numDocs, d, numDims);

try (IndexReader reader = DirectoryReader.open(d)) {
float[] queryVector = randomVector(numDims);

checkProfiling(queryVector, reader, new MatchAllDocsQuery());
checkProfiling(queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100)));
checkProfiling(k, numDocs, queryVector, reader, new MatchAllDocsQuery());
checkProfiling(k, numDocs, queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100)));
}
}
}

private void checkProfiling(float[] queryVector, IndexReader reader, Query innerQuery) throws IOException {
private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader reader, Query innerQuery) throws IOException {
RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
FIELD_NAME,
queryVector,
Expand Down Expand Up @@ -229,13 +217,4 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th
w.forceMerge(1);
}
}

@ParametersFactory
public static Iterable<Object[]> parameters() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always have a specific k, it makes no sense to use parameters.

List<Object[]> params = new ArrayList<>();
params.add(new Object[] { true });
params.add(new Object[] { false });

return params;
}
}