Skip to content

Commit fb33b75

Browse files
Improve tests and documentation
1 parent fa5d807 commit fb33b75

File tree

2 files changed

+83
-34
lines changed

2 files changed

+83
-34
lines changed

server/src/internalClusterTest/java/org/elasticsearch/search/KnnSearchIT.java

Lines changed: 81 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,79 +11,127 @@
1111

1212
import org.elasticsearch.action.search.SearchRequest;
1313
import org.elasticsearch.action.search.SearchResponse;
14-
import org.elasticsearch.action.search.SearchScrollRequest;
1514
import org.elasticsearch.client.internal.Client;
15+
import org.elasticsearch.common.settings.Settings;
1616
import org.elasticsearch.core.TimeValue;
17+
import org.elasticsearch.index.query.QueryBuilders;
1718
import org.elasticsearch.search.builder.SearchSourceBuilder;
1819
import org.elasticsearch.search.vectors.KnnSearchBuilder;
20+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
1921
import org.elasticsearch.test.ESIntegTestCase;
2022
import org.elasticsearch.xcontent.XContentBuilder;
2123
import org.elasticsearch.xcontent.XContentFactory;
22-
import org.elasticsearch.xcontent.XContentType;
2324

2425
import java.util.List;
2526

26-
import static org.hamcrest.Matchers.equalTo;
2727
import static org.hamcrest.Matchers.notNullValue;
2828

29-
@ESIntegTestCase.ClusterScope(minNumDataNodes = 3)
29+
@ESIntegTestCase.ClusterScope(minNumDataNodes = 2)
3030
public class KnnSearchIT extends ESIntegTestCase {
3131

3232
private static final String INDEX_NAME = "test_knn_index";
3333
private static final String VECTOR_FIELD = "vector";
34-
private static final int DIMENSION = 2;
3534

3635
private XContentBuilder createKnnMapping() throws Exception {
3736
return XContentFactory.jsonBuilder()
3837
.startObject()
3938
.startObject("properties")
4039
.startObject(VECTOR_FIELD)
4140
.field("type", "dense_vector")
42-
.field("dims", DIMENSION)
41+
.field("dims", 2)
4342
.field("index", true)
4443
.field("similarity", "l2_norm")
44+
.startObject("index_options")
45+
.field("type", "hnsw")
46+
.endObject()
47+
.endObject()
48+
.startObject("category")
49+
.field("type", "keyword")
4550
.endObject()
4651
.endObject()
4752
.endObject();
4853
}
4954

5055
public void testKnnSearchWithScroll() throws Exception {
56+
final int numShards = randomIntBetween(1, 3);
5157
Client client = client();
52-
53-
client.admin().indices().prepareCreate(INDEX_NAME).setMapping(createKnnMapping()).get();
54-
55-
int count = randomIntBetween(10, 20);
58+
client.admin()
59+
.indices()
60+
.prepareCreate(INDEX_NAME)
61+
.setSettings(Settings.builder().put("index.number_of_shards", numShards))
62+
.setMapping(createKnnMapping())
63+
.get();
64+
65+
final int count = 100;
5666
for (int i = 0; i < count; i++) {
57-
client.prepareIndex(INDEX_NAME).setSource(XContentType.JSON, VECTOR_FIELD, new float[] { i, i }).get();
67+
XContentBuilder source = XContentFactory.jsonBuilder()
68+
.startObject()
69+
.field(VECTOR_FIELD, new float[] { i * 0.1f, i * 0.1f })
70+
.field("category", i >= 90 ? "last_ten" : null)
71+
.endObject();
72+
client.prepareIndex(INDEX_NAME).setSource(source).get();
5873
}
59-
6074
refresh(INDEX_NAME);
6175

62-
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
63-
int k = count / 2;
64-
sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, k, null, null)));
76+
final int k = randomIntBetween(11, 15);
77+
// test top level knn search
78+
{
79+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
80+
sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null)));
81+
executeScrollSearch(client, sourceBuilder, k);
82+
}
83+
// test top level knn search + another query
84+
{
85+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
86+
sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null)));
87+
sourceBuilder.query(QueryBuilders.existsQuery("category").boost(10));
88+
executeScrollSearch(client, sourceBuilder, k + 10);
89+
}
6590

