Skip to content

Commit 408f473

Browse files
Update model to return correct model for CHAT_COMPLETION task type (#120326)
* Update model to return correct model for CHAT_COMPLETION task type * Update docs/changelog/120326.yaml * Delete docs/changelog/120326.yaml * Fixing chat completion functionality * Fixing tests * naming * Fixing tests * Adding more tests --------- Co-authored-by: Jonathan Buttner <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]>
1 parent 165f930 commit 408f473

File tree

24 files changed

+312
-62
lines changed

24 files changed

+312
-62
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ public ActionRequestValidationException validate() {
9292
return e;
9393
}
9494

95-
if (taskType.isAnyOrSame(TaskType.COMPLETION) == false) {
95+
if (taskType.isAnyOrSame(TaskType.CHAT_COMPLETION) == false) {
9696
var e = new ActionRequestValidationException();
97-
e.addValidationError("Field [taskType] must be [completion]");
97+
e.addValidationError("Field [taskType] must be [chat_completion]");
9898
return e;
9999
}
100100

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void testValidation_ReturnsException_When_TaskType_IsNot_Completion() {
5252
TimeValue.timeValueSeconds(10)
5353
);
5454
var exception = request.validate();
55-
assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [completion];"));
55+
assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [chat_completion];"));
5656
}
5757

5858
public void testValidation_ReturnsNull_When_TaskType_IsAny() {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,9 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
272272
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
273273
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
274274
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
275-
assertThat(services.size(), equalTo(2));
275+
assertThat(services.size(), equalTo(3));
276276
} else {
277-
assertThat(services.size(), equalTo(1));
277+
assertThat(services.size(), equalTo(2));
278278
}
279279

280280
String[] providers = new String[services.size()];
@@ -283,7 +283,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
283283
providers[i] = (String) serviceConfig.get("service");
284284
}
285285

286-
var providerList = new ArrayList<>(List.of("openai"));
286+
var providerList = new ArrayList<>(List.of("openai", "streaming_completion_test_service"));
287287

288288
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
289289
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
@@ -519,14 +519,19 @@ public void testSupportedStream() throws Exception {
519519

520520
public void testUnifiedCompletionInference() throws Exception {
521521
String modelId = "streaming";
522-
putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION));
522+
putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION));
523523
var singleModel = getModel(modelId);
524524
assertEquals(modelId, singleModel.get("inference_id"));
525-
assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type"));
525+
assertEquals(TaskType.CHAT_COMPLETION.toString(), singleModel.get("task_type"));
526526

