Skip to content

Commit e03f240

Browse files
committed
Use request size when k is null to calculate the number of results to retrieve from each shard
1 parent e0763c2 commit e03f240

File tree

7 files changed

+89
-87
lines changed

7 files changed

+89
-87
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2008,6 +2008,7 @@ public Query createKnnQuery(
20082008
VectorData queryVector,
20092009
Integer k,
20102010
int numCands,
2011+
int requestSize,
20112012
Float numCandsFactor,
20122013
Query filter,
20132014
Float similarityThreshold,
@@ -2024,6 +2025,7 @@ public Query createKnnQuery(
20242025
queryVector.asFloatVector(),
20252026
k,
20262027
numCands,
2028+
requestSize,
20272029
numCandsFactor,
20282030
filter,
20292031
similarityThreshold,
@@ -2090,6 +2092,7 @@ private Query createKnnFloatQuery(
20902092
float[] queryVector,
20912093
Integer k,
20922094
int numCands,
2095+
int requestSize,
20932096
Float numCandsFactor,
20942097
Query filter,
20952098
Float similarityThreshold,
@@ -2127,7 +2130,7 @@ && isNotUnitVector(squaredMagnitude)) {
21272130
name(),
21282131
queryVector,
21292132
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT),
2130-
k,
2133+
k == null ? requestSize : k,
21312134
knnQuery
21322135
);
21332136
}

server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import org.apache.lucene.search.DoubleValues;
1919
import org.apache.lucene.search.DoubleValuesSource;
2020
import org.apache.lucene.search.IndexSearcher;
21-
import org.elasticsearch.search.profile.query.QueryProfiler;
22-
import org.elasticsearch.search.vectors.QueryProfilerProvider;
2321

2422
import java.io.IOException;
2523
import java.util.Arrays;
@@ -29,12 +27,11 @@
2927
* DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the
3028
* original vector values stored in the index
3129
*/
32-
public class VectorSimilarityFloatValueSource extends DoubleValuesSource implements QueryProfilerProvider {
30+
public class VectorSimilarityFloatValueSource extends DoubleValuesSource {
3331

3432
private final String field;
3533
private final float[] target;
3634
private final VectorSimilarityFunction vectorSimilarityFunction;
37-
private long vectorOpsCount;
3835

3936
public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) {
4037
this.field = field;
@@ -52,7 +49,6 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws
5249
return new DoubleValues() {
5350
@Override
5451
public double doubleValue() throws IOException {
55-
vectorOpsCount++;
5652
return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(iterator.index()));
5753
}
5854

@@ -73,11 +69,6 @@ public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
7369
return this;
7470
}
7571

76-
@Override
77-
public void profile(QueryProfiler queryProfiler) {
78-
queryProfiler.addVectorOpsCount(vectorOpsCount);
79-
}
80-
8172
@Override
8273
public int hashCode() {
8374
return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction);

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
528528
String parentPath = context.nestedLookup().getNestedParent(fieldName);
529529
Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor();
530530

531+
BitSetProducer parentBitSet = null;
531532
if (parentPath != null) {
532-
final BitSetProducer parentBitSet;
533533
final Query parentFilter;
534534
NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper();
535535
if (originalObjectMapper != null) {
@@ -558,17 +558,18 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
558558
// Now join the filterQuery & parentFilter to provide the matching blocks of children
559559
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
560560
}
561-
return vectorFieldType.createKnnQuery(
562-
queryVector,
563-
k,
564-
adjustedNumCands,
565-
numCandidatesFactor,
566-
filterQuery,
567-
vectorSimilarity,
568-
parentBitSet
569-
);
570561
}
571-
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, numCandidatesFactor, filterQuery, vectorSimilarity, null);
562+
563+
return vectorFieldType.createKnnQuery(
564+
queryVector,
565+
k,
566+
adjustedNumCands,
567+
requestSize,
568+
numCandidatesFactor,
569+
filterQuery,
570+
vectorSimilarity,
571+
parentBitSet
572+
);
572573
}
573574

574575
@Override

server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,15 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide
3232
private final String fieldName;
3333
private final float[] floatTarget;
3434
private final VectorSimilarityFunction vectorSimilarityFunction;
35-
private final Integer k;
35+
private final int k;
3636
private final Query innerQuery;
37-
38-
private QueryProfilerProvider vectorProfiling;
37+
private long vectorOperations = 0;
3938

