|
9 | 9 |
|
10 | 10 | import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; |
11 | 11 |
|
12 | | -import org.elasticsearch.client.Request; |
13 | | -import org.elasticsearch.client.ResponseException; |
14 | 12 | import org.elasticsearch.test.TestClustersThreadFilter; |
15 | 13 | import org.elasticsearch.test.cluster.ElasticsearchCluster; |
16 | | -import org.elasticsearch.test.rest.ESRestTestCase; |
17 | | -import org.elasticsearch.xpack.esql.AssertWarnings; |
18 | | -import org.elasticsearch.xpack.esql.CsvTestsDataLoader; |
19 | | -import org.elasticsearch.xpack.esql.action.EsqlCapabilities; |
20 | | -import org.elasticsearch.xpack.esql.qa.rest.ProfileLogger; |
21 | | -import org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase; |
22 | | -import org.junit.After; |
23 | | -import org.junit.Before; |
| 14 | +import org.elasticsearch.xpack.esql.qa.rest.KnnSemanticTextTestCase; |
24 | 15 | import org.junit.ClassRule; |
25 | | -import org.junit.Rule; |
26 | | - |
27 | | -import java.io.IOException; |
28 | | -import java.util.HashMap; |
29 | | -import java.util.List; |
30 | | -import java.util.Map; |
31 | | - |
32 | | -import static org.elasticsearch.rest.RestStatus.BAD_REQUEST; |
33 | | -import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.requestObjectBuilder; |
34 | | -import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.runEsqlSync; |
35 | | -import static org.hamcrest.Matchers.is; |
36 | | -import static org.hamcrest.core.StringContains.containsString; |
37 | 16 |
|
38 | 17 | @ThreadLeakFilters(filters = TestClustersThreadFilter.class) |
39 | | -public class KnnSemanticTextIT extends ESRestTestCase { |
| 18 | +public class KnnSemanticTextIT extends KnnSemanticTextTestCase { |
40 | 19 |
|
41 | 20 | @ClassRule |
42 | 21 | public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test")); |
43 | 22 |
|
44 | | - @Rule(order = Integer.MIN_VALUE) |
45 | | - public ProfileLogger profileLogger = new ProfileLogger(); |
46 | | - |
47 | | - private int numDocs; |
48 | | - private final Map<Integer, String> indexedTexts = new HashMap<>(); |
49 | | - |
50 | 23 | @Override |
51 | 24 | protected String getTestRestCluster() { |
52 | 25 | return cluster.getHttpAddresses(); |
53 | 26 | } |
54 | | - |
55 | | - @Before |
56 | | - public void checkCapability() { |
57 | | - assumeTrue("knn with semantic text not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); |
58 | | - } |
59 | | - |
60 | | - @SuppressWarnings("unchecked") |
61 | | - public void testKnnQueryWithSemanticText() throws IOException { |
62 | | - String knnQuery = """ |
63 | | - FROM semantic-test METADATA _score |
64 | | - | WHERE knn(dense_semantic, [0, 1, 2], 10) |
65 | | - | KEEP id, _score, dense_semantic |
66 | | - | SORT _score DESC |
67 | | - | LIMIT 10 |
68 | | - """; |
69 | | - |
70 | | - Map<String, Object> response = runEsqlQuery(knnQuery); |
71 | | - List<Map<String, Object>> columns = (List<Map<String, Object>>) response.get("columns"); |
72 | | - assertThat(columns.size(), is(3)); |
73 | | - List<List<Object>> rows = (List<List<Object>>) response.get("values"); |
74 | | - assertThat(rows.size(), is(3)); |
75 | | - for (int row = 0; row < rows.size(); row++) { |
76 | | - List<Object> rowData = rows.get(row); |
77 | | - Integer id = (Integer) rowData.get(0); |
78 | | - assertThat(id, is(3 - row)); |
79 | | - } |
80 | | - } |
81 | | - |
82 | | - public void testKnnQueryOnTextField() throws IOException { |
83 | | - String knnQuery = """ |
84 | | - FROM semantic-test METADATA _score |
85 | | - | WHERE knn(text, [0, 1, 2], 10) |
86 | | - | KEEP id, _score, dense_semantic |
87 | | - | SORT _score DESC |
88 | | - | LIMIT 10 |
89 | | - """; |
90 | | - |
91 | | - ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(knnQuery)); |
92 | | - assertThat(re.getResponse().getStatusLine().getStatusCode(), is(BAD_REQUEST.getStatus())); |
93 | | - assertThat(re.getMessage(), containsString("[knn] queries are only supported on [dense_vector] fields")); |
94 | | - } |
95 | | - |
96 | | - public void testKnnQueryOnSparseSemanticTextField() throws IOException { |
97 | | - String knnQuery = """ |
98 | | - FROM semantic-test METADATA _score |
99 | | - | WHERE knn(sparse_semantic, [0, 1, 2], 10) |
100 | | - | KEEP id, _score, sparse_semantic |
101 | | - | SORT _score DESC |
102 | | - | LIMIT 10 |
103 | | - """; |
104 | | - |
105 | | - ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(knnQuery)); |
106 | | - assertThat(re.getResponse().getStatusLine().getStatusCode(), is(BAD_REQUEST.getStatus())); |
107 | | - assertThat(re.getMessage(), containsString("[knn] queries are only supported on [dense_vector] fields")); |
108 | | - } |
109 | | - |
110 | | - @Before |
111 | | - public void setUp() throws Exception { |
112 | | - super.setUp(); |
113 | | - setupInferenceEndpoints(); |
114 | | - setupIndex(); |
115 | | - } |
116 | | - |
117 | | - private void setupIndex() throws IOException { |
118 | | - Request request = new Request("PUT", "/semantic-test"); |
119 | | - request.setJsonEntity(""" |
120 | | - { |
121 | | - "mappings": { |
122 | | - "properties": { |
123 | | - "id": { |
124 | | - "type": "integer" |
125 | | - }, |
126 | | - "dense_semantic": { |
127 | | - "type": "semantic_text", |
128 | | - "inference_id": "test_dense_inference" |
129 | | - }, |
130 | | - "sparse_semantic": { |
131 | | - "type": "semantic_text", |
132 | | - "inference_id": "test_sparse_inference" |
133 | | - }, |
134 | | - "text": { |
135 | | - "type": "text", |
136 | | - "copy_to": ["dense_semantic", "sparse_semantic"] |
137 | | - } |
138 | | - } |
139 | | - }, |
140 | | - "settings": { |
141 | | - "index": { |
142 | | - "number_of_shards": 1, |
143 | | - "number_of_replicas": 0 |
144 | | - } |
145 | | - } |
146 | | - } |
147 | | - """); |
148 | | - assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode()); |
149 | | - |
150 | | - request = new Request("POST", "/_bulk?index=semantic-test&refresh=true"); |
151 | | - request.setJsonEntity(""" |
152 | | - {"index": {"_id": "1"}} |
153 | | - {"id": 1, "text": "sample text"} |
154 | | - {"index": {"_id": "2"}} |
155 | | - {"id": 2, "text": "another sample text"} |
156 | | - {"index": {"_id": "3"}} |
157 | | - {"id": 3, "text": "yet another sample text"} |
158 | | - """); |
159 | | - assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode()); |
160 | | - } |
161 | | - |
162 | | - private void setupInferenceEndpoints() throws IOException { |
163 | | - CsvTestsDataLoader.createTextEmbeddingInferenceEndpoint(client()); |
164 | | - CsvTestsDataLoader.createSparseEmbeddingInferenceEndpoint(client()); |
165 | | - } |
166 | | - |
167 | | - @After |
168 | | - public void tearDown() throws Exception { |
169 | | - super.tearDown(); |
170 | | - client().performRequest(new Request("DELETE", "semantic-test")); |
171 | | - |
172 | | - if (CsvTestsDataLoader.clusterHasTextEmbeddingInferenceEndpoint(client())) { |
173 | | - CsvTestsDataLoader.deleteTextEmbeddingInferenceEndpoint(client()); |
174 | | - } |
175 | | - if (CsvTestsDataLoader.clusterHasSparseEmbeddingInferenceEndpoint(client())) { |
176 | | - CsvTestsDataLoader.deleteSparseEmbeddingInferenceEndpoint(client()); |
177 | | - } |
178 | | - } |
179 | | - |
180 | | - private Map<String, Object> runEsqlQuery(String query) throws IOException { |
181 | | - RestEsqlTestCase.RequestObjectBuilder builder = requestObjectBuilder().query(query); |
182 | | - return runEsqlSync(builder, new AssertWarnings.NoWarnings(), profileLogger); |
183 | | - } |
184 | 27 | } |
0 commit comments