Skip to content

Commit 643c767

Browse files
committed
First IT test for knn with semantic_text
1 parent 0e264b1 commit 643c767

File tree

1 file changed

+129
-0
lines changed
  • x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node

1 file changed

+129
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.qa.single_node;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;
11+
12+
import org.elasticsearch.client.Request;
13+
import org.elasticsearch.test.TestClustersThreadFilter;
14+
import org.elasticsearch.test.cluster.ElasticsearchCluster;
15+
import org.elasticsearch.test.rest.ESRestTestCase;
16+
import org.elasticsearch.xpack.esql.AssertWarnings;
17+
import org.elasticsearch.xpack.esql.CsvTestsDataLoader;
18+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
19+
import org.elasticsearch.xpack.esql.qa.rest.ProfileLogger;
20+
import org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase;
21+
import org.junit.After;
22+
import org.junit.Before;
23+
import org.junit.ClassRule;
24+
import org.junit.Rule;
25+
26+
import java.io.IOException;
27+
import java.util.HashMap;
28+
import java.util.List;
29+
import java.util.Map;
30+
31+
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.requestObjectBuilder;
32+
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.runEsqlSync;
33+
import static org.hamcrest.Matchers.is;
34+
35+
@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
36+
public class KnnSemanticTextIT extends ESRestTestCase {
37+
38+
@ClassRule
39+
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test"));
40+
41+
@Rule(order = Integer.MIN_VALUE)
42+
public ProfileLogger profileLogger = new ProfileLogger();
43+
44+
private int numDocs;
45+
private final Map<Integer, String> indexedTexts = new HashMap<>();
46+
47+
@Override
48+
protected String getTestRestCluster() {
49+
return cluster.getHttpAddresses();
50+
}
51+
52+
@Before
53+
public void checkCapability() {
54+
assumeTrue("semantic text capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
55+
}
56+
57+
public void testKnnQuery() throws IOException {
58+
String knnQuery = """
59+
FROM semantic-test METADATA _score
60+
| WHERE knn(semantic, [0, 1, 2], 10)
61+
| KEEP id, _score, semantic
62+
| SORT _score DESC
63+
| LIMIT 10
64+
""";
65+
66+
Map<String, Object> response = runEsqlQuery(knnQuery);
67+
@SuppressWarnings("unchecked")
68+
List<Map<String, Object>> columns = (List<Map<String, Object>>) response.get("columns");
69+
assertThat(columns.size(), is(3));
70+
}
71+
72+
@Before
73+
public void setupIndex() throws IOException {
74+
Request request = new Request("PUT", "/semantic-test");
75+
request.setJsonEntity("""
76+
{
77+
"mappings": {
78+
"properties": {
79+
"id": {
80+
"type": "integer"
81+
},
82+
"semantic": {
83+
"type": "semantic_text",
84+
"inference_id": "test_dense_inference"
85+
}
86+
}
87+
},
88+
"settings": {
89+
"index": {
90+
"number_of_shards": 1,
91+
"number_of_replicas": 0
92+
}
93+
}
94+
}
95+
""");
96+
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
97+
98+
request = new Request("POST", "/_bulk?index=semantic-test&refresh=true");
99+
// 4 documents with a null in the middle, leading to 3 ESQL pages and 3 Arrow batches
100+
request.setJsonEntity("""
101+
{"index": {"_id": "1"}}
102+
{"id": 1, "semantic": "sample text one"}
103+
{"index": {"_id": "2"}}
104+
{"id": 2, "semantic": "sample text two"}
105+
{"index": {"_id": "3"}}
106+
{"id": 3, "semantic": "sample text three"}
107+
""");
108+
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
109+
}
110+
111+
@Before
112+
public void setupInferenceEndpoint() throws IOException {
113+
CsvTestsDataLoader.createTextEmbeddingInferenceEndpoint(client());
114+
}
115+
116+
@After
117+
public void removeIndexAndInferenceEndpoint() throws IOException {
118+
client().performRequest(new Request("DELETE", "semantic-test"));
119+
120+
if (CsvTestsDataLoader.clusterHasTextEmbeddingInferenceEndpoint(client())) {
121+
CsvTestsDataLoader.deleteTextEmbeddingInferenceEndpoint(client());
122+
}
123+
}
124+
125+
private Map<String, Object> runEsqlQuery(String query) throws IOException {
126+
RestEsqlTestCase.RequestObjectBuilder builder = requestObjectBuilder().query(query);
127+
return runEsqlSync(builder, new AssertWarnings.NoWarnings(), profileLogger);
128+
}
129+
}

0 commit comments

Comments
 (0)