Skip to content

Commit 1802d60

Browse files
committed
[Inference] Implementing the completion task type on EIS.
1 parent bc5bc54 commit 1802d60

15 files changed

+864
-69
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionCreator;
5757
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler;
5858
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
59+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceChatCompletionModel;
5960
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
6061
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
6162
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
@@ -98,6 +99,7 @@ public class ElasticInferenceService extends SenderService {
9899
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
99100
TaskType.SPARSE_EMBEDDING,
100101
TaskType.CHAT_COMPLETION,
102+
TaskType.COMPLETION,
101103
TaskType.RERANK,
102104
TaskType.TEXT_EMBEDDING
103105
);
@@ -129,6 +131,7 @@ public class ElasticInferenceService extends SenderService {
129131
*/
130132
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(
131133
TaskType.SPARSE_EMBEDDING,
134+
TaskType.COMPLETION,
132135
TaskType.RERANK,
133136
TaskType.TEXT_EMBEDDING
134137
);
@@ -188,7 +191,7 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
188191
return Map.of(
189192
DEFAULT_CHAT_COMPLETION_MODEL_ID_V1,
190193
new DefaultModelConfig(
191-
new ElasticInferenceServiceCompletionModel(
194+
new ElasticInferenceServiceChatCompletionModel(
192195
DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1,
193196
TaskType.CHAT_COMPLETION,
194197
NAME,
@@ -303,7 +306,7 @@ protected void doUnifiedCompletionInfer(
303306
TimeValue timeout,
304307
ActionListener<InferenceServiceResults> listener
305308
) {
306-
if (model instanceof ElasticInferenceServiceCompletionModel == false) {
309+
if (model instanceof ElasticInferenceServiceChatCompletionModel == false) {
307310
listener.onFailure(createInvalidModelException(model));
308311
return;
309312
}
@@ -313,8 +316,8 @@ protected void doUnifiedCompletionInfer(
313316
// generating a different "traceparent" as every task and every REST request creates a new span).
314317
var currentTraceInfo = getCurrentTraceInfo();
315318

316-
var completionModel = (ElasticInferenceServiceCompletionModel) model;
317-
var overriddenModel = ElasticInferenceServiceCompletionModel.of(completionModel, inputs.getRequest());
319+
var completionModel = (ElasticInferenceServiceChatCompletionModel) model;
320+
var overriddenModel = ElasticInferenceServiceChatCompletionModel.of(completionModel, inputs.getRequest());
318321
var errorMessage = constructFailedToSendRequestMessage(
319322
String.format(Locale.ROOT, "%s completions", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
320323
);
@@ -506,7 +509,17 @@ private static ElasticInferenceServiceModel createModel(
506509
context,
507510
chunkingSettings
508511
);
509-
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
512+
case CHAT_COMPLETION -> new ElasticInferenceServiceChatCompletionModel(
513+
inferenceEntityId,
514+
taskType,
515+
NAME,
516+
serviceSettings,
517+
taskSettings,
518+
secretSettings,
519+
elasticInferenceServiceComponents,
520+
context
521+
);
522+
case COMPLETION -> new ElasticInferenceServiceCompletionModel(
510523
inferenceEntityId,
511524
taskType,
512525
NAME,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedCompletionRequestManager.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
1818
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
1919
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
20-
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
20+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceChatCompletionModel;
2121
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceUnifiedChatCompletionRequest;
2222
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
2323
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
@@ -32,7 +32,7 @@ public class ElasticInferenceServiceUnifiedCompletionRequestManager extends Elas
3232
private static final ResponseHandler HANDLER = createCompletionHandler();
3333

3434
public static ElasticInferenceServiceUnifiedCompletionRequestManager of(
35-
ElasticInferenceServiceCompletionModel model,
35+
ElasticInferenceServiceChatCompletionModel model,
3636
ThreadPool threadPool,
3737
TraceContext traceContext
3838
) {
@@ -43,11 +43,11 @@ public static ElasticInferenceServiceUnifiedCompletionRequestManager of(
4343
);
4444
}
4545

46-
private final ElasticInferenceServiceCompletionModel model;
46+
private final ElasticInferenceServiceChatCompletionModel model;
4747
private final TraceContext traceContext;
4848

4949
private ElasticInferenceServiceUnifiedCompletionRequestManager(
50-
ElasticInferenceServiceCompletionModel model,
50+
ElasticInferenceServiceChatCompletionModel model,
5151
ThreadPool threadPool,
5252
TraceContext traceContext
5353
) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1212
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1313
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
14+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
1415
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
1516
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
1617
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
@@ -20,13 +21,17 @@
2021
import org.elasticsearch.xpack.inference.services.ServiceComponents;
2122
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler;
2223
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager;
24+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
2325
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
26+
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceCompletionRequest;
2427
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceDenseTextEmbeddingsRequest;
2528
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest;
2629
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
2730
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
31+
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
2832
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
2933

34+
import java.util.Map;
3035
import java.util.Objects;
3136

3237
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
@@ -45,6 +50,11 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer
4550
(request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response)
4651
);
4752

53+
static final ResponseHandler COMPLETION_HANDLER = new ElasticInferenceServiceResponseHandler(
54+
"elastic completion",
55+
OpenAiChatCompletionResponseEntity::fromResponse
56+
);
57+
4858
private final Sender sender;
4959

5060
private final ServiceComponents serviceComponents;
@@ -108,4 +118,25 @@ public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel m
108118
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Elastic dense text embeddings");
109119
return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage);
110120
}
121+
122+
@Override
123+
public ExecutableAction create(ElasticInferenceServiceCompletionModel model, Map<String, Object> taskSettings) {
124+
var threadPool = serviceComponents.threadPool();
125+
126+
var manager = new GenericRequestManager<>(
127+
threadPool,
128+
model,
129+
COMPLETION_HANDLER,
130+
(chatCompletionInput) -> new ElasticInferenceServiceCompletionRequest(
131+
chatCompletionInput.getInputs(),
132+
model,
133+
traceContext,
134+
extractRequestMetadataFromThreadContext(threadPool.getThreadContext())
135+
),
136+
ChatCompletionInput.class
137+
);
138+
139+
var errorMessage = constructFailedToSendRequestMessage(Strings.format("%s completion", ELASTIC_INFERENCE_SERVICE_IDENTIFIER));
140+
return new SenderExecutableAction(sender, manager, errorMessage);
141+
}
111142
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,20 @@
88
package org.elasticsearch.xpack.inference.services.elastic.action;
99

1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
11+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
1112
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
1213
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
1314
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
1415

16+
import java.util.Map;
17+
1518
public interface ElasticInferenceServiceActionVisitor {
1619

1720
ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model);
1821

1922
ExecutableAction create(ElasticInferenceServiceRerankModel model);
2023

2124
ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model);
25+
26+
ExecutableAction create(ElasticInferenceServiceCompletionModel model, Map<String, Object> taskSettings);
2227
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic.completion;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.inference.EmptySecretSettings;
13+
import org.elasticsearch.inference.EmptyTaskSettings;
14+
import org.elasticsearch.inference.ModelConfigurations;
15+
import org.elasticsearch.inference.ModelSecrets;
16+
import org.elasticsearch.inference.SecretSettings;
17+
import org.elasticsearch.inference.TaskSettings;
18+
import org.elasticsearch.inference.TaskType;
19+
import org.elasticsearch.inference.UnifiedCompletionRequest;
20+
import org.elasticsearch.rest.RestStatus;
21+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
22+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
23+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel;
24+
25+
import java.net.URI;
26+
import java.net.URISyntaxException;
27+
import java.util.Map;
28+
import java.util.Objects;
29+
30+
public class ElasticInferenceServiceChatCompletionModel extends ElasticInferenceServiceModel {
31+
32+
public static ElasticInferenceServiceChatCompletionModel of(
33+
ElasticInferenceServiceChatCompletionModel model,
34+
UnifiedCompletionRequest request
35+
) {
36+
var originalModelServiceSettings = model.getServiceSettings();
37+
var overriddenServiceSettings = new ElasticInferenceServiceCompletionServiceSettings(
38+
Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId())
39+
);
40+
41+
return new ElasticInferenceServiceChatCompletionModel(model, overriddenServiceSettings);
42+
}
43+
44+
private final URI uri;
45+
46+
public ElasticInferenceServiceChatCompletionModel(
47+
String inferenceEntityId,
48+
TaskType taskType,
49+
String service,
50+
Map<String, Object> serviceSettings,
51+
Map<String, Object> taskSettings,
52+
Map<String, Object> secrets,
53+
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
54+
ConfigurationParseContext context
55+
) {
56+
this(
57+
inferenceEntityId,
58+
taskType,
59+
service,
60+
ElasticInferenceServiceCompletionServiceSettings.fromMap(serviceSettings, context),
61+
EmptyTaskSettings.INSTANCE,
62+
EmptySecretSettings.INSTANCE,
63+
elasticInferenceServiceComponents
64+
);
65+
}
66+
67+
public ElasticInferenceServiceChatCompletionModel(
68+
ElasticInferenceServiceChatCompletionModel model,
69+
ElasticInferenceServiceCompletionServiceSettings serviceSettings
70+
) {
71+
super(model, serviceSettings);
72+
this.uri = createUri();
73+
74+
}
75+
76+
public ElasticInferenceServiceChatCompletionModel(
77+
String inferenceEntityId,
78+
TaskType taskType,
79+
String service,
80+
ElasticInferenceServiceCompletionServiceSettings serviceSettings,
81+
@Nullable TaskSettings taskSettings,
82+
@Nullable SecretSettings secretSettings,
83+
ElasticInferenceServiceComponents elasticInferenceServiceComponents
84+
) {
85+
super(
86+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
87+
new ModelSecrets(secretSettings),
88+
serviceSettings,
89+
elasticInferenceServiceComponents
90+
);
91+
92+
this.uri = createUri();
93+
94+
}
95+
96+
@Override
97+
public ElasticInferenceServiceCompletionServiceSettings getServiceSettings() {
98+
return (ElasticInferenceServiceCompletionServiceSettings) super.getServiceSettings();
99+
}
100+
101+
public URI uri() {
102+
return uri;
103+
}
104+
105+
private URI createUri() throws ElasticsearchStatusException {
106+
try {
107+
// TODO, consider transforming the base URL into a URI for better error handling.
108+
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/chat");
109+
} catch (URISyntaxException e) {
110+
throw new ElasticsearchStatusException(
111+
"Failed to create URI for service ["
112+
+ this.getConfigurations().getService()
113+
+ "] with taskType ["
114+
+ this.getTaskType()
115+
+ "]: "
116+
+ e.getMessage(),
117+
RestStatus.BAD_REQUEST,
118+
e
119+
);
120+
}
121+
}
122+
123+
// TODO create/refactor the Configuration class to be extensible for different task types (i.e completion, sparse embeddings).
124+
}

0 commit comments

Comments
 (0)