Skip to content

Commit 1b70962

Browse files
committed
Apply feedback from code review.
1 parent e2ed11c commit 1b70962

16 files changed

+472
-728
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionCreator;
5757
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler;
5858
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
59-
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceChatCompletionModel;
6059
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
6160
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
6261
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
@@ -191,7 +190,7 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
191190
return Map.of(
192191
DEFAULT_CHAT_COMPLETION_MODEL_ID_V1,
193192
new DefaultModelConfig(
194-
new ElasticInferenceServiceChatCompletionModel(
193+
new ElasticInferenceServiceCompletionModel(
195194
DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1,
196195
TaskType.CHAT_COMPLETION,
197196
NAME,
@@ -306,7 +305,7 @@ protected void doUnifiedCompletionInfer(
306305
TimeValue timeout,
307306
ActionListener<InferenceServiceResults> listener
308307
) {
309-
if (model instanceof ElasticInferenceServiceChatCompletionModel == false) {
308+
if (model instanceof ElasticInferenceServiceCompletionModel == false || model.getTaskType() != TaskType.CHAT_COMPLETION) {
310309
listener.onFailure(createInvalidModelException(model));
311310
return;
312311
}
@@ -316,8 +315,8 @@ protected void doUnifiedCompletionInfer(
316315
// generating a different "traceparent" as every task and every REST request creates a new span).
317316
var currentTraceInfo = getCurrentTraceInfo();
318317

319-
var completionModel = (ElasticInferenceServiceChatCompletionModel) model;
320-
var overriddenModel = ElasticInferenceServiceChatCompletionModel.of(completionModel, inputs.getRequest());
318+
var completionModel = (ElasticInferenceServiceCompletionModel) model;
319+
var overriddenModel = ElasticInferenceServiceCompletionModel.of(completionModel, inputs.getRequest());
321320
var errorMessage = constructFailedToSendRequestMessage(
322321
String.format(Locale.ROOT, "%s completions", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
323322
);
@@ -509,17 +508,7 @@ private static ElasticInferenceServiceModel createModel(
509508
context,
510509
chunkingSettings
511510
);
512-
case CHAT_COMPLETION -> new ElasticInferenceServiceChatCompletionModel(
513-
inferenceEntityId,
514-
taskType,
515-
NAME,
516-
serviceSettings,
517-
taskSettings,
518-
secretSettings,
519-
elasticInferenceServiceComponents,
520-
context
521-
);
522-
case COMPLETION -> new ElasticInferenceServiceCompletionModel(
511+
case CHAT_COMPLETION, COMPLETION -> new ElasticInferenceServiceCompletionModel(
523512
inferenceEntityId,
524513
taskType,
525514
NAME,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedCompletionRequestManager.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
1818
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
1919
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
20-
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceChatCompletionModel;
20+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
2121
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceUnifiedChatCompletionRequest;
2222
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
2323
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
@@ -32,7 +32,7 @@ public class ElasticInferenceServiceUnifiedCompletionRequestManager extends Elas
3232
private static final ResponseHandler HANDLER = createCompletionHandler();
3333

3434
public static ElasticInferenceServiceUnifiedCompletionRequestManager of(
35-
ElasticInferenceServiceChatCompletionModel model,
35+
ElasticInferenceServiceCompletionModel model,
3636
ThreadPool threadPool,
3737
TraceContext traceContext
3838
) {
@@ -43,11 +43,11 @@ public static ElasticInferenceServiceUnifiedCompletionRequestManager of(
4343
);
4444
}
4545

46-
private final ElasticInferenceServiceChatCompletionModel model;
46+
private final ElasticInferenceServiceCompletionModel model;
4747
private final TraceContext traceContext;
4848

4949
private ElasticInferenceServiceUnifiedCompletionRequestManager(
50-
ElasticInferenceServiceChatCompletionModel model,
50+
ElasticInferenceServiceCompletionModel model,
5151
ThreadPool threadPool,
5252
TraceContext traceContext
5353
) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,17 @@
1616
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
1717
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
1818
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
19+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
1920
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity;
2021
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity;
2122
import org.elasticsearch.xpack.inference.services.ServiceComponents;
2223
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler;
2324
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager;
2425
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
2526
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
26-
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceCompletionRequest;
2727
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceDenseTextEmbeddingsRequest;
2828
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest;
29+
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceUnifiedChatCompletionRequest;
2930
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
3031
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
3132
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
@@ -40,6 +41,8 @@
4041

4142
public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor {
4243

44+
public static final String USER_ROLE = "user";
45+
4346
static final ResponseHandler DENSE_TEXT_EMBEDDINGS_HANDLER = new ElasticInferenceServiceResponseHandler(
4447
"elastic dense text embedding",
4548
ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::fromResponse
@@ -127,8 +130,8 @@ public ExecutableAction create(ElasticInferenceServiceCompletionModel model, Map
127130
threadPool,
128131
model,
129132
COMPLETION_HANDLER,
130-
(chatCompletionInput) -> new ElasticInferenceServiceCompletionRequest(
131-
chatCompletionInput.getInputs(),
133+
(chatCompletionInput) -> new ElasticInferenceServiceUnifiedChatCompletionRequest(
134+
new UnifiedChatInput(chatCompletionInput.getInputs(), USER_ROLE, false),
132135
model,
133136
traceContext,
134137
extractRequestMetadataFromThreadContext(threadPool.getThreadContext())

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceChatCompletionModel.java

Lines changed: 0 additions & 124 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.inference.SecretSettings;
1717
import org.elasticsearch.inference.TaskSettings;
1818
import org.elasticsearch.inference.TaskType;
19+
import org.elasticsearch.inference.UnifiedCompletionRequest;
1920
import org.elasticsearch.rest.RestStatus;
2021
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
2122
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
@@ -26,13 +27,22 @@
2627
import java.net.URI;
2728
import java.net.URISyntaxException;
2829
import java.util.Map;
30+
import java.util.Objects;
2931

30-
/**
31-
* Adapter model for COMPLETION task type that converts simple text inputs into chat messages
32-
* and uses the chat completion endpoint.
33-
*/
3432
public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceExecutableActionModel {
3533

34+
public static ElasticInferenceServiceCompletionModel of(
35+
ElasticInferenceServiceCompletionModel model,
36+
UnifiedCompletionRequest request
37+
) {
38+
var originalModelServiceSettings = model.getServiceSettings();
39+
var overriddenServiceSettings = new ElasticInferenceServiceCompletionServiceSettings(
40+
Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId())
41+
);
42+
43+
return new ElasticInferenceServiceCompletionModel(model, overriddenServiceSettings);
44+
}
45+
3646
private final URI uri;
3747

3848
public ElasticInferenceServiceCompletionModel(
@@ -56,6 +66,14 @@ public ElasticInferenceServiceCompletionModel(
5666
);
5767
}
5868

69+
public ElasticInferenceServiceCompletionModel(
70+
ElasticInferenceServiceCompletionModel model,
71+
ElasticInferenceServiceCompletionServiceSettings serviceSettings
72+
) {
73+
super(model, serviceSettings);
74+
this.uri = createUri();
75+
}
76+
5977
public ElasticInferenceServiceCompletionModel(
6078
String inferenceEntityId,
6179
TaskType taskType,
@@ -74,14 +92,6 @@ public ElasticInferenceServiceCompletionModel(
7492
this.uri = createUri();
7593
}
7694

77-
public ElasticInferenceServiceCompletionModel(
78-
ElasticInferenceServiceCompletionModel model,
79-
ElasticInferenceServiceCompletionServiceSettings serviceSettings
80-
) {
81-
super(model, serviceSettings);
82-
this.uri = createUri();
83-
}
84-
8595
@Override
8696
public ElasticInferenceServiceCompletionServiceSettings getServiceSettings() {
8797
return (ElasticInferenceServiceCompletionServiceSettings) super.getServiceSettings();
@@ -93,7 +103,7 @@ public URI uri() {
93103

94104
private URI createUri() throws ElasticsearchStatusException {
95105
try {
96-
// Use the same chat endpoint as CHAT_COMPLETION
106+
// TODO, consider transforming the base URL into a URI for better error handling.
97107
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/chat");
98108
} catch (URISyntaxException e) {
99109
throw new ElasticsearchStatusException(

0 commit comments

Comments
 (0)