Skip to content

Commit 29a41cc

Browse files
committed
[ES|QL] COMPLETION command - Inference Operator implementation (elastic#127409)
1 parent f97680a commit 29a41cc

File tree

36 files changed

+2653
-607
lines changed

36 files changed

+2653
-607
lines changed

x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,11 @@
6666
import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled;
6767
import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues;
6868
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.availableDatasetsForEs;
69-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasInferenceEndpoint;
70-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasRerankInferenceEndpoint;
71-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoint;
72-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createRerankInferenceEndpoint;
73-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoint;
74-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteRerankInferenceEndpoint;
69+
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoints;
70+
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoints;
7571
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.loadDataSetIntoEs;
7672
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;
73+
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.COMPLETION;
7774
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK;
7875
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS;
7976
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.cap;
@@ -138,12 +135,8 @@ protected EsqlSpecTestCase(
138135

139136
@Before
140137
public void setup() throws IOException {
141-
if (supportsInferenceTestService() && clusterHasInferenceEndpoint(client()) == false) {
142-
createInferenceEndpoint(client());
143-
}
144-
145-
if (supportsInferenceTestService() && clusterHasRerankInferenceEndpoint(client()) == false) {
146-
createRerankInferenceEndpoint(client());
138+
if (supportsInferenceTestService()) {
139+
createInferenceEndpoints(adminClient());
147140
}
148141

149142
if (indexExists(availableDatasetsForEs(client(), supportsIndexModeLookup()).iterator().next().indexName()) == false) {
@@ -166,8 +159,8 @@ public static void wipeTestData() throws IOException {
166159
}
167160
}
168161

169-
deleteInferenceEndpoint(client());
170-
deleteRerankInferenceEndpoint(client());
162+
deleteInferenceEndpoints(adminClient());
163+
171164
}
172165

173166
public boolean logResults() {
@@ -256,7 +249,7 @@ protected boolean supportsInferenceTestService() {
256249
}
257250

258251
protected boolean requiresInferenceEndpoint() {
259-
return Stream.of(SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName())
252+
return Stream.of(SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName(), COMPLETION.capabilityName())
260253
.anyMatch(testCase.requiredCapabilities::contains);
261254
}
262255

@@ -355,6 +348,11 @@ private Object valueMapper(CsvTestUtils.Type type, Object value) {
355348
return new BigDecimal(s).round(new MathContext(7, RoundingMode.DOWN)).doubleValue();
356349
}
357350
}
351+
if (type == CsvTestUtils.Type.TEXT || type == CsvTestUtils.Type.KEYWORD || type == CsvTestUtils.Type.SEMANTIC_TEXT) {
352+
if (value instanceof String s) {
353+
value = s.replaceAll("\\\\n", "\n");
354+
}
355+
}
358356
return value.toString();
359357
}
360358

x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java

Lines changed: 76 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.common.settings.Settings;
2828
import org.elasticsearch.common.xcontent.XContentHelper;
2929
import org.elasticsearch.core.Nullable;
30+
import org.elasticsearch.inference.TaskType;
3031
import org.elasticsearch.logging.LogManager;
3132
import org.elasticsearch.logging.Logger;
3233
import org.elasticsearch.test.rest.ESRestTestCase;
@@ -278,7 +279,7 @@ public static void main(String[] args) throws IOException {
278279
}
279280

280281
public static Set<TestDataset> availableDatasetsForEs(RestClient client, boolean supportsIndexModeLookup) throws IOException {
281-
boolean inferenceEnabled = clusterHasInferenceEndpoint(client);
282+
boolean inferenceEnabled = clusterHasSparseEmbeddingInferenceEndpoint(client);
282283

283284
Set<TestDataset> testDataSets = new HashSet<>();
284285

@@ -319,79 +320,90 @@ private static void loadDataSetIntoEs(RestClient client, boolean supportsIndexMo
319320
}
320321
}
321322

322-
/**
323-
* The semantic_text mapping type require an inference endpoint that needs to be setup before creating the index.
324-
*/
325-
public static void createInferenceEndpoint(RestClient client) throws IOException {
326-
Request request = new Request("PUT", "_inference/sparse_embedding/test_sparse_inference");
327-
request.setJsonEntity("""
323+
public static void createInferenceEndpoints(RestClient client) throws IOException {
324+
if (clusterHasSparseEmbeddingInferenceEndpoint(client) == false) {
325+
createSparseEmbeddingInferenceEndpoint(client);
326+
}
327+
328+
if (clusterHasRerankInferenceEndpoint(client) == false) {
329+
createRerankInferenceEndpoint(client);
330+
}
331+
332+
if (clusterHasCompletionInferenceEndpoint(client) == false) {
333+
createCompletionInferenceEndpoint(client);
334+
}
335+
}
336+
337+
public static void deleteInferenceEndpoints(RestClient client) throws IOException {
338+
deleteSparseEmbeddingInferenceEndpoint(client);
339+
deleteRerankInferenceEndpoint(client);
340+
deleteCompletionInferenceEndpoint(client);
341+
}
342+
343+
/** The semantic_text mapping type require an inference endpoint that needs to be setup before creating the index. */
344+
public static void createSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
345+
createInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference", """
328346
{
329347
"service": "test_service",
330-
"service_settings": {
331-
"model": "my_model",
332-
"api_key": "abc64"
333-
},
334-
"task_settings": {
335-
}
348+
"service_settings": { "model": "my_model", "api_key": "abc64" },
349+
"task_settings": { }
336350
}
337351
""");
338-
client.performRequest(request);
339352
}
340353

341-
public static void deleteInferenceEndpoint(RestClient client) throws IOException {
342-
try {
343-
client.performRequest(new Request("DELETE", "_inference/sparse_embedding/test_sparse_inference"));
344-
} catch (ResponseException e) {
345-
// 404 here means the endpoint was not created
346-
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
347-
throw e;
348-
}
349-
}
354+
public static void deleteSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
355+
deleteInferenceEndpoint(client, "test_sparse_inference");
350356
}
351357

352-
public static boolean clusterHasInferenceEndpoint(RestClient client) throws IOException {
353-
Request request = new Request("GET", "_inference/sparse_embedding/test_sparse_inference");
354-
try {
355-
client.performRequest(request);
356-
} catch (ResponseException e) {
357-
if (e.getResponse().getStatusLine().getStatusCode() == 404) {
358-
return false;
359-
}
360-
throw e;
361-
}
362-
return true;
358+
public static boolean clusterHasSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
359+
return clusterHasInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference");
363360
}
364361

365362
public static void createRerankInferenceEndpoint(RestClient client) throws IOException {
366-
Request request = new Request("PUT", "_inference/rerank/test_reranker");
367-
request.setJsonEntity("""
363+
createInferenceEndpoint(client, TaskType.RERANK, "test_reranker", """
368364
{
369365
"service": "test_reranking_service",
370-
"service_settings": {
371-
"model_id": "my_model",
372-
"api_key": "abc64"
373-
},
374-
"task_settings": {
375-
"use_text_length": true
376-
}
366+
"service_settings": { "model_id": "my_model", "api_key": "abc64" },
367+
"task_settings": { "use_text_length": true }
377368
}
378369
""");
379-
client.performRequest(request);
380370
}
381371

382372
public static void deleteRerankInferenceEndpoint(RestClient client) throws IOException {
383-
try {
384-
client.performRequest(new Request("DELETE", "_inference/rerank/test_reranker"));
385-
} catch (ResponseException e) {
386-
// 404 here means the endpoint was not created
387-
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
388-
throw e;
389-
}
390-
}
373+
deleteInferenceEndpoint(client, "test_reranker");
391374
}
392375

393376
public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throws IOException {
394-
Request request = new Request("GET", "_inference/rerank/test_reranker");
377+
return clusterHasInferenceEndpoint(client, TaskType.RERANK, "test_reranker");
378+
}
379+
380+
public static void createCompletionInferenceEndpoint(RestClient client) throws IOException {
381+
createInferenceEndpoint(client, TaskType.COMPLETION, "test_completion", """
382+
{
383+
"service": "completion_test_service",
384+
"service_settings": { "model": "my_model", "api_key": "abc64" },
385+
"task_settings": { "temperature": 3 }
386+
}
387+
""");
388+
}
389+
390+
public static void deleteCompletionInferenceEndpoint(RestClient client) throws IOException {
391+
deleteInferenceEndpoint(client, "test_completion");
392+
}
393+
394+
public static boolean clusterHasCompletionInferenceEndpoint(RestClient client) throws IOException {
395+
return clusterHasInferenceEndpoint(client, TaskType.COMPLETION, "test_completion");
396+
}
397+
398+
private static void createInferenceEndpoint(RestClient client, TaskType taskType, String inferenceId, String modelSettings)
399+
throws IOException {
400+
Request request = new Request("PUT", "_inference/" + taskType.name() + "/" + inferenceId);
401+
request.setJsonEntity(modelSettings);
402+
client.performRequest(request);
403+
}
404+
405+
private static boolean clusterHasInferenceEndpoint(RestClient client, TaskType taskType, String inferenceId) throws IOException {
406+
Request request = new Request("GET", "_inference/" + taskType.name() + "/" + inferenceId);
395407
try {
396408
client.performRequest(request);
397409
} catch (ResponseException e) {
@@ -403,6 +415,17 @@ public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throw
403415
return true;
404416
}
405417

418+
private static void deleteInferenceEndpoint(RestClient client, String inferenceId) throws IOException {
419+
try {
420+
client.performRequest(new Request("DELETE", "_inference/" + inferenceId));
421+
} catch (ResponseException e) {
422+
// 404 here means the endpoint was not created
423+
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
424+
throw e;
425+
}
426+
}
427+
}
428+
406429
private static void loadEnrichPolicy(RestClient client, String policyName, String policyFileName, Logger logger) throws IOException {
407430
URL policyMapping = CsvTestsDataLoader.class.getResource("/" + policyFileName);
408431
if (policyMapping == null) {
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Note:
2+
// The "test_completion" service returns the prompt in uppercase, making the output easy to guess.
3+
4+
5+
completion using a ROW source operator
6+
required_capability: completion
7+
8+
ROW prompt="Who is Victor Hugo?"
9+
| COMPLETION prompt WITH test_completion AS completion_output
10+
;
11+
12+
prompt:keyword | completion_output:keyword
13+
Who is Victor Hugo? | WHO IS VICTOR HUGO?
14+
;
15+
16+
17+
completion using a ROW source operator and prompt is a multi-valued field
18+
required_capability: completion
19+
20+
ROW prompt=["Answer the following question:", "Who is Victor Hugo?"]
21+
| COMPLETION prompt WITH test_completion AS completion_output
22+
;
23+
24+
prompt:keyword | completion_output:keyword
25+
[Answer the following question:, Who is Victor Hugo?] | ANSWER THE FOLLOWING QUESTION:\nWHO IS VICTOR HUGO?
26+
;
27+
28+
29+
completion after a search
30+
required_capability: completion
31+
required_capability: match_operator_colon
32+
33+
FROM books METADATA _score
34+
| WHERE title:"war and peace" AND author:"Tolstoy"
35+
| SORT _score DESC
36+
| LIMIT 2
37+
| COMPLETION title WITH test_completion
38+
| KEEP title, completion
39+
;
40+
41+
title:text | completion:keyword
42+
War and Peace | WAR AND PEACE
43+
War and Peace (Signet Classics) | WAR AND PEACE (SIGNET CLASSICS)
44+
;
45+
46+
completion using a function as a prompt
47+
required_capability: completion
48+
required_capability: match_operator_colon
49+
50+
FROM books METADATA _score
51+
| WHERE title:"war and peace" AND author:"Tolstoy"
52+
| SORT _score DESC
53+
| LIMIT 2
54+
| COMPLETION CONCAT("This is a prompt: ", title) WITH test_completion
55+
| KEEP title, completion
56+
;
57+
58+
title:text | completion:keyword
59+
War and Peace | THIS IS A PROMPT: WAR AND PEACE
60+
War and Peace (Signet Classics) | THIS IS A PROMPT: WAR AND PEACE (SIGNET CLASSICS)
61+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ private LogicalPlan resolveCompletion(Completion p, List<Attribute> childrenOutp
596596
Expression prompt = p.prompt();
597597

598598
if (targetField instanceof UnresolvedAttribute ua) {
599-
targetField = new ReferenceAttribute(ua.source(), ua.name(), TEXT);
599+
targetField = new ReferenceAttribute(ua.source(), ua.name(), KEYWORD);
600600
}
601601

602602
if (prompt.resolved() == false) {

0 commit comments

Comments
 (0)