-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Vector rescoring - Simplify code for k == null #118997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
carlosdelest
merged 8 commits into
elastic:main
from
carlosdelest:non-issue/rescore-vector-use-size-as-k
Jan 9, 2025
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
e03f240
Use request size when k is null to calculate the number of results to…
carlosdelest 1674d50
Merge branch 'main' into non-issue/rescore-vector-use-size-as-k
elasticmachine 72b8779
Merge branch 'main' into non-issue/rescore-vector-use-size-as-k
carlosdelest 89fae7c
Merge branch 'main' into non-issue/rescore-vector-use-size-as-k
elasticmachine 0a2f895
Remove unnecessary override
carlosdelest dcbe0bf
Merge remote-tracking branch 'carlosdelest/non-issue/rescore-vector-u…
carlosdelest 17979d5
Merge branch 'main' into non-issue/rescore-vector-use-size-as-k
carlosdelest f1e2972
Remove request size as it is already provided from the query using k
carlosdelest File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,16 +32,15 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide | |
private final String fieldName; | ||
private final float[] floatTarget; | ||
private final VectorSimilarityFunction vectorSimilarityFunction; | ||
private final Integer k; | ||
private final int k; | ||
private final Query innerQuery; | ||
|
||
private QueryProfilerProvider vectorProfiling; | ||
private long vectorOperations = 0; | ||
|
||
public RescoreKnnVectorQuery( | ||
String fieldName, | ||
float[] floatTarget, | ||
VectorSimilarityFunction vectorSimilarityFunction, | ||
Integer k, | ||
int k, | ||
Query innerQuery | ||
) { | ||
this.fieldName = fieldName; | ||
|
@@ -54,19 +53,12 @@ public RescoreKnnVectorQuery( | |
@Override | ||
public Query rewrite(IndexSearcher searcher) throws IOException { | ||
DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); | ||
// Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need | ||
// to calculate top k and return directly the query to understand how many comparisons were done | ||
vectorProfiling = (QueryProfilerProvider) valueSource; | ||
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); | ||
Query query = searcher.rewrite(functionScoreQuery); | ||
|
||
if (k == null) { | ||
// No need to calculate top k - let the request size limit the results. | ||
return query; | ||
} | ||
|
||
// Retrieve top k documents from the rescored query | ||
TopDocs topDocs = searcher.search(query, k); | ||
vectorOperations = topDocs.totalHits.value(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We know in advance the number of comparisons done |
||
ScoreDoc[] scoreDocs = topDocs.scoreDocs; | ||
int[] docIds = new int[scoreDocs.length]; | ||
float[] scores = new float[scoreDocs.length]; | ||
|
@@ -82,7 +74,7 @@ public Query innerQuery() { | |
return innerQuery; | ||
} | ||
|
||
public Integer k() { | ||
public int k() { | ||
return k; | ||
} | ||
|
||
|
@@ -92,10 +84,7 @@ public void profile(QueryProfiler queryProfiler) { | |
queryProfilerProvider.profile(queryProfiler); | ||
} | ||
|
||
if (vectorProfiling == null) { | ||
throw new IllegalStateException("Query should have been rewritten"); | ||
} | ||
vectorProfiling.profile(queryProfiler); | ||
queryProfiler.addVectorOpsCount(vectorOperations); | ||
} | ||
|
||
@Override | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,8 +9,6 @@ | |
|
||
package org.elasticsearch.search.vectors; | ||
|
||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; | ||
|
||
import org.apache.lucene.document.Document; | ||
import org.apache.lucene.document.KnnFloatVectorField; | ||
import org.apache.lucene.index.DirectoryReader; | ||
|
@@ -33,11 +31,9 @@ | |
|
||
import java.io.IOException; | ||
import java.io.UnsupportedEncodingException; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.Collection; | ||
import java.util.HashSet; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.PriorityQueue; | ||
import java.util.stream.Collectors; | ||
|
@@ -49,21 +45,11 @@ | |
public class RescoreKnnVectorQueryTests extends ESTestCase { | ||
|
||
public static final String FIELD_NAME = "float_vector"; | ||
private final int numDocs; | ||
private final Integer k; | ||
|
||
public RescoreKnnVectorQueryTests(boolean useK) { | ||
this.numDocs = randomIntBetween(10, 100); | ||
this.k = useK ? randomIntBetween(1, numDocs - 1) : null; | ||
} | ||
|
||
public void testRescoreDocs() throws Exception { | ||
int numDocs = randomIntBetween(10, 100); | ||
int numDims = randomIntBetween(5, 100); | ||
|
||
Integer adjustedK = k; | ||
if (k == null) { | ||
adjustedK = numDocs; | ||
} | ||
int k = randomIntBetween(1, numDocs - 1); | ||
|
||
try (Directory d = newDirectory()) { | ||
addRandomDocuments(numDocs, d, numDims); | ||
|
@@ -77,7 +63,7 @@ public void testRescoreDocs() throws Exception { | |
FIELD_NAME, | ||
queryVector, | ||
VectorSimilarityFunction.COSINE, | ||
adjustedK, | ||
k, | ||
new MatchAllDocsQuery() | ||
); | ||
|
||
|
@@ -86,7 +72,7 @@ public void testRescoreDocs() throws Exception { | |
Map<Integer, Float> rescoredDocs = Arrays.stream(docs.scoreDocs) | ||
.collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score)); | ||
|
||
assertThat(rescoredDocs.size(), equalTo(adjustedK)); | ||
assertThat(rescoredDocs.size(), equalTo(k)); | ||
|
||
Collection<Float> rescoredScores = new HashSet<>(rescoredDocs.values()); | ||
|
||
|
@@ -113,7 +99,7 @@ public void testRescoreDocs() throws Exception { | |
assertThat(rescoredDocs.size(), equalTo(0)); | ||
|
||
// Check top scoring docs are contained in rescored docs | ||
for (int i = 0; i < adjustedK; i++) { | ||
for (int i = 0; i < k; i++) { | ||
Float topScore = topK.poll(); | ||
if (rescoredScores.contains(topScore) == false) { | ||
fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); | ||
|
@@ -124,21 +110,23 @@ public void testRescoreDocs() throws Exception { | |
} | ||
|
||
public void testProfiling() throws Exception { | ||
int numDocs = randomIntBetween(10, 100); | ||
int numDims = randomIntBetween(5, 100); | ||
int k = randomIntBetween(1, numDocs - 1); | ||
|
||
try (Directory d = newDirectory()) { | ||
addRandomDocuments(numDocs, d, numDims); | ||
|
||
try (IndexReader reader = DirectoryReader.open(d)) { | ||
float[] queryVector = randomVector(numDims); | ||
|
||
checkProfiling(queryVector, reader, new MatchAllDocsQuery()); | ||
checkProfiling(queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100))); | ||
checkProfiling(k, numDocs, queryVector, reader, new MatchAllDocsQuery()); | ||
checkProfiling(k, numDocs, queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100))); | ||
} | ||
} | ||
} | ||
|
||
private void checkProfiling(float[] queryVector, IndexReader reader, Query innerQuery) throws IOException { | ||
private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader reader, Query innerQuery) throws IOException { | ||
RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( | ||
FIELD_NAME, | ||
queryVector, | ||
|
@@ -229,13 +217,4 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th | |
w.forceMerge(1); | ||
} | ||
} | ||
|
||
@ParametersFactory | ||
public static Iterable<Object[]> parameters() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We always have a specific k, it makes no sense to use parameters. |
||
List<Object[]> params = new ArrayList<>(); | ||
params.add(new Object[] { true }); | ||
params.add(new Object[] { false }); | ||
|
||
return params; | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Profiling can be simplified, as we always know the number of results to return