-
Notifications
You must be signed in to change notification settings - Fork 25.6k
ESQL - Add semantic_text support for knn function #133806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
e8dd1e8
9db94aa
0e264b1
66db5c9
643c767
3a8ef30
fe6d2b3
a2f9b7d
10d2f48
9aae2d5
6e437ac
37b540c
3cf742e
6eb6761
7b0cc2f
cb33548
68aef53
7ea8050
d8ef102
2e9f118
d56c897
3b4e687
7fc135c
c98126f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.esql.qa.multi_node; | ||
|
||
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; | ||
|
||
import org.elasticsearch.test.TestClustersThreadFilter; | ||
import org.elasticsearch.test.cluster.ElasticsearchCluster; | ||
import org.elasticsearch.xpack.esql.qa.rest.KnnSemanticTextTestCase; | ||
import org.junit.ClassRule; | ||
|
||
@ThreadLeakFilters(filters = TestClustersThreadFilter.class) | ||
public class KnnSemanticTextIT extends KnnSemanticTextTestCase { | ||
@ClassRule | ||
public static ElasticsearchCluster cluster = Clusters.testCluster( | ||
spec -> spec.module("x-pack-inference").plugin("inference-service-test") | ||
); | ||
|
||
@Override | ||
protected String getTestRestCluster() { | ||
return cluster.getHttpAddresses(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.esql.qa.single_node; | ||
|
||
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; | ||
|
||
import org.elasticsearch.test.TestClustersThreadFilter; | ||
import org.elasticsearch.test.cluster.ElasticsearchCluster; | ||
import org.elasticsearch.xpack.esql.qa.rest.KnnSemanticTextTestCase; | ||
import org.junit.ClassRule; | ||
|
||
@ThreadLeakFilters(filters = TestClustersThreadFilter.class) | ||
public class KnnSemanticTextIT extends KnnSemanticTextTestCase { | ||
|
||
@ClassRule | ||
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test")); | ||
|
||
@Override | ||
protected String getTestRestCluster() { | ||
return cluster.getHttpAddresses(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.esql.qa.rest; | ||
|
||
import org.elasticsearch.client.Request; | ||
import org.elasticsearch.client.ResponseException; | ||
import org.elasticsearch.test.rest.ESRestTestCase; | ||
import org.elasticsearch.xpack.esql.AssertWarnings; | ||
import org.elasticsearch.xpack.esql.CsvTestsDataLoader; | ||
import org.elasticsearch.xpack.esql.action.EsqlCapabilities; | ||
import org.junit.After; | ||
import org.junit.Before; | ||
import org.junit.Rule; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static org.elasticsearch.rest.RestStatus.BAD_REQUEST; | ||
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.requestObjectBuilder; | ||
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.runEsqlSync; | ||
import static org.hamcrest.Matchers.is; | ||
import static org.hamcrest.core.StringContains.containsString; | ||
|
||
public class KnnSemanticTextTestCase extends ESRestTestCase { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should add test cases here for using knn over 2 different dense endpoints. Note that this may conflict with this open PR: #133675 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's not needed, as the SemanticQueryBuilder is not used on this approach - it's just knn being done over two dense_vector fields. The |
||
|
||
@Rule(order = Integer.MIN_VALUE) | ||
public ProfileLogger profileLogger = new ProfileLogger(); | ||
|
||
@Before | ||
public void checkCapability() { | ||
assumeTrue("knn with semantic text not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); | ||
} | ||
|
||
@SuppressWarnings("unchecked") | ||
public void testKnnQueryWithSemanticText() throws IOException { | ||
String knnQuery = """ | ||
FROM semantic-test METADATA _score | ||
| WHERE knn(dense_semantic, [0, 1, 2], 10) | ||
| KEEP id, _score, dense_semantic | ||
| SORT _score DESC | ||
| LIMIT 10 | ||
"""; | ||
|
||
Map<String, Object> response = runEsqlQuery(knnQuery); | ||
List<Map<String, Object>> columns = (List<Map<String, Object>>) response.get("columns"); | ||
assertThat(columns.size(), is(3)); | ||
List<List<Object>> rows = (List<List<Object>>) response.get("values"); | ||
assertThat(rows.size(), is(3)); | ||
for (int row = 0; row < rows.size(); row++) { | ||
List<Object> rowData = rows.get(row); | ||
Integer id = (Integer) rowData.get(0); | ||
assertThat(id, is(3 - row)); | ||
} | ||
} | ||
|
||
public void testKnnQueryOnTextField() throws IOException { | ||
String knnQuery = """ | ||
FROM semantic-test METADATA _score | ||
| WHERE knn(text, [0, 1, 2], 10) | ||
| KEEP id, _score, dense_semantic | ||
| SORT _score DESC | ||
| LIMIT 10 | ||
"""; | ||
|
||
ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(knnQuery)); | ||
assertThat(re.getResponse().getStatusLine().getStatusCode(), is(BAD_REQUEST.getStatus())); | ||
assertThat(re.getMessage(), containsString("[knn] queries are only supported on [dense_vector] fields")); | ||
} | ||
|
||
public void testKnnQueryOnSparseSemanticTextField() throws IOException { | ||
String knnQuery = """ | ||
FROM semantic-test METADATA _score | ||
| WHERE knn(sparse_semantic, [0, 1, 2], 10) | ||
| KEEP id, _score, sparse_semantic | ||
| SORT _score DESC | ||
| LIMIT 10 | ||
"""; | ||
|
||
ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(knnQuery)); | ||
assertThat(re.getResponse().getStatusLine().getStatusCode(), is(BAD_REQUEST.getStatus())); | ||
assertThat(re.getMessage(), containsString("[knn] queries are only supported on [dense_vector] fields")); | ||
} | ||
|
||
@Before | ||
public void setUp() throws Exception { | ||
super.setUp(); | ||
setupInferenceEndpoints(); | ||
setupIndex(); | ||
} | ||
|
||
private void setupIndex() throws IOException { | ||
Request request = new Request("PUT", "/semantic-test"); | ||
request.setJsonEntity(""" | ||
{ | ||
"mappings": { | ||
"properties": { | ||
"id": { | ||
"type": "integer" | ||
}, | ||
"dense_semantic": { | ||
"type": "semantic_text", | ||
"inference_id": "test_dense_inference" | ||
}, | ||
"sparse_semantic": { | ||
"type": "semantic_text", | ||
"inference_id": "test_sparse_inference" | ||
}, | ||
"text": { | ||
"type": "text", | ||
"copy_to": ["dense_semantic", "sparse_semantic"] | ||
} | ||
} | ||
}, | ||
"settings": { | ||
"index": { | ||
"number_of_shards": 1, | ||
"number_of_replicas": 0 | ||
} | ||
} | ||
} | ||
"""); | ||
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode()); | ||
|
||
request = new Request("POST", "/_bulk?index=semantic-test&refresh=true"); | ||
request.setJsonEntity(""" | ||
{"index": {"_id": "1"}} | ||
{"id": 1, "text": "sample text"} | ||
{"index": {"_id": "2"}} | ||
{"id": 2, "text": "another sample text"} | ||
{"index": {"_id": "3"}} | ||
{"id": 3, "text": "yet another sample text"} | ||
"""); | ||
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode()); | ||
} | ||
|
||
private void setupInferenceEndpoints() throws IOException { | ||
CsvTestsDataLoader.createTextEmbeddingInferenceEndpoint(client()); | ||
CsvTestsDataLoader.createSparseEmbeddingInferenceEndpoint(client()); | ||
} | ||
|
||
@After | ||
public void tearDown() throws Exception { | ||
super.tearDown(); | ||
client().performRequest(new Request("DELETE", "semantic-test")); | ||
|
||
if (CsvTestsDataLoader.clusterHasTextEmbeddingInferenceEndpoint(client())) { | ||
CsvTestsDataLoader.deleteTextEmbeddingInferenceEndpoint(client()); | ||
} | ||
if (CsvTestsDataLoader.clusterHasSparseEmbeddingInferenceEndpoint(client())) { | ||
CsvTestsDataLoader.deleteSparseEmbeddingInferenceEndpoint(client()); | ||
} | ||
} | ||
|
||
private Map<String, Object> runEsqlQuery(String query) throws IOException { | ||
RestEsqlTestCase.RequestObjectBuilder builder = requestObjectBuilder().query(query); | ||
return runEsqlSync(builder, new AssertWarnings.NoWarnings(), profileLogger); | ||
} | ||
} |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adds a test inference endpoint for text embedding tasks |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adds a |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
_id:keyword,semantic_text_field:semantic_text,st_bool:semantic_text,st_cartesian_point:semantic_text,st_cartesian_shape:semantic_text,st_datetime:semantic_text,st_double:semantic_text,st_geopoint:semantic_text,st_geoshape:semantic_text,st_integer:semantic_text,st_ip:semantic_text,st_long:semantic_text,st_unsigned_long:semantic_text,st_version:semantic_text,st_multi_value:semantic_text,st_unicode:semantic_text,host:keyword,description:text,value:long,st_base64:semantic_text,st_logs:semantic_text,language_name:keyword | ||
1,live long and prosper,false,"POINT(4297.11 -1475.53)",,1953-09-02T00:00:00.000Z,5.20128E11,"POINT(42.97109630194 14.7552534413725)","POLYGON ((30 10\, 40 40\, 20 40\, 10 20\, 30 10))",23,1.1.1.1,2147483648,2147483648,1.2.3,["Hello there!", "This is a random value", "for testing purposes"],你吃饭了吗,"host1","some description1",1001,ZWxhc3RpYw==,"2024-12-23T12:15:00.000Z 1.2.3.4 [email protected] 4553",English | ||
2,all we have to decide is what to do with the time that is given to us,true,"POINT(7580.93 2272.77)",,2023-09-24T15:57:00.000Z,4541.11,"POINT(37.97109630194 21.7552534413725)","POLYGON ((30 10\, 40 40\, 20 40\, 10 20\, 30 10))",122,1.1.2.1,123,2147483648.2,9.0.0,["nice to meet you", "bye bye!"],["谢谢", "对不起我的中文不好"],"host2","some description2",1002,aGVsbG8=,"2024-01-23T12:15:00.000Z 1.2.3.4 [email protected] 42",French | ||
3,be excellent to each other,,,,,,,,,,,,,,,"host3","some description3",1003,,"2023-01-23T12:15:00.000Z 127.0.0.1 [email protected] 42",Spanish | ||
_id:keyword,semantic_text_field:semantic_text,semantic_text_dense_field:semantic_text,st_bool:semantic_text,st_cartesian_point:semantic_text,st_cartesian_shape:semantic_text,st_datetime:semantic_text,st_double:semantic_text,st_geopoint:semantic_text,st_geoshape:semantic_text,st_integer:semantic_text,st_ip:semantic_text,st_long:semantic_text,st_unsigned_long:semantic_text,st_version:semantic_text,st_multi_value:semantic_text,st_unicode:semantic_text,host:keyword,description:text,value:long,st_base64:semantic_text,st_logs:semantic_text,language_name:keyword | ||
1,live long and prosper,live long and prosper,false,"POINT(4297.11 -1475.53)",,1953-09-02T00:00:00.000Z,5.20128E11,"POINT(42.97109630194 14.7552534413725)","POLYGON ((30 10\, 40 40\, 20 40\, 10 20\, 30 10))",23,1.1.1.1,2147483648,2147483648,1.2.3,["Hello there!", "This is a random value", "for testing purposes"],你吃饭了吗,"host1","some description1",1001,ZWxhc3RpYw==,"2024-12-23T12:15:00.000Z 1.2.3.4 [email protected] 4553",English | ||
2,all we have to decide is what to do with the time that is given to us,all we have to decide is what to do with the time that is given to us,true,"POINT(7580.93 2272.77)",,2023-09-24T15:57:00.000Z,4541.11,"POINT(37.97109630194 21.7552534413725)","POLYGON ((30 10\, 40 40\, 20 40\, 10 20\, 30 10))",122,1.1.2.1,123,2147483648.2,9.0.0,["nice to meet you", "bye bye!"],["谢谢", "对不起我的中文不好"],"host2","some description2",1002,aGVsbG8=,"2024-01-23T12:15:00.000Z 1.2.3.4 [email protected] 42",French | ||
3,be excellent to each other,be excellent to each other,,,,,,,,,,,,,,,"host3","some description3",1003,,"2023-01-23T12:15:00.000Z 127.0.0.1 [email protected] 42",Spanish |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added both single and multi node ITs, that extend from a common superclass (used
SeamnticMatchTestCase
as a template)