Skip to content

Commit 993090d

Browse files
authored
[ES|QL] COMPLETION command - Inference Operator implementation (#127409)
1 parent cc5a2d8 commit 993090d

File tree

37 files changed

+2653
-613
lines changed

37 files changed

+2653
-613
lines changed

muted-tests.yml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,6 @@ tests:
300300
- class: org.elasticsearch.search.basic.SearchWithRandomDisconnectsIT
301301
method: testSearchWithRandomDisconnects
302302
issue: https://github.com/elastic/elasticsearch/issues/122707
303-
- class: org.elasticsearch.xpack.esql.inference.RerankOperatorTests
304-
method: testSimpleCircuitBreaking
305-
issue: https://github.com/elastic/elasticsearch/issues/124337
306303
- class: org.elasticsearch.index.engine.ThreadPoolMergeSchedulerTests
307304
method: testSchedulerCloseWaitsForRunningMerge
308305
issue: https://github.com/elastic/elasticsearch/issues/125236
@@ -384,9 +381,6 @@ tests:
384381
- class: org.elasticsearch.packaging.test.DockerTests
385382
method: test024InstallPluginFromArchiveUsingConfigFile
386383
issue: https://github.com/elastic/elasticsearch/issues/126936
387-
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
388-
method: test {rerank.Reranker before a limit ASYNC}
389-
issue: https://github.com/elastic/elasticsearch/issues/127051
390384
- class: org.elasticsearch.packaging.test.DockerTests
391385
method: test026InstallBundledRepositoryPlugins
392386
issue: https://github.com/elastic/elasticsearch/issues/127081
@@ -399,9 +393,6 @@ tests:
399393
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
400394
method: test {p0=ml/data_frame_analytics_cat_apis/Test cat data frame analytics all jobs with header}
401395
issue: https://github.com/elastic/elasticsearch/issues/127625
402-
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
403-
method: test {rerank.Reranker using another sort order ASYNC}
404-
issue: https://github.com/elastic/elasticsearch/issues/127638
405396
- class: org.elasticsearch.xpack.search.CrossClusterAsyncSearchIT
406397
method: testCancellationViaTimeoutWithAllowPartialResultsSetToFalse
407398
issue: https://github.com/elastic/elasticsearch/issues/127096

x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,11 @@
6767
import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled;
6868
import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues;
6969
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.availableDatasetsForEs;
70-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasInferenceEndpoint;
71-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasRerankInferenceEndpoint;
72-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoint;
73-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createRerankInferenceEndpoint;
74-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoint;
75-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteRerankInferenceEndpoint;
70+
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoints;
71+
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoints;
7672
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.loadDataSetIntoEs;
7773
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;
74+
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.COMPLETION;
7875
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METRICS_COMMAND;
7976
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK;
8077
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS;
@@ -138,12 +135,8 @@ protected EsqlSpecTestCase(
138135

139136
@Before
140137
public void setup() throws IOException {
141-
if (supportsInferenceTestService() && clusterHasInferenceEndpoint(client()) == false) {
142-
createInferenceEndpoint(client());
143-
}
144-
145-
if (supportsInferenceTestService() && clusterHasRerankInferenceEndpoint(client()) == false) {
146-
createRerankInferenceEndpoint(client());
138+
if (supportsInferenceTestService()) {
139+
createInferenceEndpoints(adminClient());
147140
}
148141

149142
boolean supportsLookup = supportsIndexModeLookup();
@@ -164,8 +157,8 @@ public static void wipeTestData() throws IOException {
164157
}
165158
}
166159

167-
deleteInferenceEndpoint(client());
168-
deleteRerankInferenceEndpoint(client());
160+
deleteInferenceEndpoints(adminClient());
161+
169162
}
170163

171164
public boolean logResults() {
@@ -254,7 +247,7 @@ protected boolean supportsInferenceTestService() {
254247
}
255248

256249
protected boolean requiresInferenceEndpoint() {
257-
return Stream.of(SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName())
250+
return Stream.of(SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName(), COMPLETION.capabilityName())
258251
.anyMatch(testCase.requiredCapabilities::contains);
259252
}
260253

@@ -372,6 +365,11 @@ private Object valueMapper(CsvTestUtils.Type type, Object value) {
372365
return new BigDecimal(s).round(new MathContext(7, RoundingMode.DOWN)).doubleValue();
373366
}
374367
}
368+
if (type == CsvTestUtils.Type.TEXT || type == CsvTestUtils.Type.KEYWORD || type == CsvTestUtils.Type.SEMANTIC_TEXT) {
369+
if (value instanceof String s) {
370+
value = s.replaceAll("\\\\n", "\n");
371+
}
372+
}
375373
return value.toString();
376374
}
377375

x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java

Lines changed: 75 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.common.settings.Settings;
2828
import org.elasticsearch.common.xcontent.XContentHelper;
2929
import org.elasticsearch.core.Nullable;
30+
import org.elasticsearch.inference.TaskType;
3031
import org.elasticsearch.logging.LogManager;
3132
import org.elasticsearch.logging.Logger;
3233
import org.elasticsearch.test.rest.ESRestTestCase;
@@ -320,7 +321,7 @@ public static Set<TestDataset> availableDatasetsForEs(
320321
boolean supportsIndexModeLookup,
321322
boolean supportsSourceFieldMapping
322323
) throws IOException {
323-
boolean inferenceEnabled = clusterHasInferenceEndpoint(client);
324+
boolean inferenceEnabled = clusterHasSparseEmbeddingInferenceEndpoint(client);
324325

325326
Set<TestDataset> testDataSets = new HashSet<>();
326327

@@ -382,77 +383,90 @@ private static void loadDataSetIntoEs(
382383
}
383384
}
384385

386+
public static void createInferenceEndpoints(RestClient client) throws IOException {
387+
if (clusterHasSparseEmbeddingInferenceEndpoint(client) == false) {
388+
createSparseEmbeddingInferenceEndpoint(client);
389+
}
390+
391+
if (clusterHasRerankInferenceEndpoint(client) == false) {
392+
createRerankInferenceEndpoint(client);
393+
}
394+
395+
if (clusterHasCompletionInferenceEndpoint(client) == false) {
396+
createCompletionInferenceEndpoint(client);
397+
}
398+
}
399+
400+
public static void deleteInferenceEndpoints(RestClient client) throws IOException {
401+
deleteSparseEmbeddingInferenceEndpoint(client);
402+
deleteRerankInferenceEndpoint(client);
403+
deleteCompletionInferenceEndpoint(client);
404+
}
405+
385406
/** The semantic_text mapping type require an inference endpoint that needs to be setup before creating the index. */
386-
public static void createInferenceEndpoint(RestClient client) throws IOException {
387-
Request request = new Request("PUT", "_inference/sparse_embedding/test_sparse_inference");
388-
request.setJsonEntity("""
407+
public static void createSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
408+
createInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference", """
389409
{
390410
"service": "test_service",
391-
"service_settings": {
392-
"model": "my_model",
393-
"api_key": "abc64"
394-
},
395-
"task_settings": {
396-
}
411+
"service_settings": { "model": "my_model", "api_key": "abc64" },
412+
"task_settings": { }
397413
}
398414
""");
399-
client.performRequest(request);
400415
}
401416

402-
public static void deleteInferenceEndpoint(RestClient client) throws IOException {
403-
try {
404-
client.performRequest(new Request("DELETE", "_inference/test_sparse_inference"));
405-
} catch (ResponseException e) {
406-
// 404 here means the endpoint was not created
407-
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
408-
throw e;
409-
}
410-
}
417+
public static void deleteSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
418+
deleteInferenceEndpoint(client, "test_sparse_inference");
411419
}
412420

413-
public static boolean clusterHasInferenceEndpoint(RestClient client) throws IOException {
414-
Request request = new Request("GET", "_inference/sparse_embedding/test_sparse_inference");
415-
try {
416-
client.performRequest(request);
417-
} catch (ResponseException e) {
418-
if (e.getResponse().getStatusLine().getStatusCode() == 404) {
419-
return false;
420-
}
421-
throw e;
422-
}
423-
return true;
421+
public static boolean clusterHasSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
422+
return clusterHasInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference");
424423
}
425424

426425
public static void createRerankInferenceEndpoint(RestClient client) throws IOException {
427-
Request request = new Request("PUT", "_inference/rerank/test_reranker");
428-
request.setJsonEntity("""
426+
createInferenceEndpoint(client, TaskType.RERANK, "test_reranker", """
429427
{
430428
"service": "test_reranking_service",
431-
"service_settings": {
432-
"model_id": "my_model",
433-
"api_key": "abc64"
434-
},
435-
"task_settings": {
436-
"use_text_length": true
437-
}
429+
"service_settings": { "model_id": "my_model", "api_key": "abc64" },
430+
"task_settings": { "use_text_length": true }
438431
}
439432
""");
440-
client.performRequest(request);
441433
}
442434

443435
public static void deleteRerankInferenceEndpoint(RestClient client) throws IOException {
444-
try {
445-
client.performRequest(new Request("DELETE", "_inference/rerank/test_reranker"));
446-
} catch (ResponseException e) {
447-
// 404 here means the endpoint was not created
448-
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
449-
throw e;
450-
}
451-
}
436+
deleteInferenceEndpoint(client, "test_reranker");
452437
}
453438

454439
public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throws IOException {
455-
Request request = new Request("GET", "_inference/rerank/test_reranker");
440+
return clusterHasInferenceEndpoint(client, TaskType.RERANK, "test_reranker");
441+
}
442+
443+
public static void createCompletionInferenceEndpoint(RestClient client) throws IOException {
444+
createInferenceEndpoint(client, TaskType.COMPLETION, "test_completion", """
445+
{
446+
"service": "completion_test_service",
447+
"service_settings": { "model": "my_model", "api_key": "abc64" },
448+
"task_settings": { "temperature": 3 }
449+
}
450+
""");
451+
}
452+
453+
public static void deleteCompletionInferenceEndpoint(RestClient client) throws IOException {
454+
deleteInferenceEndpoint(client, "test_completion");
455+
}
456+
457+
public static boolean clusterHasCompletionInferenceEndpoint(RestClient client) throws IOException {
458+
return clusterHasInferenceEndpoint(client, TaskType.COMPLETION, "test_completion");
459+
}
460+
461+
private static void createInferenceEndpoint(RestClient client, TaskType taskType, String inferenceId, String modelSettings)
462+
throws IOException {
463+
Request request = new Request("PUT", "_inference/" + taskType.name() + "/" + inferenceId);
464+
request.setJsonEntity(modelSettings);
465+
client.performRequest(request);
466+
}
467+
468+
private static boolean clusterHasInferenceEndpoint(RestClient client, TaskType taskType, String inferenceId) throws IOException {
469+
Request request = new Request("GET", "_inference/" + taskType.name() + "/" + inferenceId);
456470
try {
457471
client.performRequest(request);
458472
} catch (ResponseException e) {
@@ -464,6 +478,17 @@ public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throw
464478
return true;
465479
}
466480

481+
private static void deleteInferenceEndpoint(RestClient client, String inferenceId) throws IOException {
482+
try {
483+
client.performRequest(new Request("DELETE", "_inference/" + inferenceId));
484+
} catch (ResponseException e) {
485+
// 404 here means the endpoint was not created
486+
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
487+
throw e;
488+
}
489+
}
490+
}
491+
467492
private static void loadEnrichPolicy(RestClient client, String policyName, String policyFileName, Logger logger) throws IOException {
468493
URL policyMapping = getResource("/" + policyFileName);
469494
String entity = readTextFile(policyMapping);
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Note:
2+
// The "test_completion" service returns the prompt in uppercase, making the output easy to guess.
3+
4+
5+
completion using a ROW source operator
6+
required_capability: completion
7+
8+
ROW prompt="Who is Victor Hugo?"
9+
| COMPLETION prompt WITH test_completion AS completion_output
10+
;
11+
12+
prompt:keyword | completion_output:keyword
13+
Who is Victor Hugo? | WHO IS VICTOR HUGO?
14+
;
15+
16+
17+
completion using a ROW source operator and prompt is a multi-valued field
18+
required_capability: completion
19+
20+
ROW prompt=["Answer the following question:", "Who is Victor Hugo?"]
21+
| COMPLETION prompt WITH test_completion AS completion_output
22+
;
23+
24+
prompt:keyword | completion_output:keyword
25+
[Answer the following question:, Who is Victor Hugo?] | ANSWER THE FOLLOWING QUESTION:\nWHO IS VICTOR HUGO?
26+
;
27+
28+
29+
completion after a search
30+
required_capability: completion
31+
required_capability: match_operator_colon
32+
33+
FROM books METADATA _score
34+
| WHERE title:"war and peace" AND author:"Tolstoy"
35+
| SORT _score DESC
36+
| LIMIT 2
37+
| COMPLETION title WITH test_completion
38+
| KEEP title, completion
39+
;
40+
41+
title:text | completion:keyword
42+
War and Peace | WAR AND PEACE
43+
War and Peace (Signet Classics) | WAR AND PEACE (SIGNET CLASSICS)
44+
;
45+
46+
completion using a function as a prompt
47+
required_capability: completion
48+
required_capability: match_operator_colon
49+
50+
FROM books METADATA _score
51+
| WHERE title:"war and peace" AND author:"Tolstoy"
52+
| SORT _score DESC
53+
| LIMIT 2
54+
| COMPLETION CONCAT("This is a prompt: ", title) WITH test_completion
55+
| KEEP title, completion
56+
;
57+
58+
title:text | completion:keyword
59+
War and Peace | THIS IS A PROMPT: WAR AND PEACE
60+
War and Peace (Signet Classics) | THIS IS A PROMPT: WAR AND PEACE (SIGNET CLASSICS)
61+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ private LogicalPlan resolveCompletion(Completion p, List<Attribute> childrenOutp
617617
Expression prompt = p.prompt();
618618

619619
if (targetField instanceof UnresolvedAttribute ua) {
620-
targetField = new ReferenceAttribute(ua.source(), ua.name(), TEXT);
620+
targetField = new ReferenceAttribute(ua.source(), ua.name(), KEYWORD);
621621
}
622622

623623
if (prompt.resolved() == false) {

0 commit comments

Comments
 (0)