Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -605,3 +605,51 @@ setup:
- match: { hits.hits.0._score: $knn_score0 }
- match: { hits.hits.1._score: $knn_score1 }
- match: { hits.hits.2._score: $knn_score2 }

---
"kNN search with num_candidates exceeds max allowed value":
- requires:
reason: 'num_candidates exceeds max allowed value'
test_runner_features: [capabilities]

- do:
indices.create:
index: test_num_candidates
body:
mappings:
properties:
vector:
type: dense_vector
element_type: float
dims: 5
settings:
index.max_knn_num_candidates: 500

- do:
search:
index: test_num_candidates
body:
knn:
field: vector
query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
k: 2
num_candidates: 200

- match: { hits.total.value: 0 }

- do:
indices.put_settings:
index: test_num_candidates
body:
index.max_knn_num_candidates: 100

- do:
catch: /\[num_candidates\] cannot exceed \[100\]/
search:
index: test_num_candidates
body:
knn:
field: vector
query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
k: 2
num_candidates: 200
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings {
IndexSettings.MAX_REFRESH_LISTENERS_PER_SHARD,
IndexSettings.MAX_SLICES_PER_SCROLL,
IndexSettings.MAX_REGEX_LENGTH_SETTING,
IndexSettings.INDEX_MAX_KNN_NUM_CANDIDATES_SETTING,
ShardsLimitAllocationDecider.INDEX_TOTAL_SHARDS_PER_NODE_SETTING,
IndexSettings.INDEX_GC_DELETES_SETTING,
IndexSettings.INDEX_SOFT_DELETES_SETTING,
Expand Down
23 changes: 23 additions & 0 deletions server/src/main/java/org/elasticsearch/index/IndexSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,17 @@ public final class IndexSettings {
Property.IndexScope
);

/**
* The maximum number of candidates to be considered for KNN search. The default value is 10_000.
*/
public static final Setting<Integer> INDEX_MAX_KNN_NUM_CANDIDATES_SETTING = Setting.intSetting(
"index.max_knn_num_candidates",
10_000,
1,
Property.Dynamic,
Property.IndexScope
);

public static final TimeValue DEFAULT_REFRESH_INTERVAL = new TimeValue(1, TimeUnit.SECONDS);
public static final Setting<TimeValue> NODE_DEFAULT_REFRESH_INTERVAL_SETTING = Setting.timeSetting(
"node._internal.default_refresh_interval",
Expand Down Expand Up @@ -930,6 +941,8 @@ private void setRetentionLeaseMillis(final TimeValue retentionLease) {
*/
private volatile int maxRegexLength;

private volatile int maxKnnNumCandidates;

private final IndexRouting indexRouting;

/**
Expand Down Expand Up @@ -1083,6 +1096,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti
mappingDepthLimit = scopedSettings.get(INDEX_MAPPING_DEPTH_LIMIT_SETTING);
mappingFieldNameLengthLimit = scopedSettings.get(INDEX_MAPPING_FIELD_NAME_LENGTH_LIMIT_SETTING);
mappingDimensionFieldsLimit = scopedSettings.get(INDEX_MAPPING_DIMENSION_FIELDS_LIMIT_SETTING);
maxKnnNumCandidates = scopedSettings.get(INDEX_MAX_KNN_NUM_CANDIDATES_SETTING);
indexRouting = IndexRouting.fromIndexMetadata(indexMetadata);
sourceKeepMode = scopedSettings.get(Mapper.SYNTHETIC_SOURCE_KEEP_INDEX_SETTING);
es87TSDBCodecEnabled = scopedSettings.get(TIME_SERIES_ES87TSDB_CODEC_ENABLED_SETTING);
Expand Down Expand Up @@ -1203,6 +1217,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti
this::setSkipIgnoredSourceWrite
);
scopedSettings.addSettingsUpdateConsumer(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, this::setSkipIgnoredSourceRead);
scopedSettings.addSettingsUpdateConsumer(INDEX_MAX_KNN_NUM_CANDIDATES_SETTING, this::setMaxKnnNumCandidates);
}

private void setSearchIdleAfter(TimeValue searchIdleAfter) {
Expand Down Expand Up @@ -1821,4 +1836,12 @@ public TimestampBounds getTimestampBounds() {
public IndexRouting getIndexRouting() {
return indexRouting;
}

public int getMaxKnnNumCandidates() {
return maxKnnNumCandidates;
}

public void setMaxKnnNumCandidates(int maxKnnNumCandidates) {
this.maxKnnNumCandidates = maxKnnNumCandidates;
}
}
12 changes: 11 additions & 1 deletion server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.Map;

import static org.elasticsearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;

/**
* DFS phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase.
Expand Down Expand Up @@ -177,7 +178,7 @@ private static Timer maybeStartTimer(DfsProfiler profiler, DfsTimingType dtt) {
return null;
};

private static void executeKnnVectorQuery(SearchContext context) throws IOException {
static void executeKnnVectorQuery(SearchContext context) throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eagerly validating like this is OK. However, KnnVectorQueryBuilder#doToQuery should also validate as its possible to provide a knn query that isn't executed through DFS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think the check in KnnVectorQueryBuilder#doToQuery is better, I will change the check code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

SearchSourceBuilder source = context.request().source();
if (source == null || source.knnSearch().isEmpty()) {
return;
Expand All @@ -186,6 +187,15 @@ private static void executeKnnVectorQuery(SearchContext context) throws IOExcept
SearchExecutionContext searchExecutionContext = context.getSearchExecutionContext();
List<KnnSearchBuilder> knnSearch = source.knnSearch();
List<KnnVectorQueryBuilder> knnVectorQueryBuilders = knnSearch.stream().map(KnnSearchBuilder::toQueryBuilder).toList();
int maxKnnNumCandidates = context.indexShard().indexSettings().getMaxKnnNumCandidates();
for (KnnVectorQueryBuilder knnVectorQueryBuilder : knnVectorQueryBuilders) {
if (knnVectorQueryBuilder.numCands() != null && knnVectorQueryBuilder.numCands() > maxKnnNumCandidates) {
throw new IllegalArgumentException(
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxKnnNumCandidates + "]"
);
}
}

// Since we apply boost during the DfsQueryPhase, we should not apply boost here:
knnVectorQueryBuilders.forEach(knnVectorQueryBuilder -> knnVectorQueryBuilder.boost(DEFAULT_BOOST));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
* Defines a kNN search to run in the search request.
*/
public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewriteable<KnnSearchBuilder> {
public static final int NUM_CANDS_LIMIT = 10_000;
public static final float NUM_CANDS_MULTIPLICATIVE_FACTOR = 1.5f;

public static final ParseField FIELD_FIELD = new ParseField("field");
Expand Down Expand Up @@ -264,9 +263,6 @@ private KnnSearchBuilder(
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than " + "[" + K_FIELD.getPreferredName() + "]"
);
}
if (numCandidates > NUM_CANDS_LIMIT) {
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
}
if (queryVector == null && queryVectorBuilder == null) {
throw new IllegalArgumentException(
format(
Expand Down Expand Up @@ -667,9 +663,7 @@ public Builder rescoreVectorBuilder(RescoreVectorBuilder rescoreVectorBuilder) {
public KnnSearchBuilder build(int size) {
int requestSize = size < 0 ? DEFAULT_SIZE : size;
int adjustedK = k == null ? requestSize : k;
int adjustedNumCandidates = numCandidates == null
? Math.round(Math.min(NUM_CANDS_LIMIT, NUM_CANDS_MULTIPLICATIVE_FACTOR * adjustedK))
: numCandidates;
int adjustedNumCandidates = numCandidates == null ? Math.round(NUM_CANDS_MULTIPLICATIVE_FACTOR * adjustedK) : numCandidates;
return new KnnSearchBuilder(
field,
queryVectorBuilder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ public void toSearchRequest(SearchRequestBuilder builder) {

// visible for testing
static class KnnSearch {
private static final int NUM_CANDS_LIMIT = 10000;
static final ParseField FIELD_FIELD = new ParseField("field");
static final ParseField K_FIELD = new ParseField("k");
static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
Expand Down Expand Up @@ -253,9 +252,6 @@ public KnnVectorQueryBuilder toQueryBuilder() {
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than " + "[" + K_FIELD.getPreferredName() + "]"
);
}
if (numCands > NUM_CANDS_LIMIT) {
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
}
return new KnnVectorQueryBuilder(field, queryVector, numCands, numCands, null, null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
*/
public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBuilder> {
public static final String NAME = "knn";
private static final int NUM_CANDS_LIMIT = 10_000;
private static final float NUM_CANDS_MULTIPLICATIVE_FACTOR = 1.5f;

public static final ParseField FIELD_FIELD = new ParseField("field");
Expand Down Expand Up @@ -183,9 +182,6 @@ private KnnVectorQueryBuilder(
if (k != null && k < 1) {
throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
}
if (numCands != null && numCands > NUM_CANDS_LIMIT) {
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
}
if (k != null && numCands != null && numCands < k) {
throw new IllegalArgumentException(
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]"
Expand Down Expand Up @@ -496,7 +492,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
k = Math.min(k, numCands);
}
}
int adjustedNumCands = numCands == null ? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * k, NUM_CANDS_LIMIT)) : numCands;
int adjustedNumCands = numCands == null ? Math.round(NUM_CANDS_MULTIPLICATIVE_FACTOR * k) : numCands;
if (fieldType == null) {
return new MatchNoDocsQuery();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.search.profile.Profilers;
import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;
import org.elasticsearch.search.profile.query.CollectorResult;
import org.elasticsearch.search.profile.query.QueryProfileShardResult;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.IndexSettingsModule;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
Expand All @@ -32,6 +40,9 @@
import java.util.List;
import java.util.concurrent.ThreadPoolExecutor;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class DfsPhaseTests extends ESTestCase {

ThreadPoolExecutor threadPoolExecutor;
Expand Down Expand Up @@ -102,4 +113,37 @@ public void testSingleKnnSearch() throws IOException {
reader.close();
}
}

public void testNumCandidatesExceedsMax() {
Settings settings = Settings.builder().put("index.max_knn_num_candidates", 100).build();
IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("test", settings);

SearchContext context = mock(SearchContext.class);
when(context.indexShard()).thenAnswer(invocation -> {
IndexShard mockIndexShard = mock(IndexShard.class);
when(mockIndexShard.indexSettings()).thenReturn(indexSettings);
return mockIndexShard;
});

// 构造超过最大值的查询参数
KnnSearchBuilder queryBuilder = new KnnSearchBuilder(
"float_vector",
new float[] { 0, 0, 0 },
10,
150, // 超过maxKnnNumCandidates的值
null,
null
);
SearchSourceBuilder source = new SearchSourceBuilder();
source.knnSearch(List.of(queryBuilder));
ShardSearchRequest searchRequest = mock(ShardSearchRequest.class);
when(searchRequest.source()).thenReturn(source);
when(context.request()).thenReturn(searchRequest);

// 验证异常抛出
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> DfsPhase.executeKnnVectorQuery(context));
assertEquals("[num_candidates] cannot exceed [100]", e.getMessage());

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,6 @@ public void testNumCandsLessThanK() {
assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]"));
}

public void testNumCandsExceedsLimit() {
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null, null)
);
assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]"));
}

public void testInvalidK() {
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,22 +179,6 @@ public void testNumCandsLessThanK() throws IOException {
assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]"));
}

public void testNumCandsExceedsLimit() throws IOException {
XContentType xContentType = randomFrom(XContentType.values());
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent())
.startObject()
.startObject(KnnSearchRequestParser.KNN_SECTION_FIELD.getPreferredName())
.field(KnnSearch.FIELD_FIELD.getPreferredName(), "field")
.field(KnnSearch.K_FIELD.getPreferredName(), 100)
.field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 10002)
.field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f })
.endObject()
.endObject();

IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> parseSearchRequest(builder));
assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]"));
}

public void testInvalidK() throws IOException {
XContentType xContentType = randomFrom(XContentType.values());
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent())
Expand Down
Loading