Skip to content

Commit 10d2f48

Browse files
committed
Add test for sparse vector
1 parent a2f9b7d commit 10d2f48

File tree

1 file changed

+37
-10
lines changed
  • x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node

1 file changed

+37
-10
lines changed

x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/KnnSemanticTextIT.java

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ public void checkCapability() {
6161
public void testKnnQueryWithSemanticText() throws IOException {
6262
String knnQuery = """
6363
FROM semantic-test METADATA _score
64-
| WHERE knn(semantic, [0, 1, 2], 10)
65-
| KEEP id, _score, semantic
64+
| WHERE knn(dense_semantic, [0, 1, 2], 10)
65+
| KEEP id, _score, dense_semantic
6666
| SORT _score DESC
6767
| LIMIT 10
6868
""";
@@ -83,7 +83,21 @@ public void testKnnQueryOnTextField() throws IOException {
8383
String knnQuery = """
8484
FROM semantic-test METADATA _score
8585
| WHERE knn(text, [0, 1, 2], 10)
86-
| KEEP id, _score, semantic
86+
| KEEP id, _score, dense_semantic
87+
| SORT _score DESC
88+
| LIMIT 10
89+
""";
90+
91+
ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(knnQuery));
92+
assertThat(re.getResponse().getStatusLine().getStatusCode(), is(BAD_REQUEST.getStatus()));
93+
assertThat(re.getMessage(), containsString("[knn] queries are only supported on [dense_vector] fields"));
94+
}
95+
96+
public void testKnnQueryOnSparseSemanticTextField() throws IOException {
97+
String knnQuery = """
98+
FROM semantic-test METADATA _score
99+
| WHERE knn(sparse_semantic, [0, 1, 2], 10)
100+
| KEEP id, _score, sparse_semantic
87101
| SORT _score DESC
88102
| LIMIT 10
89103
""";
@@ -94,7 +108,13 @@ public void testKnnQueryOnTextField() throws IOException {
94108
}
95109

96110
@Before
97-
public void setupIndex() throws IOException {
111+
public void setUp() throws Exception {
112+
super.setUp();
113+
setupInferenceEndpoints();
114+
setupIndex();
115+
}
116+
117+
private void setupIndex() throws IOException {
98118
Request request = new Request("PUT", "/semantic-test");
99119
request.setJsonEntity("""
100120
{
@@ -103,13 +123,17 @@ public void setupIndex() throws IOException {
103123
"id": {
104124
"type": "integer"
105125
},
106-
"semantic": {
126+
"dense_semantic": {
107127
"type": "semantic_text",
108128
"inference_id": "test_dense_inference"
109129
},
130+
"sparse_semantic": {
131+
"type": "semantic_text",
132+
"inference_id": "test_sparse_inference"
133+
},
110134
"text": {
111135
"type": "text",
112-
"copy_to": "semantic"
136+
"copy_to": ["dense_semantic", "sparse_semantic"]
113137
}
114138
}
115139
},
@@ -124,7 +148,6 @@ public void setupIndex() throws IOException {
124148
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
125149

126150
request = new Request("POST", "/_bulk?index=semantic-test&refresh=true");
127-
// 4 documents with a null in the middle, leading to 3 ESQL pages and 3 Arrow batches
128151
request.setJsonEntity("""
129152
{"index": {"_id": "1"}}
130153
{"id": 1, "text": "sample text"}
@@ -136,18 +159,22 @@ public void setupIndex() throws IOException {
136159
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
137160
}
138161

139-
@Before
140-
public void setupInferenceEndpoint() throws IOException {
162+
private void setupInferenceEndpoints() throws IOException {
141163
CsvTestsDataLoader.createTextEmbeddingInferenceEndpoint(client());
164+
CsvTestsDataLoader.createSparseEmbeddingInferenceEndpoint(client());
142165
}
143166

144167
@After
145-
public void removeIndexAndInferenceEndpoint() throws IOException {
168+
public void tearDown() throws Exception {
169+
super.tearDown();
146170
client().performRequest(new Request("DELETE", "semantic-test"));
147171

148172
if (CsvTestsDataLoader.clusterHasTextEmbeddingInferenceEndpoint(client())) {
149173
CsvTestsDataLoader.deleteTextEmbeddingInferenceEndpoint(client());
150174
}
175+
if (CsvTestsDataLoader.clusterHasSparseEmbeddingInferenceEndpoint(client())) {
176+
CsvTestsDataLoader.deleteSparseEmbeddingInferenceEndpoint(client());
177+
}
151178
}
152179

153180
private Map<String, Object> runEsqlQuery(String query) throws IOException {

0 commit comments

Comments
 (0)