Skip to content

Commit 6b6d21b

Browse files
Adding support for completion
1 parent e4e9095 commit 6b6d21b

File tree

3 files changed

+51
-9
lines changed

3 files changed

+51
-9
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels;
2121
import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID;
2222
import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID;
23+
import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.GP_LLM_V2_COMPLETION_ENDPOINT_ID;
2324
import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID;
2425
import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID;
2526
import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID;
@@ -62,7 +63,7 @@ public void testGetDefaultEndpoints() throws IOException {
6263

6364
assertInferenceIdTaskType(allModels, RAINBOW_SPRINKLES_ENDPOINT_ID, TaskType.CHAT_COMPLETION);
6465
assertInferenceIdTaskType(allModels, GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, TaskType.CHAT_COMPLETION);
65-
assertInferenceIdTaskType(allModels, ".gp-llm-v2-completion", TaskType.COMPLETION);
66+
assertInferenceIdTaskType(allModels, GP_LLM_V2_COMPLETION_ENDPOINT_ID, TaskType.COMPLETION);
6667
assertInferenceIdTaskType(allModels, ELSER_V2_ENDPOINT_ID, TaskType.SPARSE_EMBEDDING);
6768
assertInferenceIdTaskType(allModels, JINA_EMBED_V3_ENDPOINT_ID, TaskType.TEXT_EMBEDDING);
6869
assertInferenceIdTaskType(allModels, RERANK_V1_ENDPOINT_ID, TaskType.RERANK);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationModel.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838
import java.util.function.Function;
3939
import java.util.stream.Collectors;
4040

41-
import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION;
42-
4341
/**
4442
* Transforms the response from {@link ElasticInferenceServiceAuthorizationRequestHandler} into a format for consumption by the service.
4543
*/
@@ -81,7 +79,8 @@ private static ElasticInferenceServiceModel createModel(
8179
}
8280

8381
return switch (taskType) {
84-
case CHAT_COMPLETION -> createCompletionModel(authorizedEndpoint, components);
82+
case CHAT_COMPLETION -> createCompletionModel(authorizedEndpoint, TaskType.CHAT_COMPLETION, components);
83+
case COMPLETION -> createCompletionModel(authorizedEndpoint, TaskType.COMPLETION, components);
8584
case SPARSE_EMBEDDING -> createSparseEmbeddingsModel(authorizedEndpoint, components);
8685
case TEXT_EMBEDDING -> createDenseTextEmbeddingsModel(authorizedEndpoint, components);
8786
case RERANK -> createRerankModel(authorizedEndpoint, components);
@@ -112,11 +111,12 @@ private static TaskType getTaskType(String taskType) {
112111

113112
private static ElasticInferenceServiceCompletionModel createCompletionModel(
114113
AuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint,
114+
TaskType taskType,
115115
ElasticInferenceServiceComponents components
116116
) {
117117
return new ElasticInferenceServiceCompletionModel(
118118
authorizedEndpoint.id(),
119-
CHAT_COMPLETION,
119+
taskType,
120120
ElasticInferenceService.NAME,
121121
new ElasticInferenceServiceCompletionServiceSettings(authorizedEndpoint.modelName()),
122122
EmptyTaskSettings.INSTANCE,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ public class AuthorizationResponseEntityTests extends AbstractBWCWireSerializati
5151

5252
// gp-llm-v2
5353
public static final String GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-chat_completion";
54+
public static final String GP_LLM_V2_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-completion";
5455
public static final String GP_LLM_V2_MODEL_NAME = "gp-llm-v2";
5556

5657
// elser-2
@@ -190,6 +191,19 @@ public record EisAuthorizationResponse(
190191
],
191192
"release_date": "2024-05-01"
192193
},
194+
{
195+
"id": ".gp-llm-v2-completion",
196+
"model_name": "gp-llm-v2",
197+
"task_types": {
198+
"eis": "chat",
199+
"elasticsearch": "completion"
200+
},
201+
"status": "ga",
202+
"properties": [
203+
"multilingual"
204+
],
205+
"release_date": "2024-05-01"
206+
},
193207
{
194208
"id": ".elser-2-elastic",
195209
"model_name": "elser_model_2",
@@ -300,7 +314,8 @@ public static AuthorizationResponseEntity.TaskTypeObject createTaskTypeObject(St
300314
public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEndpoints(String url) {
301315
var authorizedEndpoints = List.of(
302316
createRainbowSprinklesAuthorizedEndpoint(),
303-
createGpLlmV2AuthorizedEndpoint(),
317+
createGpLlmV2ChatCompletionAuthorizedEndpoint(),
318+
createGpLlmV2CompletionAuthorizedEndpoint(),
304319
createElserAuthorizedEndpoint(),
305320
createJinaEmbedAuthorizedEndpoint(),
306321
new AuthorizationResponseEntity.AuthorizedEndpoint(
@@ -322,7 +337,8 @@ public static EisAuthorizationResponse getEisAuthorizationResponseWithMultipleEn
322337
new AuthorizationResponseEntity(authorizedEndpoints),
323338
List.of(
324339
createRainbowSprinklesExpectedEndpoint(url),
325-
createGpLlmV2ExpectedEndpoint(url),
340+
createGpLlmV2ChatCompletionExpectedEndpoint(url),
341+
createGpLlmV2CompletionExpectedEndpoint(url),
326342
createElserExpectedEndpoint(url),
327343
createJinaExpectedEndpoint(url),
328344
new ElasticInferenceServiceRerankModel(
@@ -352,7 +368,7 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createRainbowSprin
352368
);
353369
}
354370

355-
private static ElasticInferenceServiceModel createGpLlmV2ExpectedEndpoint(String url) {
371+
private static ElasticInferenceServiceModel createGpLlmV2ChatCompletionExpectedEndpoint(String url) {
356372
return new ElasticInferenceServiceCompletionModel(
357373
GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID,
358374
TaskType.CHAT_COMPLETION,
@@ -364,7 +380,19 @@ private static ElasticInferenceServiceModel createGpLlmV2ExpectedEndpoint(String
364380
);
365381
}
366382

367-
private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2AuthorizedEndpoint() {
383+
private static ElasticInferenceServiceModel createGpLlmV2CompletionExpectedEndpoint(String url) {
384+
return new ElasticInferenceServiceCompletionModel(
385+
GP_LLM_V2_COMPLETION_ENDPOINT_ID,
386+
TaskType.COMPLETION,
387+
ElasticInferenceService.NAME,
388+
new ElasticInferenceServiceCompletionServiceSettings(GP_LLM_V2_MODEL_NAME),
389+
EmptyTaskSettings.INSTANCE,
390+
EmptySecretSettings.INSTANCE,
391+
new ElasticInferenceServiceComponents(url)
392+
);
393+
}
394+
395+
private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2ChatCompletionAuthorizedEndpoint() {
368396
return new AuthorizationResponseEntity.AuthorizedEndpoint(
369397
GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID,
370398
GP_LLM_V2_MODEL_NAME,
@@ -377,6 +405,19 @@ private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2Autho
377405
);
378406
}
379407

408+
private static AuthorizationResponseEntity.AuthorizedEndpoint createGpLlmV2CompletionAuthorizedEndpoint() {
409+
return new AuthorizationResponseEntity.AuthorizedEndpoint(
410+
GP_LLM_V2_COMPLETION_ENDPOINT_ID,
411+
GP_LLM_V2_MODEL_NAME,
412+
createTaskTypeObject(EIS_CHAT_PATH, "completion"),
413+
"ga",
414+
List.of("multilingual"),
415+
"2024-05-01",
416+
null,
417+
null
418+
);
419+
}
420+
380421
private static ElasticInferenceServiceModel createRainbowSprinklesExpectedEndpoint(String url) {
381422
return new ElasticInferenceServiceCompletionModel(
382423
RAINBOW_SPRINKLES_ENDPOINT_ID,

0 commit comments

Comments
 (0)