Skip to content

Commit f9a3721

Browse files
[ML] Adding new chat_completion task type for unified API (#119982)
* Creating new chat completion task type * Adding some comments * Refactoring names and removing todo * Exposing chat completion for openai and eis for now * Fixing tests
1 parent d7474e6 commit f9a3721

File tree

11 files changed

+283
-29
lines changed

11 files changed

+283
-29
lines changed

muted-tests.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,6 @@ tests:
224224
- class: org.elasticsearch.search.profile.dfs.DfsProfilerIT
225225
method: testProfileDfs
226226
issue: https://github.com/elastic/elasticsearch/issues/119711
227-
- class: org.elasticsearch.xpack.inference.InferenceCrudIT
228-
method: testGetServicesWithCompletionTaskType
229-
issue: https://github.com/elastic/elasticsearch/issues/119959
230227
- class: org.elasticsearch.multi_cluster.MultiClusterYamlTestSuiteIT
231228
issue: https://github.com/elastic/elasticsearch/issues/119983
232229
- class: org.elasticsearch.xpack.test.rest.XPackRestIT

server/src/main/java/org/elasticsearch/inference/TaskType.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ public enum TaskType implements Writeable {
2929
public boolean isAnyOrSame(TaskType other) {
3030
return true;
3131
}
32-
};
32+
},
33+
CHAT_COMPLETION;
3334

3435
public static final String NAME = "task_type";
3536

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
242242
@SuppressWarnings("unchecked")
243243
public void testGetServicesWithCompletionTaskType() throws IOException {
244244
List<Object> services = getServices(TaskType.COMPLETION);
245-
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
246-
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
247-
assertThat(services.size(), equalTo(10));
248-
} else {
249-
assertThat(services.size(), equalTo(9));
250-
}
245+
assertThat(services.size(), equalTo(9));
251246

252247
String[] providers = new String[services.size()];
253248
for (int i = 0; i < services.size(); i++) {
@@ -269,9 +264,30 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
269264
)
270265
);
271266

267+
assertArrayEquals(providers, providerList.toArray());
268+
}
269+
270+
@SuppressWarnings("unchecked")
271+
public void testGetServicesWithChatCompletionTaskType() throws IOException {
272+
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
272273
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
273274
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
274-
providerList.add(6, "elastic");
275+
assertThat(services.size(), equalTo(2));
276+
} else {
277+
assertThat(services.size(), equalTo(1));
278+
}
279+
280+
String[] providers = new String[services.size()];
281+
for (int i = 0; i < services.size(); i++) {
282+
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
283+
providers[i] = (String) serviceConfig.get("service");
284+
}
285+
286+
var providerList = new ArrayList<>(List.of("openai"));
287+
288+
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
289+
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
290+
providerList.addFirst("elastic");
275291
}
276292