91+
// test knn query
92+
{
93+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
94+
sourceBuilder.query(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null));
95+
executeScrollSearch(client, sourceBuilder, k * numShards);
96+
}
97+
// test knn query + another query
98+
{
99+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
100+
sourceBuilder.query(
101+
QueryBuilders.boolQuery()
102+
.should(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null))
103+
.should(QueryBuilders.existsQuery("category").boost(10))
104+
);
105+
executeScrollSearch(client, sourceBuilder, k * numShards + 10);
106+
}
107+
108+
}
109+
110+
private static void executeScrollSearch(Client client, SearchSourceBuilder sourceBuilder, int expectedNumHits) {
66111
SearchRequest searchRequest = new SearchRequest(INDEX_NAME);
67112
searchRequest.source(sourceBuilder).scroll(TimeValue.timeValueMinutes(1));
68113

69-
SearchResponse firstResponse = client.search(searchRequest).actionGet();
70-
assertThat(firstResponse.getScrollId(), notNullValue());
71-
assertThat(firstResponse.getHits().getHits().length, equalTo(k));
72-
firstResponse.decRef();
73-
74-
while (true) {
75-
SearchScrollRequest scrollRequest = new SearchScrollRequest(firstResponse.getScrollId());
76-
scrollRequest.scroll(TimeValue.timeValueMinutes(1));
77-
SearchResponse scrollResponse = client.searchScroll(scrollRequest).actionGet();
78-
if (scrollResponse.getHits().getHits().length == 0) {
79-
break;
80-
}
81-
assertThat(scrollResponse.getHits().getHits().length, equalTo(1));
82-
assertThat(scrollResponse.getScrollId(), notNullValue());
83-
assertThat(scrollResponse.getHits().getTotalHits().value(), equalTo((long) k));
84-
scrollResponse.decRef();
114+
SearchResponse searchResponse = client.search(searchRequest).actionGet();
115+
int hitsCollected = 0;
116+
float prevScore = Float.POSITIVE_INFINITY;
117+
try {
118+
do {
119+
assertThat(searchResponse.getScrollId(), notNullValue());
120+
assertEquals(expectedNumHits, searchResponse.getHits().getTotalHits().value());
121+
// assert correct order of returned hits
122+
for (var searchHit : searchResponse.getHits()) {
123+
assert (searchHit.getScore() <= prevScore);
124+
prevScore = searchHit.getScore();
125+
hitsCollected += 1;
126+
}
127+
searchResponse.decRef();
128+
searchResponse = client().prepareSearchScroll(searchResponse.getScrollId()).setScroll(TimeValue.timeValueMinutes(1)).get();
129+
} while (searchResponse.getHits().getHits().length > 0);
130+
} finally {
131+
assertEquals(expectedNumHits, hitsCollected);
132+
clearScroll(searchResponse.getScrollId());
133+
searchResponse.decRef();
85134
}
86-
87-
client.prepareClearScroll().addScrollId(firstResponse.getScrollId()).get();
88135
}
136+
89137
}

server/src/main/java/org/elasticsearch/search/internal/LegacyReaderContext.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ public Engine.Searcher acquireSearcher(String source) {
7373
@Override
7474
public ShardSearchRequest getShardSearchRequest(ShardSearchRequest other) {
7575
if (other != null) {
76-
// the source builder maybe changed in knn query or another case
76+
// The top level knn search modifies the source after the DFS phase.
77+
// so we need to update the source stored in the context.
7778
shardSearchRequest.source(other.source());
7879
}
7980
return shardSearchRequest;

0 commit comments

Comments
 (0)