Skip to content

Commit 412f30e

Browse files
authored
[Inference] Implementing the completion task type on EIS. (#137677)
1 parent 3749b6e commit 412f30e

File tree

9 files changed

+483
-13
lines changed

9 files changed

+483
-13
lines changed

docs/changelog/137677.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137677
2+
summary: "[Inference] Implementing the completion task type on EIS"
3+
area: "Inference"
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,12 @@ private void executeTaskImmediately(RejectableTask task) {
305305
e
306306
);
307307

308-
task.onRejection(
309-
new EsRejectedExecutionException(
310-
format("Failed to execute request for inference id [%s]", task.getRequestManager().inferenceEntityId()),
311-
false
312-
)
308+
var rejectionException = new EsRejectedExecutionException(
309+
format("Failed to execute request for inference id [%s]", task.getRequestManager().inferenceEntityId()),
310+
false
313311
);
312+
rejectionException.initCause(e);
313+
task.onRejection(rejectionException);
314314
}
315315
}
316316

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
4040
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
4141
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
42+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
4243
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
4344
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
4445
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
@@ -72,6 +73,7 @@
7273
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
7374
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
7475
import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage;
76+
import static org.elasticsearch.xpack.inference.services.openai.action.OpenAiActionCreator.USER_ROLE;
7577

