Skip to content

Commit 7045500

Browse files
committed
Limit for rescoring factor is 1.0, so we can't have less rescored docs than num_candidates
1 parent 9412be0 commit 7045500

File tree

4 files changed

+6
-13
lines changed

4 files changed

+6
-13
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2113,11 +2113,8 @@ && isNotUnitVector(squaredMagnitude)) {
21132113

21142114
int adjustedNumCands = numCands;
21152115
if (needsRescore(numCandsFactor)) {
2116-
// We shouldn't have less than k candidates (or 1 in case k is not set) to rescore
2117-
int minCands = k == null ? 1 : k;
21182116
// k <= numCands * numCandsFactor <= NUM_CANDS_OVERSAMPLE_LIMIT. Adjust otherwise.
2119-
adjustedNumCands = Math.max(minCands, (int) Math.ceil(numCands * numCandsFactor));
2120-
adjustedNumCands = Math.min(adjustedNumCands, NUM_CANDS_OVERSAMPLE_LIMIT);
2117+
adjustedNumCands = Math.min((int) Math.ceil(numCands * numCandsFactor), NUM_CANDS_OVERSAMPLE_LIMIT);
21212118
}
21222119
Query knnQuery = parentFilter != null
21232120
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, adjustedNumCands, parentFilter)

server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
public class RescoreVectorBuilder implements Writeable, ToXContentObject {
2525

2626
public static final ParseField NUM_CANDIDATES_FACTOR_FIELD = new ParseField("num_candidates_factor");
27-
public static final float MIN_OVERSAMPLE = 0.0F;
27+
public static final float MIN_OVERSAMPLE = 1.0F;
2828
private static final ConstructingObjectParser<RescoreVectorBuilder, Void> PARSER = new ConstructingObjectParser<>(
2929
"rescore_vector",
3030
args -> new RescoreVectorBuilder((Float) args[0])
@@ -39,8 +39,8 @@ public class RescoreVectorBuilder implements Writeable, ToXContentObject {
3939

4040
public RescoreVectorBuilder(Float numCandidatesFactor) {
4141
Objects.requireNonNull(numCandidatesFactor, "[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be set");
42-
if (numCandidatesFactor <= MIN_OVERSAMPLE) {
43-
throw new IllegalArgumentException("[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be > " + MIN_OVERSAMPLE);
42+
if (numCandidatesFactor < MIN_OVERSAMPLE) {
43+
throw new IllegalArgumentException("[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be >= " + MIN_OVERSAMPLE);
4444
}
4545
this.numCandidatesFactor = numCandidatesFactor;
4646
}

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,6 @@ public void testRescoreOversampleModifiesNumCandidates() {
462462
// Oversampling limits for num candidates
463463
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, 1000, 10000);
464464
checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, 5000, 10000);
465-
// Oversampling is capped at k as a minimum
466-
checkRescoreQueryParameters(fieldType, 10, 100, 0.01F, 10, 10);
467-
// Oversampling is capped at 1 as a minimum if k is not specified
468-
checkRescoreQueryParameters(fieldType, null, 100, 0.0001F, null, 1);
469465
}
470466

471467
private static void checkRescoreQueryParameters(

server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ public void testInvalidK() {
257257
public void testInvalidRescoreVectorBuilder() {
258258
IllegalArgumentException e = expectThrows(
259259
IllegalArgumentException.class,
260-
() -> new KnnSearchBuilder("field", randomVector(3), 10, 100, new RescoreVectorBuilder(0.0F), null)
260+
() -> new KnnSearchBuilder("field", randomVector(3), 10, 100, new RescoreVectorBuilder(0.99F), null)
261261
);
262-
assertThat(e.getMessage(), containsString("[num_candidates_factor] must be > 0.0"));
262+
assertThat(e.getMessage(), containsString("[num_candidates_factor] must be >= 1.0"));
263263
}
264264

265265
public void testRewrite() throws Exception {

0 commit comments

Comments
 (0)