Skip to content

Commit a9df8e3

Browse files
committed
Google Vertex AI completion model, response entity and tests
1 parent e241efa commit a9df8e3

File tree

7 files changed

+311
-8
lines changed

7 files changed

+311
-8
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4343
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
4444
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
45+
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
4546
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
4647
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
4748
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
@@ -75,7 +76,8 @@ public class GoogleVertexAiService extends SenderService {
7576
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
7677
TaskType.TEXT_EMBEDDING,
7778
TaskType.RERANK,
78-
TaskType.CHAT_COMPLETION
79+
TaskType.CHAT_COMPLETION,
80+
TaskType.COMPLETION
7981
);
8082

8183
public static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
@@ -220,11 +222,11 @@ protected void doInfer(
220222
return;
221223
}
222224

223-
GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model;
225+
var completionModel = (GoogleVertexAiCompletionModel) model;
224226

225227
var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents());
226228

227-
var action = googleVertexAiModel.accept(actionCreator, taskSettings);
229+
var action = completionModel.accept(actionCreator, taskSettings);
228230
action.execute(inputs, timeout, listener);
229231
}
230232

@@ -368,6 +370,16 @@ private static GoogleVertexAiModel createModel(
368370
context
369371
);
370372

373+
case COMPLETION -> new GoogleVertexAiCompletionModel(
374+
inferenceEntityId,
375+
taskType,
376+
NAME,
377+
serviceSettings,
378+
taskSettings,
379+
secretSettings,
380+
context
381+
);
382+
371383
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
372384
};
373385
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
2424
import org.elasticsearch.xpack.inference.external.http.HttpResult;
2525
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
26-
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
2726
import org.elasticsearch.xpack.inference.external.request.Request;
2827
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
2928
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
29+
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;
3030

3131
import java.nio.charset.StandardCharsets;
3232
import java.util.Locale;
@@ -43,10 +43,8 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVe
4343
private static final String ERROR_MESSAGE_FIELD = "message";
4444
private static final String ERROR_STATUS_FIELD = "status";
4545

46-
private static final ResponseParser noopParseFunction = (a, b) -> null;
47-
4846
public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) {
49-
super(requestType, noopParseFunction, GoogleVertexAiErrorResponse::fromResponse, true);
47+
super(requestType, GoogleVertexAiCompletionResponseEntity::fromResponse, GoogleVertexAiErrorResponse::fromResponse, true);
5048
}
5149

5250
@Override
@@ -63,7 +61,6 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR
6361