4039
public RescoreKnnVectorQuery(
4140
String fieldName,
4241
float[] floatTarget,
4342
VectorSimilarityFunction vectorSimilarityFunction,
44-
Integer k,
43+
int k,
4544
Query innerQuery
4645
) {
4746
this.fieldName = fieldName;
@@ -54,19 +53,12 @@ public RescoreKnnVectorQuery(
5453
@Override
5554
public Query rewrite(IndexSearcher searcher) throws IOException {
5655
DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
57-
// Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need
58-
// to calculate top k and return directly the query to understand how many comparisons were done
59-
vectorProfiling = (QueryProfilerProvider) valueSource;
6056
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
6157
Query query = searcher.rewrite(functionScoreQuery);
6258

63-
if (k == null) {
64-
// No need to calculate top k - let the request size limit the results.
65-
return query;
66-
}
67-
6859
// Retrieve top k documents from the rescored query
6960
TopDocs topDocs = searcher.search(query, k);
61+
vectorOperations = topDocs.totalHits.value();
7062
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
7163
int[] docIds = new int[scoreDocs.length];
7264
float[] scores = new float[scoreDocs.length];
@@ -82,7 +74,7 @@ public Query innerQuery() {
8274
return innerQuery;
8375
}
8476

85-
public Integer k() {
77+
public int k() {
8678
return k;
8779
}
8880

@@ -92,10 +84,7 @@ public void profile(QueryProfiler queryProfiler) {
9284
queryProfilerProvider.profile(queryProfiler);
9385
}
9486

95-
if (vectorProfiling == null) {
96-
throw new IllegalStateException("Query should have been rewritten");
97-
}
98-
vectorProfiling.profile(queryProfiler);
87+
queryProfiler.addVectorOpsCount(vectorOperations);
9988
}
10089

10190
@Override

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,7 +1674,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
16741674

16751675
Exception e = expectThrows(
16761676
IllegalArgumentException.class,
1677-
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null)
1677+
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, 3, null, null, null, null)
16781678
);
16791679
assertThat(
16801680
e.getMessage(),
@@ -1687,6 +1687,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
16871687
VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }),
16881688
3,
16891689
3,
1690+
3,
16901691
null,
16911692
null,
16921693
null,
@@ -1700,7 +1701,16 @@ public void testByteVectorQueryBoundaries() throws IOException {
17001701

17011702
e = expectThrows(
17021703
IllegalArgumentException.class,
1703-
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null)
1704+
() -> denseVectorFieldType.createKnnQuery(
1705+
VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }),
1706+
3,
1707+
3,
1708+
3,
1709+
null,
1710+
null,
1711+
null,
1712+
null
1713+
)
17041714
);
17051715
assertThat(
17061716
e.getMessage(),
@@ -1709,7 +1719,16 @@ public void testByteVectorQueryBoundaries() throws IOException {
17091719

17101720
e = expectThrows(
17111721
IllegalArgumentException.class,
1712-
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null)
1722+
() -> denseVectorFieldType.createKnnQuery(
1723+
VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }),
1724+
3,
1725+
3,
1726+
3,
1727+
null,
1728+
null,
1729+
null,
1730+
null
1731+
)
17131732
);
17141733
assertThat(
17151734
e.getMessage(),
@@ -1722,6 +1741,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
17221741
VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }),
17231742
3,
17241743
3,
1744+
3,
17251745
null,
17261746
null,
17271747
null,
@@ -1736,6 +1756,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
17361756
VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }),
17371757
3,
17381758
3,
1759+
3,
17391760
null,
17401761
null,
17411762
null,
@@ -1753,6 +1774,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
17531774
VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }),
17541775
3,
17551776
3,
1777+
3,
17561778
null,
17571779
null,
17581780
null,
@@ -1787,6 +1809,7 @@ public void testFloatVectorQueryBoundaries() throws IOException {
17871809
VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }),
17881810
3,
17891811
3,
1812+
3,
17901813
null,
17911814
null,
17921815
null,
@@ -1801,6 +1824,7 @@ public void testFloatVectorQueryBoundaries() throws IOException {
18011824
VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }),
18021825
3,
18031826
3,
1827+
3,
18041828
null,
18051829
null,
18061830
null,
@@ -1818,6 +1842,7 @@ public void testFloatVectorQueryBoundaries() throws IOException {
18181842
VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }),
18191843
3,
18201844
3,
1845+
3,
18211846
null,
18221847
null,
18231848
null,

0 commit comments

Comments
 (0)