527527
var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList();
528528
try {
529-
var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER);
529+
var events = unifiedCompletionInferOnMockService(
530+
modelId,
531+
TaskType.CHAT_COMPLETION,
532+
input,
533+
VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER
534+
);
530535
var expectedResponses = expectedResultsIterator(input);
531536
assertThat(events.size(), equalTo((input.size() + 1) * 2));
532537
events.forEach(event -> {

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ public List<Factory> getInferenceServiceFactories() {
5555

5656
public static class TestInferenceService extends AbstractTestInferenceService {
5757
private static final String NAME = "streaming_completion_test_service";
58-
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION);
58+
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
5959

60-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.COMPLETION);
60+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
6161

6262
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
6363

@@ -129,7 +129,7 @@ public void unifiedCompletionInfer(
129129
ActionListener<InferenceServiceResults> listener
130130
) {
131131
switch (model.getConfigurations().getTaskType()) {
132-
case COMPLETION -> listener.onResponse(makeUnifiedResults(request));
132+
case CHAT_COMPLETION -> listener.onResponse(makeUnifiedResults(request));
133133
default -> listener.onFailure(
134134
new ElasticsearchStatusException(
135135
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public TransportUnifiedCompletionInferenceAction(
5252

5353
@Override
5454
protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, UnparsedModel unparsedModel) {
55-
return request.getTaskType().isAnyOrSame(TaskType.COMPLETION) == false || unparsedModel.taskType() != TaskType.COMPLETION;
55+
return request.getTaskType().isAnyOrSame(TaskType.CHAT_COMPLETION) == false || unparsedModel.taskType() != TaskType.CHAT_COMPLETION;
5656
}
5757

5858
@Override
@@ -64,7 +64,7 @@ protected ElasticsearchStatusException createInvalidTaskTypeException(
6464
"Incompatible task_type for unified API, the requested type [{}] must be one of [{}]",
6565
RestStatus.BAD_REQUEST,
6666
request.getTaskType(),
67-
TaskType.COMPLETION.toString()
67+
TaskType.CHAT_COMPLETION.toString()
6868
);
6969
}
7070

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager {
2626

2727
private static final Logger logger = LogManager.getLogger(OpenAiCompletionRequestManager.class);
2828
private static final ResponseHandler HANDLER = createCompletionHandler();
29-
static final String USER_ROLE = "user";
29+
public static final String USER_ROLE = "user";
3030

3131
public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) {
3232
return new OpenAiCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ private static ElasticInferenceServiceModel createModel(
251251
eisServiceComponents,
252252
context
253253
);
254-
case COMPLETION -> new ElasticInferenceServiceCompletionModel(
254+
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
255255
inferenceEntityId,
256256
taskType,
257257
NAME,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ public static ModelValidator buildModelValidator(TaskType taskType) {
2323
case COMPLETION -> {
2424
return new ChatCompletionModelValidator(new SimpleServiceIntegrationValidator());
2525
}
26+
case CHAT_COMPLETION -> {
27+
return new ChatCompletionModelValidator(new SimpleChatCompletionServiceIntegrationValidator());
28+
}
2629
case SPARSE_EMBEDDING, RERANK, ANY -> {
2730
return new SimpleModelValidator(new SimpleServiceIntegrationValidator());
2831
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
2+
/*
3+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
4+
* or more contributor license agreements. Licensed under the Elastic License
5+
* 2.0; you may not use this file except in compliance with the Elastic License
6+
* 2.0.
7+
*/
8+
9+
package org.elasticsearch.xpack.inference.services.validation;
10+
11+
import org.elasticsearch.ElasticsearchStatusException;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceService;
14+
import org.elasticsearch.inference.InferenceServiceResults;
15+
import org.elasticsearch.inference.Model;
16+
import org.elasticsearch.rest.RestStatus;
17+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
18+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
19+
20+
import java.util.List;
21+
22+
import static org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager.USER_ROLE;
23+
24+
/**
25+
* This class uses the unified chat completion method to perform validation.
26+
*/
27+
public class SimpleChatCompletionServiceIntegrationValidator implements ServiceIntegrationValidator {
28+
private static final List<String> TEST_INPUT = List.of("how big");
29+
30+
@Override
31+
public void validate(InferenceService service, Model model, ActionListener<InferenceServiceResults> listener) {
32+
var chatCompletionInput = new UnifiedChatInput(TEST_INPUT, USER_ROLE, false);
33+
service.unifiedCompletionInfer(
34+
model,
35+
chatCompletionInput.getRequest(),
36+
InferenceAction.Request.DEFAULT_TIMEOUT,
37+
ActionListener.wrap(r -> {
38+
if (r != null) {
39+
listener.onResponse(r);
40+
} else {
41+
listener.onFailure(
42+
new ElasticsearchStatusException(
43+
"Could not complete inference endpoint creation as validation call to service returned null response.",
44+
RestStatus.BAD_REQUEST
45+
)
46+
);
47+
}
48+
}, e -> {
49+
listener.onFailure(
50+
new ElasticsearchStatusException(
51+
"Could not complete inference endpoint creation as validation call to service threw an exception.",
52+
RestStatus.BAD_REQUEST,
53+
e
54+
)
55+
);
56+
})
57+
);
58+
}
59+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,15 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
5757
private BaseTransportInferenceAction<Request> action;
5858

5959
protected static final String serviceId = "serviceId";
60-
protected static final TaskType taskType = TaskType.COMPLETION;
60+
protected final TaskType taskType;
6161
protected static final String inferenceId = "inferenceEntityId";
6262
protected InferenceServiceRegistry serviceRegistry;
6363
protected InferenceStats inferenceStats;
6464

65+
public BaseTransportInferenceActionTestCase(TaskType taskType) {
66+
this.taskType = taskType;
67+
}
68+
6569
@Before
6670
public void setUp() throws Exception {
6771
super.setUp();
@@ -377,7 +381,7 @@ protected void mockModelAndServiceRegistry(InferenceService service) {
377381
when(serviceRegistry.getService(any())).thenReturn(Optional.of(service));
378382
}
379383

380-
protected void mockValidLicenseState(){
384+
protected void mockValidLicenseState() {
381385
when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(true);
382386
}
383387
}

0 commit comments

Comments
 (0)