7678
public class ElasticInferenceService extends SenderService {
7779

@@ -86,6 +88,7 @@ public class ElasticInferenceService extends SenderService {
8688
public static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
8789
TaskType.SPARSE_EMBEDDING,
8890
TaskType.CHAT_COMPLETION,
91+
TaskType.COMPLETION,
8992
TaskType.RERANK,
9093
TaskType.TEXT_EMBEDDING
9194
);
@@ -101,6 +104,7 @@ public class ElasticInferenceService extends SenderService {
101104
*/
102105
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(
103106
TaskType.SPARSE_EMBEDDING,
107+
TaskType.COMPLETION,
104108
TaskType.RERANK,
105109
TaskType.TEXT_EMBEDDING
106110
);
@@ -162,7 +166,8 @@ protected void doUnifiedCompletionInfer(
162166
TimeValue timeout,
163167
ActionListener<InferenceServiceResults> listener
164168
) {
165-
if (model instanceof ElasticInferenceServiceCompletionModel == false) {
169+
if (model instanceof ElasticInferenceServiceCompletionModel == false
170+
|| (model.getTaskType() != TaskType.CHAT_COMPLETION && model.getTaskType() != TaskType.COMPLETION)) {
166171
listener.onFailure(createInvalidModelException(model));
167172
return;
168173
}
@@ -212,10 +217,15 @@ protected void doInfer(
212217

213218
var elasticInferenceServiceModel = (ElasticInferenceServiceModel) model;
214219

220+
// For ElasticInferenceServiceCompletionModel, convert ChatCompletionInput to UnifiedChatInput
221+
// since the request manager expects UnifiedChatInput
222+
final InferenceInputs finalInputs = (elasticInferenceServiceModel instanceof ElasticInferenceServiceCompletionModel
223+
&& inputs instanceof ChatCompletionInput) ? new UnifiedChatInput((ChatCompletionInput) inputs, USER_ROLE) : inputs;
224+
215225
actionCreator.create(
216226
elasticInferenceServiceModel,
217227
currentTraceInfo,
218-
listener.delegateFailureAndWrap((delegate, action) -> action.execute(inputs, timeout, delegate))
228+
listener.delegateFailureAndWrap((delegate, action) -> action.execute(finalInputs, timeout, delegate))
219229
);
220230
}
221231

@@ -379,7 +389,7 @@ private static ElasticInferenceServiceModel createModel(
379389
context,
380390
chunkingSettings
381391
);
382-
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
392+
case CHAT_COMPLETION, COMPLETION -> new ElasticInferenceServiceCompletionModel(
383393
inferenceEntityId,
384394
taskType,
385395
NAME,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,6 @@ public String getInferenceEntityId() {
8484

8585
@Override
8686
public boolean isStreaming() {
87-
return true;
87+
return unifiedChatInput.stream();
8888
}
8989
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws
492492
thrownException.getMessage(),
493493
is(
494494
"Inference entity [model_id] does not support task type [chat_completion] "
495-
+ "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank]. "
495+
+ "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank, completion]. "
496496
+ "The task type for the inference entity is chat_completion, "
497497
+ "please use the _inference/chat_completion/model_id/_stream URL."
498498
)
@@ -1133,7 +1133,7 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp
11331133
webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(responseJson));
11341134
var model = new ElasticInferenceServiceCompletionModel(
11351135
"id",
1136-
TaskType.COMPLETION,
1136+
TaskType.CHAT_COMPLETION,
11371137
"elastic",
11381138
new ElasticInferenceServiceCompletionServiceSettings("model_id"),
11391139
EmptyTaskSettings.INSTANCE,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,7 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra
419419
List.of(
420420
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
421421
InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID,
422-
// EIS does not yet support completions so this model will be ignored
423-
EnumSet.of(TaskType.COMPLETION)
422+
EnumSet.noneOf(TaskType.class)
424423
)
425424
)
426425
)

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,63 @@ public void testOverridingModelId() {
4747
assertThat(overriddenModel.getServiceSettings().modelId(), is("new_model_id"));
4848
assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION));
4949
}
50+
51+
public void testUriCreation() {
52+
var url = "http://eis-gateway.com";
53+
var model = createModel(url, "my-model-id");
54+
55+
var uri = model.uri();
56+
assertThat(uri.toString(), is(url + "/api/v1/chat"));
57+
}
58+
59+
public void testGetServiceSettings() {
60+
var modelId = "test-model";
61+
var model = createModel("http://eis-gateway.com", modelId);
62+
63+
var serviceSettings = model.getServiceSettings();
64+
assertThat(serviceSettings.modelId(), is(modelId));
65+
}
66+
67+
public void testGetTaskType() {
68+
var model = createModel("http://eis-gateway.com", "my-model-id");
69+
assertThat(model.getTaskType(), is(TaskType.COMPLETION));
70+
}
71+
72+
public void testGetInferenceEntityId() {
73+
var inferenceEntityId = "test-id";
74+
var model = new ElasticInferenceServiceCompletionModel(
75+
inferenceEntityId,
76+
TaskType.COMPLETION,
77+
"elastic",
78+
new ElasticInferenceServiceCompletionServiceSettings("my-model-id"),
79+
EmptyTaskSettings.INSTANCE,
80+
EmptySecretSettings.INSTANCE,
81+
ElasticInferenceServiceComponents.of("http://eis-gateway.com")
82+
);
83+
84+
assertThat(model.getInferenceEntityId(), is(inferenceEntityId));
85+
}
86+
87+
public void testModelWithOverriddenServiceSettings() {
88+
var originalModel = createModel("http://eis-gateway.com", "original-model");
89+
var newServiceSettings = new ElasticInferenceServiceCompletionServiceSettings("new-model");
90+
91+
var overriddenModel = new ElasticInferenceServiceCompletionModel(originalModel, newServiceSettings);
92+
93+
assertThat(overriddenModel.getServiceSettings().modelId(), is("new-model"));
94+
assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION));
95+
assertThat(overriddenModel.uri().toString(), is(originalModel.uri().toString()));
96+
}
97+
98+
public static ElasticInferenceServiceCompletionModel createModel(String url, String modelId) {
99+
return new ElasticInferenceServiceCompletionModel(
100+
"id",
101+
TaskType.COMPLETION,
102+
"elastic",
103+
new ElasticInferenceServiceCompletionServiceSettings(modelId),
104+
EmptyTaskSettings.INSTANCE,
105+
EmptySecretSettings.INSTANCE,
106+
ElasticInferenceServiceComponents.of(url)
107+
);
108+
}
50109
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
import org.elasticsearch.xcontent.XContentBuilder;
1515
import org.elasticsearch.xcontent.json.JsonXContent;
1616
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
17+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModelTests;
1718
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
1819
import org.elasticsearch.xpack.inference.services.openai.request.OpenAiUnifiedChatCompletionRequestEntity;
1920

2021
import java.io.IOException;
2122
import java.util.ArrayList;
23+
import java.util.List;
2224

2325
import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals;
2426
import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel;
@@ -67,4 +69,152 @@ public void testModelUserFieldsSerialization() throws IOException {
6769
assertJsonEquals(jsonString, expectedJson);
6870
}
6971

