Skip to content

Commit 0a80088

Browse files
committed
adding integration test
1 parent a771e72 commit 0a80088

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.query;
11+
12+
import org.elasticsearch.cluster.metadata.IndexMetadata;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
15+
import org.elasticsearch.index.query.QueryBuilders;
16+
import org.elasticsearch.search.vectors.KnnSearchBuilder;
17+
import org.elasticsearch.test.ESIntegTestCase;
18+
import org.elasticsearch.xcontent.XContentBuilder;
19+
import org.elasticsearch.xcontent.XContentFactory;
20+
import org.junit.Before;
21+
22+
import java.io.IOException;
23+
import java.util.List;
24+
25+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
26+
27+
public class VectorIT extends ESIntegTestCase {
28+
29+
private static final String INDEX_NAME = "test";
30+
private static final String VECTOR_FIELD = "vector";
31+
private static final String NUM_ID_FIELD = "num_id";
32+
33+
private static void randomVector(float[] vector) {
34+
for (int i = 0; i < vector.length; i++) {
35+
vector[i] = randomFloat();
36+
}
37+
}
38+
39+
@Before
40+
public void setup() throws IOException {
41+
XContentBuilder mapping = XContentFactory.jsonBuilder()
42+
.startObject()
43+
.startObject("properties")
44+
.startObject(VECTOR_FIELD)
45+
.field("type", "dense_vector")
46+
.startObject("index_options")
47+
.field("type", "hnsw")
48+
.endObject()
49+
.endObject()
50+
.startObject(NUM_ID_FIELD)
51+
.field("type", "long")
52+
.endObject()
53+
.endObject()
54+
.endObject();
55+
56+
Settings settings = Settings.builder()
57+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
58+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
59+
.build();
60+
prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get();
61+
ensureGreen(INDEX_NAME);
62+
for (int i = 0; i < 150; i++) {
63+
float[] vector = new float[8];
64+
randomVector(vector);
65+
prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource(VECTOR_FIELD, vector, NUM_ID_FIELD, i).get();
66+
}
67+
forceMerge(true);
68+
refresh(INDEX_NAME);
69+
}
70+
71+
public void testFilteredQueryStrategy() {
72+
float[] vector = new float[8];
73+
randomVector(vector);
74+
var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, null, null).addFilterQuery(
75+
QueryBuilders.rangeQuery(NUM_ID_FIELD).lte(30)
76+
);
77+
assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), acornResponse -> {
78+
assertNotEquals(0, acornResponse.getHits().getHits().length);
79+
var profileResults = acornResponse.getProfileResults();
80+
long vectorOpsSum = profileResults.values()
81+
.stream()
82+
.mapToLong(
83+
pr -> pr.getQueryPhase()
84+
.getSearchProfileDfsPhaseResult()
85+
.getQueryProfileShardResult()
86+
.stream()
87+
.mapToLong(qpr -> qpr.getVectorOperationsCount().longValue())
88+
.sum()
89+
)
90+
.sum();
91+
client().admin()
92+
.indices()
93+
.prepareUpdateSettings(INDEX_NAME)
94+
.setSettings(
95+
Settings.builder()
96+
.put(
97+
DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC.getKey(),
98+
DenseVectorFieldMapper.FilterHeuristic.FANOUT.toString()
99+
)
100+
)
101+
.get();
102+
assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), fanoutResponse -> {
103+
assertNotEquals(0, fanoutResponse.getHits().getHits().length);
104+
var fanoutProfileResults = fanoutResponse.getProfileResults();
105+
long fanoutVectorOpsSum = fanoutProfileResults.values()
106+
.stream()
107+
.mapToLong(
108+
pr -> pr.getQueryPhase()
109+
.getSearchProfileDfsPhaseResult()
110+
.getQueryProfileShardResult()
111+
.stream()
112+
.mapToLong(qpr -> qpr.getVectorOperationsCount().longValue())
113+
.sum()
114+
)
115+
.sum();
116+
assertTrue(
117+
"fanoutVectorOps [" + fanoutVectorOpsSum + "] is not gt acornVectorOps [" + vectorOpsSum + "]",
118+
fanoutVectorOpsSum > vectorOpsSum
119+
);
120+
});
121+
});
122+
}
123+
124+
}

server/src/main/java/org/elasticsearch/search/profile/query/QueryProfileShardResult.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,8 @@ public int hashCode() {
137137
public String toString() {
138138
return Strings.toString(this);
139139
}
140+
141+
public Long getVectorOperationsCount() {
142+
return vectorOperationsCount;
143+
}
140144
}

0 commit comments

Comments
 (0)