277293
assertArrayEquals(providers, providerList.toArray());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@ public final class Paths {
3030
+ "}/{"
3131
+ INFERENCE_ID
3232
+ "}/_stream";
33-
static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_unified";
33+
34+
public static final String UNIFIED_SUFFIX = "_unified";
35+
static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + UNIFIED_SUFFIX;
3436
static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{"
3537
+ TASK_TYPE_OR_INFERENCE_ID
3638
+ "}/{"
3739
+ INFERENCE_ID
38-
+ "}/_unified";
40+
+ "}/"
41+
+ UNIFIED_SUFFIX;
3942

4043
private Paths() {
4144

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public void infer(
7373

7474
private static InferenceInputs createInput(Model model, List<String> input, @Nullable String query, boolean stream) {
7575
return switch (model.getTaskType()) {
76-
case COMPLETION -> new ChatCompletionInput(input, stream);
76+
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
7777
case RERANK -> new QueryAndDocsInputs(query, input, stream);
7878
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream);
7979
default -> throw new ElasticsearchStatusException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.ENABLED;
4343
import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS;
4444
import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS;
45+
import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_SUFFIX;
4546
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
4647

4748
public final class ServiceUtils {
@@ -780,5 +781,24 @@ public static void throwUnsupportedUnifiedCompletionOperation(String serviceName
780781
throw new UnsupportedOperationException(Strings.format("The %s service does not support unified completion", serviceName));
781782
}
782783

784+
public static String unsupportedTaskTypeForInference(Model model, EnumSet<TaskType> supportedTaskTypes) {
785+
return Strings.format(
786+
"Inference entity [%s] does not support task type [%s] for inference, the task type must be one of %s.",
787+
model.getInferenceEntityId(),
788+
model.getTaskType(),
789+
supportedTaskTypes
790+
);
791+
}
792+
793+
public static String useChatCompletionUrlMessage(Model model) {
794+
return org.elasticsearch.common.Strings.format(
795+
"The task type for the inference entity is %s, please use the _inference/%s/%s/%s URL.",
796+
model.getTaskType(),
797+
model.getTaskType(),
798+
model.getInferenceEntityId(),
799+
UNIFIED_SUFFIX
800+
);
801+
}
802+
783803
private ServiceUtils() {}
784804
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
2828
import org.elasticsearch.rest.RestStatus;
2929
import org.elasticsearch.tasks.Task;
30+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
3031
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
3132
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
3233
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
@@ -41,6 +42,7 @@
4142
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
4243
import org.elasticsearch.xpack.inference.services.SenderService;
4344
import org.elasticsearch.xpack.inference.services.ServiceComponents;
45+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4446
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
4547
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
4648
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
@@ -61,6 +63,7 @@
6163
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
6264
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
6365
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
66+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage;
6467

6568
public class ElasticInferenceService extends SenderService {
6669

@@ -69,8 +72,16 @@ public class ElasticInferenceService extends SenderService {
6972

7073
private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;
7174

72-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION);
75+
// The task types exposed via the _inference/_services API
76+
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of(
77+
TaskType.SPARSE_EMBEDDING,
78+
TaskType.CHAT_COMPLETION
79+
);
7380
private static final String SERVICE_NAME = "Elastic";
81+
/**
82+
* The task types that the {@link InferenceAction.Request} can accept.
83+
*/
84+
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING);
7485

7586
public ElasticInferenceService(
7687
HttpRequestSender.Factory factory,
@@ -83,7 +94,7 @@ public ElasticInferenceService(
8394

8495
@Override
8596
public Set<TaskType> supportedStreamingTasks() {
86-
return COMPLETION_ONLY;
97+
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY);
8798
}
8899

89100
@Override
@@ -129,6 +140,15 @@ protected void doInfer(
129140
TimeValue timeout,
130141
ActionListener<InferenceServiceResults> listener
131142
) {
143+
if (SUPPORTED_INFERENCE_ACTION_TASK_TYPES.contains(model.getTaskType()) == false) {
144+
var responseString = ServiceUtils.unsupportedTaskTypeForInference(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES);
145+
146+
if (model.getTaskType() == TaskType.CHAT_COMPLETION) {
147+
responseString = responseString + " " + useChatCompletionUrlMessage(model);
148+
}
149+
listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST));
150+
}
151+
132152
if (model instanceof ElasticInferenceServiceExecutableActionModel == false) {
133153
listener.onFailure(createInvalidModelException(model));
134154
return;
@@ -207,7 +227,7 @@ public InferenceServiceConfiguration getConfiguration() {
207227

208228
@Override
209229
public EnumSet<TaskType> supportedTaskTypes() {
210-
return supportedTaskTypes;
230+
return SUPPORTED_TASK_TYPES_FOR_SERVICES_API;
211231
}
212232

213233
private static ElasticInferenceServiceModel createModel(
@@ -383,7 +403,7 @@ public static InferenceServiceConfiguration get() {
383403

384404
return new InferenceServiceConfiguration.Builder().setService(NAME)
385405
.setName(SERVICE_NAME)
386-
.setTaskTypes(supportedTaskTypes)
406+
.setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API)
387407
.setConfigurations(configurationMap)
388408
.build();
389409
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.inference.TaskType;
2828
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
2929
import org.elasticsearch.rest.RestStatus;
30+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
3031
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3132
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3233
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
@@ -63,14 +64,24 @@
6364
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
6465
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
6566
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
67+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage;
6668
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
6769
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.ORGANIZATION;
6870

6971
public class OpenAiService extends SenderService {
7072
public static final String NAME = "openai";
7173

7274
private static final String SERVICE_NAME = "OpenAI";
73-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
75+
// The task types exposed via the _inference/_services API
76+
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of(
77+
TaskType.TEXT_EMBEDDING,
78+
TaskType.COMPLETION,
79+
TaskType.CHAT_COMPLETION
80+
);
81+
/**
82+
* The task types that the {@link InferenceAction.Request} can accept.
83+
*/
84+
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
7485

7586
public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
7687
super(factory, serviceComponents);
@@ -164,7 +175,7 @@ private static OpenAiModel createModel(
164175
secretSettings,
165176
context
166177
);
167-
case COMPLETION -> new OpenAiChatCompletionModel(
178+
case COMPLETION, CHAT_COMPLETION -> new OpenAiChatCompletionModel(
168179
inferenceEntityId,
169180
taskType,
170181
NAME,
@@ -236,7 +247,7 @@ public InferenceServiceConfiguration getConfiguration() {
236247

237248
@Override
238249
public EnumSet<TaskType> supportedTaskTypes() {
239-
return supportedTaskTypes;
250+
return SUPPORTED_TASK_TYPES_FOR_SERVICES_API;
240251
}
241252

242253
@Override
@@ -248,6 +259,15 @@ public void doInfer(
248259
TimeValue timeout,
249260
ActionListener<InferenceServiceResults> listener
250261
) {
262+
if (SUPPORTED_INFERENCE_ACTION_TASK_TYPES.contains(model.getTaskType()) == false) {
263+
var responseString = ServiceUtils.unsupportedTaskTypeForInference(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES);
264+
265+
if (model.getTaskType() == TaskType.CHAT_COMPLETION) {
266+
responseString = responseString + " " + useChatCompletionUrlMessage(model);
267+
}
268+
listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST));
269+
}
270+
251271
if (model instanceof OpenAiModel == false) {
252272
listener.onFailure(createInvalidModelException(model));
253273
return;
@@ -356,7 +376,7 @@ public TransportVersion getMinimalSupportedVersion() {
356376

357377
@Override
358378
public Set<TaskType> supportedStreamingTasks() {
359-
return COMPLETION_ONLY;
379+
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION, TaskType.ANY);
360380
}
361381

362382
/**
@@ -444,7 +464,7 @@ public static InferenceServiceConfiguration get() {
444464

445465
return new InferenceServiceConfiguration.Builder().setService(NAME)
446466
.setName(SERVICE_NAME)
447-
.setTaskTypes(supportedTaskTypes)
467+
.setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API)
448468
.setConfigurations(configurationMap)
449469
.build();
450470
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,19 +148,24 @@ private static <T> void blockingCall(
148148
latch.await();
149149
}
150150

151-
public static Model getInvalidModel(String inferenceEntityId, String serviceName) {
151+
public static Model getInvalidModel(String inferenceEntityId, String serviceName, TaskType taskType) {
152152
var mockConfigs = mock(ModelConfigurations.class);
153153
when(mockConfigs.getInferenceEntityId()).thenReturn(inferenceEntityId);
154154
when(mockConfigs.getService()).thenReturn(serviceName);
155-
when(mockConfigs.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
155+
when(mockConfigs.getTaskType()).thenReturn(taskType);
156156

157157
var mockModel = mock(Model.class);
158+
when(mockModel.getInferenceEntityId()).thenReturn(inferenceEntityId);
158159
when(mockModel.getConfigurations()).thenReturn(mockConfigs);
159-
when(mockModel.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
160+
when(mockModel.getTaskType()).thenReturn(taskType);
160161

161162
return mockModel;
162163
}
163164

165+
public static Model getInvalidModel(String inferenceEntityId, String serviceName) {
166+
return getInvalidModel(inferenceEntityId, serviceName, TaskType.TEXT_EMBEDDING);
167+
}
168+
164169
public static SimilarityMeasure randomSimilarityMeasure() {
165170
return randomFrom(SimilarityMeasure.values());
166171
}

0 commit comments

Comments
 (0)