6462
@Override
6563
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
66-
assert request.isStreaming() : "Only streaming requests support this format";
6764
var responseStatusCode = result.response().getStatusLine().getStatusCode();
6865
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
6966
var restStatus = toRestStatus(responseStatusCode);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai.completion;
9+
10+
import org.apache.http.client.utils.URIBuilder;
11+
import org.elasticsearch.inference.TaskType;
12+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
13+
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils;
14+
15+
import java.net.URI;
16+
import java.net.URISyntaxException;
17+
import java.util.Map;
18+
19+
import static org.elasticsearch.core.Strings.format;
20+
21+
public class GoogleVertexAiCompletionModel extends GoogleVertexAiChatCompletionModel {
22+
23+
public GoogleVertexAiCompletionModel(
24+
String inferenceEntityId,
25+
TaskType taskType,
26+
String service,
27+
Map<String, Object> serviceSettings,
28+
Map<String, Object> taskSettings,
29+
Map<String, Object> secrets,
30+
ConfigurationParseContext context
31+
) {
32+
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, secrets, context);
33+
try {
34+
var modelServiceSettings = this.getServiceSettings();
35+
this.uri = buildUri(modelServiceSettings.location(), modelServiceSettings.projectId(), modelServiceSettings.modelId());
36+
} catch (URISyntaxException e) {
37+
throw new RuntimeException(e);
38+
}
39+
40+
}
41+
42+
public static URI buildUri(String location, String projectId, String model) throws URISyntaxException {
43+
return new URIBuilder().setScheme("https")
44+
.setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX))
45+
.setPathSegments(
46+
GoogleVertexAiUtils.V1,
47+
GoogleVertexAiUtils.PROJECTS,
48+
projectId,
49+
GoogleVertexAiUtils.LOCATIONS,
50+
GoogleVertexAiUtils.GLOBAL,
51+
GoogleVertexAiUtils.PUBLISHERS,
52+
GoogleVertexAiUtils.PUBLISHER_GOOGLE,
53+
GoogleVertexAiUtils.MODELS,
54+
format("%s:%s", model, GoogleVertexAiUtils.GENERATE_CONTENT)
55+
)
56+
.build();
57+
}
58+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ public final class GoogleVertexAiUtils {
3737

3838
public static final String STREAM_GENERATE_CONTENT = "streamGenerateContent";
3939

40+
public static final String GENERATE_CONTENT = "generateContent";
41+
4042
public static final String QUERY_PARAM_ALT_SSE = "alt=sse";
4143

4244
private GoogleVertexAiUtils() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai.response;
9+
10+
import org.elasticsearch.inference.InferenceServiceResults;
11+
import org.elasticsearch.xcontent.XContentFactory;
12+
import org.elasticsearch.xcontent.XContentParser;
13+
import org.elasticsearch.xcontent.XContentParserConfiguration;
14+
import org.elasticsearch.xcontent.XContentType;
15+
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
16+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
17+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
18+
import org.elasticsearch.xpack.inference.external.request.Request;
19+
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedStreamingProcessor;
20+
21+
import java.io.IOException;
22+
import java.nio.charset.StandardCharsets;
23+
24+
public class GoogleVertexAiCompletionResponseEntity {
25+
/**
26+
* Parses the response from Google Vertex AI's generateContent endpoint
27+
* For a request like:
28+
* <pre>
29+
* <code>
30+
* {
31+
* "inputs": "Please summarize this text: some text"
32+
* }
33+
* </code>
34+
* </pre>
35+
*
36+
* The response is a <a href="https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse">GenerateContentResponse</a> objects that looks like:
37+
*
38+
* <pre>
39+
* <code>
40+
*
41+
* {
42+
* "candidates": [
43+
* {
44+
* "content": {
45+
* "role": "model",
46+
* "parts": [
47+
* {
48+
* "text": "I am sorry, I cannot summarize the text because I do not have access to the text you are referring to."
49+
* }
50+
* ]
51+
* },
52+
* "finishReason": "STOP",
53+
* "avgLogprobs": -0.19326641248620074
54+
* }
55+
* ],
56+
* "usageMetadata": {
57+
* "promptTokenCount": 71,
58+
* "candidatesTokenCount": 23,
59+
* "totalTokenCount": 94,
60+
* "trafficType": "ON_DEMAND",
61+
* "promptTokensDetails": [
62+
* {
63+
* "modality": "TEXT",
64+
* "tokenCount": 71
65+
* }
66+
* ],
67+
* "candidatesTokensDetails": [
68+
* {
69+
* "modality": "TEXT",
70+
* "tokenCount": 23
71+
* }
72+
* ]
73+
* },
74+
* "modelVersion": "gemini-2.0-flash-001",
75+
* "createTime": "2025-05-28T15:08:20.049493Z",
76+
* "responseId": "5CY3aNWCA6mm4_UPr-zduAE"
77+
* }
78+
* </code>
79+
* </pre>
80+
*
81+
* @param request The original request made to the service.
82+
**/
83+
public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
84+
var responseJson = new String(response.body(), StandardCharsets.UTF_8);
85+
86+
// Response from generateContent has the same shape as streamGenerateContent. We reuse the already implemented
87+
// class to avoid code duplication
88+
89+
StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk;
90+
try (
91+
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
92+
.createParser(XContentParserConfiguration.EMPTY, responseJson)
93+
) {
94+
parser.nextToken();
95+
chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser);
96+
}
97+
var results = chunk.choices().stream().map(choice -> choice.delta().content()).map(ChatCompletionResults.Result::new).toList();
98+
99+
return new ChatCompletionResults(results);
100+
}
101+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai.completion;
9+
10+
import org.elasticsearch.core.Strings;
11+
import org.elasticsearch.inference.TaskType;
12+
import org.elasticsearch.test.ESTestCase;
13+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
14+
15+
import java.net.URI;
16+
import java.net.URISyntaxException;
17+
import java.util.HashMap;
18+
import java.util.Map;
19+
20+
import static org.hamcrest.Matchers.*;
21+
22+
public class GoogleVertexAiCompletionModelTests extends ESTestCase {
23+
24+
private static final String DEFAULT_PROJECT_ID = "test-project";
25+
private static final String DEFAULT_LOCATION = "us-central1";
26+
private static final String DEFAULT_MODEL_ID = "gemini-pro";
27+
28+
public void testCreateModel() throws URISyntaxException {
29+
var model = createCompletionModel(DEFAULT_PROJECT_ID, DEFAULT_LOCATION, DEFAULT_MODEL_ID);
30+
URI expectedUri = new URI(
31+
Strings.format(
32+
"https://%s-aiplatform.googleapis.com/v1/projects/%s" + "/locations/global/publishers/google/models/%s:generateContent",
33+
DEFAULT_LOCATION,
34+
DEFAULT_PROJECT_ID,
35+
DEFAULT_MODEL_ID
36+
37+
)
38+
);
39+
assertThat(model.uri(), is(expectedUri));
40+
}
41+
42+
private static GoogleVertexAiCompletionModel createCompletionModel(String projectId, String location, String modelId) {
43+
return new GoogleVertexAiCompletionModel(
44+
"google-vertex-ai-chat-test-id",
45+
TaskType.CHAT_COMPLETION,
46+
"google_vertex_ai",
47+
new HashMap<>(Map.of("project_id", projectId, "location", location, "model_id", modelId)),
48+
new HashMap<>(),
49+
new HashMap<>(Map.of("service_account_json", "{}")),
50+
ConfigurationParseContext.PERSISTENT
51+
);
52+
}
53+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai.response;
9+
10+
import org.apache.http.HttpResponse;
11+
import org.elasticsearch.core.Strings;
12+
import org.elasticsearch.test.ESTestCase;
13+
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
14+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
15+
import org.elasticsearch.xpack.inference.external.request.Request;
16+
17+
import java.io.IOException;
18+
import java.nio.charset.StandardCharsets;
19+
20+
import static org.hamcrest.Matchers.is;
21+
import static org.mockito.Mockito.mock;
22+
23+
public class GoogleVertexAiCompletionResponseEntityTests extends ESTestCase {
24+
25+
public void testFromResponse_Javadoc() throws IOException {
26+
var responseText = "I am sorry, I cannot summarize the text because I do not have access to the text you are referring to.";
27+
28+
String responseJson = Strings.format("""
29+
{
30+
"candidates": [
31+
{
32+
"content": {
33+
"role": "model",
34+
"parts": [
35+
{
36+
"text": "%s"
37+
}
38+
]
39+
},
40+
"finishReason": "STOP",
41+
"avgLogprobs": -0.19326641248620074
42+
}
43+
],
44+
"usageMetadata": {
45+
"promptTokenCount": 71,
46+
"candidatesTokenCount": 23,
47+
"totalTokenCount": 94,
48+
"trafficType": "ON_DEMAND",
49+
"promptTokensDetails": [
50+
{
51+
"modality": "TEXT",
52+
"tokenCount": 71
53+
}
54+
],
55+
"candidatesTokensDetails": [
56+
{
57+
"modality": "TEXT",
58+
"tokenCount": 23
59+
}
60+
]
61+
},
62+
"modelVersion": "gemini-2.0-flash-001",
63+
"createTime": "2025-05-28T15:08:20.049493Z",
64+
"responseId": "5CY3aNWCA6mm4_UPr-zduAE"
65+
}
66+
""", responseText);
67+
68+
var parsedResults = GoogleVertexAiCompletionResponseEntity.fromResponse(
69+
mock(Request.class),
70+
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
71+
);
72+
73+
assert parsedResults instanceof ChatCompletionResults;
74+
var results = (ChatCompletionResults) parsedResults;
75+
76+
assertThat(results.isStreaming(), is(false));
77+
assertThat(results.results().size(), is(1));
78+
assertThat(results.results().get(0).content(), is(responseText));
79+
}
80+
}

0 commit comments

Comments
 (0)