Skip to content

Commit 87cba1e

Browse files
committed
fixed scroll with knn query
1 parent c864c6c commit 87cba1e

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search;
11+
12+
import org.elasticsearch.action.search.SearchRequest;
13+
import org.elasticsearch.action.search.SearchResponse;
14+
import org.elasticsearch.action.search.SearchScrollRequest;
15+
import org.elasticsearch.client.internal.Client;
16+
import org.elasticsearch.core.TimeValue;
17+
import org.elasticsearch.search.builder.SearchSourceBuilder;
18+
import org.elasticsearch.search.vectors.KnnSearchBuilder;
19+
import org.elasticsearch.test.ESIntegTestCase;
20+
import org.elasticsearch.xcontent.XContentBuilder;
21+
import org.elasticsearch.xcontent.XContentFactory;
22+
import org.elasticsearch.xcontent.XContentType;
23+
24+
import java.util.List;
25+
26+
import static org.hamcrest.Matchers.equalTo;
27+
import static org.hamcrest.Matchers.notNullValue;
28+
29+
@ESIntegTestCase.ClusterScope(minNumDataNodes = 3)
30+
public class KnnSearchIT extends ESIntegTestCase {
31+
32+
private static final String INDEX_NAME = "test_knn_index";
33+
private static final String VECTOR_FIELD = "vector";
34+
private static final int DIMENSION = 2;
35+
36+
private XContentBuilder createKnnMapping() throws Exception {
37+
return XContentFactory.jsonBuilder()
38+
.startObject()
39+
.startObject("properties")
40+
.startObject(VECTOR_FIELD)
41+
.field("type", "dense_vector")
42+
.field("dims", DIMENSION)
43+
.field("index", true)
44+
.field("similarity", "l2_norm")
45+
.endObject()
46+
.endObject()
47+
.endObject();
48+
}
49+
50+
public void testKnnSearchWithScroll() throws Exception {
51+
Client client = client();
52+
53+
// create index
54+
client.admin().indices().prepareCreate(INDEX_NAME).setMapping(createKnnMapping()).get();
55+
56+
// 插入测试数据
57+
int count = randomIntBetween(10, 20);
58+
for (int i = 0; i < count; i++) {
59+
client.prepareIndex(INDEX_NAME)
60+
.setSource(XContentType.JSON, VECTOR_FIELD, new float[]{i, i})
61+
.get();
62+
}
63+
64+
65+
refresh(INDEX_NAME);
66+
67+
// 构建KNN搜索请求
68+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
69+
int k = count / 2;
70+
sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[]{0, 0}, k, k, null, null)));
71+
72+
SearchRequest searchRequest = new SearchRequest(INDEX_NAME);
73+
searchRequest.source(sourceBuilder)
74+
.scroll(TimeValue.timeValueMinutes(1));
75+
76+
// 执行首次搜索
77+
SearchResponse firstResponse = client.search(searchRequest).actionGet();
78+
assertThat(firstResponse.getScrollId(), notNullValue());
79+
assertThat(firstResponse.getHits().getHits().length, equalTo(k));
80+
81+
while (true) {
82+
SearchScrollRequest scrollRequest = new SearchScrollRequest(firstResponse.getScrollId());
83+
scrollRequest.scroll(TimeValue.timeValueMinutes(1));
84+
SearchResponse scrollResponse = client.searchScroll(scrollRequest).actionGet();
85+
if (scrollResponse.getHits().getHits().length == 0) {
86+
break;
87+
}
88+
assertThat(scrollResponse.getHits().getHits().length, equalTo(1));
89+
assertThat(scrollResponse.getScrollId(), notNullValue());
90+
assertThat(scrollResponse.getHits().getTotalHits().value(), equalTo((long) k));
91+
}
92+
93+
// 清理Scroll上下文
94+
client.prepareClearScroll().addScrollId(firstResponse.getScrollId()).get();
95+
}
96+
}

server/src/main/java/org/elasticsearch/search/internal/LegacyReaderContext.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ public Engine.Searcher acquireSearcher(String source) {
7272

7373
@Override
7474
public ShardSearchRequest getShardSearchRequest(ShardSearchRequest other) {
75+
if (other != null) {
76+
shardSearchRequest.source(other.source());
77+
}
7578
return shardSearchRequest;
7679
}
7780

0 commit comments

Comments
 (0)