72+
public void testSerialization_NonStreaming_ForCompletion() throws IOException {
73+
// Test non-streaming case (used for COMPLETION task type)
74+
var unifiedChatInput = new UnifiedChatInput(List.of("What is 2+2?"), ROLE, false);
75+
var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id");
76+
var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId());
77+
78+
XContentBuilder builder = JsonXContent.contentBuilder();
79+
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
80+
81+
String jsonString = Strings.toString(builder);
82+
String expectedJson = """
83+
{
84+
"messages": [
85+
{
86+
"content": "What is 2+2?",
87+
"role": "user"
88+
}
89+
],
90+
"model": "my-model-id",
91+
"n": 1,
92+
"stream": false
93+
}
94+
""";
95+
assertJsonEquals(jsonString, expectedJson);
96+
}
97+
98+
public void testSerialization_MultipleInputs_NonStreaming() throws IOException {
99+
// Test multiple inputs converted to messages (used for COMPLETION task type)
100+
var unifiedChatInput = new UnifiedChatInput(List.of("What is 2+2?", "What is the capital of France?"), ROLE, false);
101+
var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id");
102+
var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId());
103+
104+
XContentBuilder builder = JsonXContent.contentBuilder();
105+
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
106+
107+
String jsonString = Strings.toString(builder);
108+
String expectedJson = """
109+
{
110+
"messages": [
111+
{
112+
"content": "What is 2+2?",
113+
"role": "user"
114+
},
115+
{
116+
"content": "What is the capital of France?",
117+
"role": "user"
118+
}
119+
],
120+
"model": "my-model-id",
121+
"n": 1,
122+
"stream": false
123+
}
124+
""";
125+
assertJsonEquals(jsonString, expectedJson);
126+
}
127+
128+
public void testSerialization_EmptyInput_NonStreaming() throws IOException {
129+
var unifiedChatInput = new UnifiedChatInput(List.of(""), ROLE, false);
130+
var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id");
131+
var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId());
132+
133+
XContentBuilder builder = JsonXContent.contentBuilder();
134+
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
135+
136+
String jsonString = Strings.toString(builder);
137+
String expectedJson = """
138+
{
139+
"messages": [
140+
{
141+
"content": "",
142+
"role": "user"
143+
}
144+
],
145+
"model": "my-model-id",
146+
"n": 1,
147+
"stream": false
148+
}
149+
""";
150+
assertJsonEquals(jsonString, expectedJson);
151+
}
152+
153+
public void testSerialization_AlwaysSetsNToOne_NonStreaming() throws IOException {
154+
// Verify n is always 1 regardless of number of inputs
155+
var unifiedChatInput = new UnifiedChatInput(List.of("input1", "input2", "input3"), ROLE, false);
156+
var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id");
157+
var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId());
158+
159+
XContentBuilder builder = JsonXContent.contentBuilder();
160+
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
161+
162+
String jsonString = Strings.toString(builder);
163+
String expectedJson = """
164+
{
165+
"messages": [
166+
{
167+
"content": "input1",
168+
"role": "user"
169+
},
170+
{
171+
"content": "input2",
172+
"role": "user"
173+
},
174+
{
175+
"content": "input3",
176+
"role": "user"
177+
}
178+
],
179+
"model": "my-model-id",
180+
"n": 1,
181+
"stream": false
182+
}
183+
""";
184+
assertJsonEquals(jsonString, expectedJson);
185+
}
186+
187+
public void testSerialization_AllMessagesHaveUserRole_NonStreaming() throws IOException {
188+
// Verify all messages have "user" role when converting from simple inputs
189+
var unifiedChatInput = new UnifiedChatInput(List.of("first", "second", "third"), ROLE, false);
190+
var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "test-model");
191+
var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId());
192+
193+
XContentBuilder builder = JsonXContent.contentBuilder();
194+
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
195+
196+
String jsonString = Strings.toString(builder);
197+
String expectedJson = """
198+
{
199+
"messages": [
200+
{
201+
"content": "first",
202+
"role": "user"
203+
},
204+
{
205+
"content": "second",
206+
"role": "user"
207+
},
208+
{
209+
"content": "third",
210+
"role": "user"
211+
}
212+
],
213+
"model": "test-model",
214+
"n": 1,
215+
"stream": false
216+
}
217+
""";
218+
assertJsonEquals(jsonString, expectedJson);
219+
}
70220
}

0 commit comments

Comments
 (0)