diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml index e49f0634a4887..1f07884c9fadf 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml @@ -82,7 +82,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -100,7 +100,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml index d4bf5e7e9807f..be35dcde2eff3 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml @@ -106,9 +106,9 @@ setup: k: 3 num_candidates: 3 "rescore_vector": - "num_candidates_factor": 2.0 + "oversample": 2.0 - # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard + # We expect the knn search ops + rescoring k * oversample (for rescoring) per shard - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 } # Search with similarity to check number of operations are propagated correctly @@ -131,7 +131,7 @@ setup: num_candidates: 3 similarity: 100000 "rescore_vector": - "num_candidates_factor": 2.0 + "oversample": 2.0 - # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard + # We expect the knn search ops + rescoring k * oversample (for rescoring) per shard - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml index 7d4690204acc7..8f846dd76721d 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml @@ -558,7 +558,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -598,7 +598,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 2567a4ac597d9..3f81c0044d170 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -115,7 +115,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -140,7 +140,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index b1e35789e8737..229d705bc317c 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -378,7 +378,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -398,7 +398,7 @@ setup: field: vector query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index 54e9eadf42e0b..baf568762dd17 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -556,7 +556,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -575,7 +575,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index a3cd624ef0ab8..0bc111576c2a9 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -114,7 +114,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -139,7 +139,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml index a59aedceff3d3..358ff547036e6 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml @@ -264,7 +264,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -304,7 +304,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml index 6796a92122f9a..0e0180e58fd96 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml @@ -352,7 +352,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -371,7 +371,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index d1d312449cb70..6b59b8f641ee9 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -269,7 +269,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -288,7 +288,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml index effa3fff61525..680433a5945fd 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml @@ -414,7 +414,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -454,7 +454,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml index cdc1d9c64763e..783f08a5d4ff4 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml @@ -230,7 +230,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -270,7 +270,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml index 213b571a0b4be..6559b8d969cb9 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml @@ -263,7 +263,7 @@ setup: capabilities: - method: GET path: /_search - capabilities: [knn_quantized_vector_rescore] + capabilities: [knn_quantized_vector_rescore_oversample] - skip: features: "headers" @@ -303,7 +303,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - num_candidates_factor: 1.5 + oversample: 1.5 # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index b1ea83417772e..3e0656205b976 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -125,7 +125,7 @@ public static boolean isNotUnitVector(float magnitude) { public static final short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to // vector public static final int MAGNITUDE_BYTES = 4; - public static final int NUM_CANDS_OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed for k and num_candidates + public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed private static DenseVectorFieldMapper toType(FieldMapper in) { return (DenseVectorFieldMapper) in; @@ -2021,7 +2021,7 @@ public Query createKnnQuery( VectorData queryVector, int k, int numCands, - Float numCandsFactor, + Float oversample, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2037,7 +2037,7 @@ public Query createKnnQuery( queryVector.asFloatVector(), k, numCands, - numCandsFactor, + oversample, filter, similarityThreshold, parentFilter @@ -2047,7 +2047,11 @@ public Query createKnnQuery( } private boolean needsRescore(Float rescoreOversample) { - return rescoreOversample != null && (indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized()); + return rescoreOversample != null && isQuantized(); + } + + private boolean isQuantized() { + return indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized(); } private Query createKnnBitQuery( @@ -2103,7 +2107,7 @@ private Query createKnnFloatQuery( float[] queryVector, int k, int numCands, - Float numCandsFactor, + Float oversample, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2124,18 +2128,17 @@ && isNotUnitVector(squaredMagnitude)) { } } - Integer adjustedK = k; - int adjustedNumCands = numCands; - if (needsRescore(numCandsFactor)) { - // Get all candidates, get top k as part of rescoring - adjustedK = null; - // numCands * numCandsFactor <= NUM_CANDS_OVERSAMPLE_LIMIT. Adjust otherwise. - adjustedNumCands = Math.min((int) Math.ceil(numCands * numCandsFactor), NUM_CANDS_OVERSAMPLE_LIMIT); + int adjustedK = k; + boolean rescore = needsRescore(oversample); + if (rescore) { + // Will get k * oversample for rescoring, and get the top k + adjustedK = Math.min((int) Math.ceil(k * oversample), OVERSAMPLE_LIMIT); + numCands = Math.max(adjustedK, numCands); } Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) - : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter); - if (needsRescore(numCandsFactor)) { + ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, numCands, parentFilter) + : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter); + if (rescore) { knnQuery = new RescoreKnnVectorQuery( name(), queryVector, diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 214d257f0ed6a..8231046c6586f 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -40,7 +40,7 @@ private SearchCapabilities() {} private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs"; private static final String OPTIMIZED_SCALAR_QUANTIZATION_BBQ = "optimized_scalar_quantization_bbq"; - private static final String KNN_QUANTIZED_VECTOR_RESCORE = "knn_quantized_vector_rescore"; + private static final String KNN_QUANTIZED_VECTOR_RESCORE_OVERSAMPLE = "knn_quantized_vector_rescore_oversample"; private static final String HIGHLIGHT_MAX_ANALYZED_OFFSET_DEFAULT = "highlight_max_analyzed_offset_default"; @@ -55,7 +55,7 @@ private SearchCapabilities() {} capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT); capabilities.add(RANDOM_SAMPLER_WITH_SCORED_SUBAGGS); capabilities.add(OPTIMIZED_SCALAR_QUANTIZATION_BBQ); - capabilities.add(KNN_QUANTIZED_VECTOR_RESCORE); + capabilities.add(KNN_QUANTIZED_VECTOR_RESCORE_OVERSAMPLE); capabilities.add(MOVING_FN_RIGHT_MATH); capabilities.add(K_DEFAULT_TO_SIZE); capabilities.add(KQL_QUERY_SUPPORTED); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index ea4436f8a28eb..193191658af08 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -522,7 +522,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType; String parentPath = context.nestedLookup().getNestedParent(fieldName); - Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor(); + Float oversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample(); BitSetProducer parentBitSet = null; if (parentPath != null) { @@ -556,15 +556,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { } } - return vectorFieldType.createKnnQuery( - queryVector, - k, - adjustedNumCands, - numCandidatesFactor, - filterQuery, - vectorSimilarity, - parentBitSet - ); + return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, oversample, filterQuery, vectorSimilarity, parentBitSet); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java index 4604d4f0ea325..0e110a57d1e14 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -23,7 +23,7 @@ public class RescoreVectorBuilder implements Writeable, ToXContentObject { - public static final ParseField NUM_CANDIDATES_FACTOR_FIELD = new ParseField("num_candidates_factor"); + public static final ParseField OVERSAMPLE_FIELD = new ParseField("oversample"); public static final float MIN_OVERSAMPLE = 1.0F; private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "rescore_vector", @@ -31,33 +31,33 @@ public class RescoreVectorBuilder implements Writeable, ToXContentObject { ); static { - PARSER.declareFloat(ConstructingObjectParser.constructorArg(), NUM_CANDIDATES_FACTOR_FIELD); + PARSER.declareFloat(ConstructingObjectParser.constructorArg(), OVERSAMPLE_FIELD); } // Oversample is required as of now as it is the only field in the rescore vector - private final float numCandidatesFactor; + private final float oversample; public RescoreVectorBuilder(float numCandidatesFactor) { - Objects.requireNonNull(numCandidatesFactor, "[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be set"); + Objects.requireNonNull(numCandidatesFactor, "[" + OVERSAMPLE_FIELD.getPreferredName() + "] must be set"); if (numCandidatesFactor < MIN_OVERSAMPLE) { - throw new IllegalArgumentException("[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be >= " + MIN_OVERSAMPLE); + throw new IllegalArgumentException("[" + OVERSAMPLE_FIELD.getPreferredName() + "] must be >= " + MIN_OVERSAMPLE); } - this.numCandidatesFactor = numCandidatesFactor; + this.oversample = numCandidatesFactor; } public RescoreVectorBuilder(StreamInput in) throws IOException { - this.numCandidatesFactor = in.readFloat(); + this.oversample = in.readFloat(); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeFloat(numCandidatesFactor); + out.writeFloat(oversample); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(NUM_CANDIDATES_FACTOR_FIELD.getPreferredName(), numCandidatesFactor); + builder.field(OVERSAMPLE_FIELD.getPreferredName(), oversample); builder.endObject(); return builder; } @@ -71,15 +71,15 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; RescoreVectorBuilder that = (RescoreVectorBuilder) o; - return Objects.equals(numCandidatesFactor, that.numCandidatesFactor); + return Objects.equals(oversample, that.oversample); } @Override public int hashCode() { - return Objects.hashCode(numCandidatesFactor); + return Objects.hashCode(oversample); } - public float numCandidatesFactor() { - return numCandidatesFactor; + public float oversample() { + return oversample; } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index be4c677d20b03..5c067cb2d0a27 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -36,6 +36,7 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -456,20 +457,18 @@ public void testRescoreOversampleModifiesNumCandidates() { ); // Total results is k, internal k is multiplied by oversample - checkRescoreQueryParameters(fieldType, 10, 200, randomInt(), 2.5F, null, 500, 10); + checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, 25, 200, 10); // If numCands < k, update numCands to k - checkRescoreQueryParameters(fieldType, 10, 20, randomInt(), 2.5F, null, 50, 10); - // Oversampling limits for num candidates - checkRescoreQueryParameters(fieldType, 1000, 1000, randomInt(), 11.0F, null, 10000, 1000); - checkRescoreQueryParameters(fieldType, 5000, 7500, randomInt(), 2.5F, null, 10000, 5000); + checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 25, 25, 10); + // Oversampling limits for k + checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, OVERSAMPLE_LIMIT, OVERSAMPLE_LIMIT, 1000); } private static void checkRescoreQueryParameters( DenseVectorFieldType fieldType, int k, int candidates, - int requestSize, - float numCandsFactor, + float oversample, Integer expectedK, int expectedCandidates, int expectedResults @@ -478,7 +477,7 @@ private static void checkRescoreQueryParameters( VectorData.fromFloats(new float[] { 1, 4, 10 }), k, candidates, - numCandsFactor, + oversample, null, null, null diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 244d539403315..b3764d528ff0f 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -46,7 +46,8 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_OVERSAMPLE_LIMIT; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; +import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -175,8 +176,13 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity())); query = ((VectorSimilarityQuery) query).getInnerKnnQuery(); } + Integer k = queryBuilder.k(); + if (k == null) { + k = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize(); + } if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; + assertEquals(k.intValue(), (rescoreQuery.k())); query = rescoreQuery.innerQuery(); } switch (elementType()) { @@ -190,14 +196,11 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que } BooleanQuery booleanQuery = builder.build(); Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery; - // The field should always be resolved to the concrete field - Integer k = queryBuilder.k(); Integer numCands = queryBuilder.numCands(); if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { - Float numCandsFactor = queryBuilder.rescoreVectorBuilder().numCandidatesFactor(); - int minCands = k == null ? 1 : k; - numCands = Math.max(minCands, (int) Math.ceil(numCands * numCandsFactor)); - numCands = Math.min(numCands, NUM_CANDS_OVERSAMPLE_LIMIT); + Float oversample = queryBuilder.rescoreVectorBuilder().oversample(); + k = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * oversample)); + numCands = Math.max(numCands, k); } Query knnVectorQueryBuilt = switch (elementType()) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 108dc60e2ee3b..8cca3f9ed8a21 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -259,7 +259,7 @@ public void testInvalidRescoreVectorBuilder() { IllegalArgumentException.class, () -> new KnnSearchBuilder("field", randomVector(3), 10, 100, new RescoreVectorBuilder(0.99F), null) ); - assertThat(e.getMessage(), containsString("[num_candidates_factor] must be >= 1.0")); + assertThat(e.getMessage(), containsString("[oversample] must be >= 1.0")); } public void testRewrite() throws Exception {