Skip to content

Commit ce6d45f

Browse files
committed
Create GoogleVertexAiChatCompletionResponseHandler for streaming and non streaming responses
1 parent b6f5e34 commit ce6d45f

File tree

10 files changed

+166
-8
lines changed

10 files changed

+166
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai;
9+
10+
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;
11+
12+
public class GoogleVertexAiChatCompletionResponseHandler extends GoogleVertexAiResponseHandler {
13+
14+
public GoogleVertexAiChatCompletionResponseHandler(String requestType) {
15+
super(
16+
requestType,
17+
GoogleVertexAiCompletionResponseEntity::fromResponse,
18+
GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse,
19+
true
20+
);
21+
}
22+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,19 @@
77

88
package org.elasticsearch.xpack.inference.services.googlevertexai;
99

10+
import org.elasticsearch.inference.InferenceServiceResults;
11+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
1012
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1113
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
1214
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1315
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1416
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
1517
import org.elasticsearch.xpack.inference.external.request.Request;
18+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
19+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
1620
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity;
1721

22+
import java.util.concurrent.Flow;
1823
import java.util.function.Function;
1924

2025
import static org.elasticsearch.core.Strings.format;
@@ -66,4 +71,14 @@ protected void checkForFailureStatusCode(Request request, HttpResult result) thr
6671
private static String resourceNotFoundError(Request request) {
6772
return format("Resource not found at [%s]", request.getURI());
6873
}
74+
75+
@Override
76+
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
77+
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
78+
var googleVertexAiProcessor = new GoogleVertexAiStreamingProcessor();
79+
80+
flow.subscribe(serverSentEventProcessor);
81+
serverSentEventProcessor.subscribe(googleVertexAiProcessor);
82+
return new StreamingChatCompletionResults(googleVertexAiProcessor);
83+
}
6984
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public class GoogleVertexAiService extends SenderService {
9595

9696
@Override
9797
public Set<TaskType> supportedStreamingTasks() {
98-
return EnumSet.of(TaskType.CHAT_COMPLETION);
98+
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
9999
}
100100

101101
public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
12+
import org.elasticsearch.inference.InferenceServiceResults;
13+
import org.elasticsearch.rest.RestStatus;
14+
import org.elasticsearch.xcontent.XContentFactory;
15+
import org.elasticsearch.xcontent.XContentParser;
16+
import org.elasticsearch.xcontent.XContentParserConfiguration;
17+
import org.elasticsearch.xcontent.XContentType;
18+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
19+
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
20+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
21+
22+
import java.io.IOException;
23+
import java.util.Deque;
24+
import java.util.Objects;
25+
import java.util.stream.Stream;
26+
27+
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
28+
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
29+
30+
public class GoogleVertexAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, InferenceServiceResults.Result> {
31+
32+
@Override
33+
protected void next(Deque<ServerSentEvent> item) throws Exception {
34+
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
35+
var results = parseEvent(item, GoogleVertexAiStreamingProcessor::parse, parserConfig);
36+
37+
if (results.isEmpty()) {
38+
upstream().request(1);
39+
} else {
40+
downstream().onNext(new StreamingChatCompletionResults.Results(results));
41+
}
42+
}
43+
44+
public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) {
45+
String data = event.data();
46+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) {
47+
moveToFirstToken(jsonParser);
48+
ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser);
49+
50+
var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(jsonParser);
51+
52+
return chunk.choices()
53+
.stream()
54+
.map(choice -> choice.delta())
55+
.filter(Objects::nonNull)
56+
.map(delta -> delta.content())
57+
.filter(content -> content != null && content.isEmpty() == false)
58+
.map(StreamingChatCompletionResults.Result::new);
59+
60+
} catch (IOException e) {
61+
throw new ElasticsearchStatusException(
62+
"Failed to parse event from inference provider: {}",
63+
RestStatus.INTERNAL_SERVER_ERROR,
64+
e,
65+
event
66+
);
67+
}
68+
}
69+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR
6161

6262
@Override
6363
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
64+
assert request.isStreaming() : "Only streaming requests support this format";
65+
6466
var responseStatusCode = result.response().getStatusLine().getStatusCode();
6567
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
6668
var restStatus = toRestStatus(responseStatusCode);
@@ -108,7 +110,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex
108110
}
109111
}
110112

111-
private static class GoogleVertexAiErrorResponse extends ErrorResponse {
113+
public static class GoogleVertexAiErrorResponse extends ErrorResponse {
112114
private static final Logger logger = LogManager.getLogger(GoogleVertexAiErrorResponse.class);
113115
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
114116
"google_vertex_ai_error_wrapper",
@@ -135,7 +137,7 @@ private static class GoogleVertexAiErrorResponse extends ErrorResponse {
135137
);
136138
}
137139

138-
static ErrorResponse fromResponse(HttpResult response) {
140+
public static ErrorResponse fromResponse(HttpResult response) {
139141
try (
140142
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
141143
.createParser(XContentParserConfiguration.EMPTY, response.body())

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.services.googlevertexai.action;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.rest.RestStatus;
1012
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1113
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1214
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
@@ -16,14 +18,17 @@
1618
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1719
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
1820
import org.elasticsearch.xpack.inference.services.ServiceComponents;
21+
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiChatCompletionResponseHandler;
1922
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager;
2023
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager;
2124
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler;
2225
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
26+
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
2327
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
2428
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
2529
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
2630

31+
import java.net.URISyntaxException;
2732
import java.util.Map;
2833
import java.util.Objects;
2934

@@ -36,9 +41,10 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor
3641

3742
private final ServiceComponents serviceComponents;
3843

39-
static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
44+
static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
4045
"Google VertexAI chat completion"
4146
);
47+
static final ResponseHandler CHAT_COMPLETION_HANDLER = new GoogleVertexAiChatCompletionResponseHandler("Google VertexAI completion");
4248
static final String USER_ROLE = "user";
4349

4450
public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
@@ -72,11 +78,31 @@ public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<Stri
7278
var manager = new GenericRequestManager<>(
7379
serviceComponents.threadPool(),
7480
model,
75-
COMPLETION_HANDLER,
81+
UNIFIED_CHAT_COMPLETION_HANDLER,
7682
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
7783
ChatCompletionInput.class
7884
);
7985

8086
return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
8187
}
88+
89+
@Override
90+
public ExecutableAction create(GoogleVertexAiCompletionModel model, Map<String, Object> taskSettings) {
91+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
92+
93+
var manager = new GenericRequestManager<>(serviceComponents.threadPool(), model, CHAT_COMPLETION_HANDLER, inputs -> {
94+
try {
95+
model.updateUri(inputs.stream());
96+
} catch (URISyntaxException e) {
97+
throw new ElasticsearchStatusException(
98+
"Error constructing URI for Google VertexAI completion",
99+
RestStatus.INTERNAL_SERVER_ERROR,
100+
e
101+
);
102+
}
103+
return new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model);
104+
}, ChatCompletionInput.class);
105+
106+
return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
107+
}
82108
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
12+
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
1213
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
1314
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
1415

@@ -21,4 +22,6 @@ public interface GoogleVertexAiActionVisitor {
2122
ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings);
2223

2324
ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings);
25+
26+
ExecutableAction create(GoogleVertexAiCompletionModel model, Map<String, Object> taskSettings);
2427
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiCompletionModel.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
import org.apache.http.client.utils.URIBuilder;
1111
import org.elasticsearch.inference.TaskType;
12+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1213
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
14+
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor;
1315
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils;
1416

1517
import java.net.URI;
@@ -39,6 +41,25 @@ public GoogleVertexAiCompletionModel(
3941

4042
}
4143

44+
public void updateUri(boolean isStream) throws URISyntaxException {
45+
var location = getServiceSettings().location();
46+
var projectId = getServiceSettings().projectId();
47+
var model = getServiceSettings().modelId();
48+
49+
// Google VertexAI generates streaming response using another API. We call this
50+
// method before making the request to be sure we are calling the right API
51+
if (isStream) {
52+
this.uri = GoogleVertexAiChatCompletionModel.buildUri(location, projectId, model);
53+
} else {
54+
this.uri = GoogleVertexAiCompletionModel.buildUri(location, projectId, model);
55+
}
56+
}
57+
58+
@Override
59+
public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) {
60+
return visitor.create(this, taskSettings);
61+
}
62+
4263
public static URI buildUri(String location, String projectId, String model) throws URISyntaxException {
4364
return new URIBuilder().setScheme("https")
4465
.setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX))

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
3939
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
4040
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
41-
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_HANDLER;
41+
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.UNIFIED_CHAT_COMPLETION_HANDLER;
4242
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.USER_ROLE;
4343
import static org.hamcrest.Matchers.is;
4444
import static org.mockito.ArgumentMatchers.any;
@@ -130,7 +130,7 @@ private ExecutableAction createAction(String location, String projectId, String
130130
var manager = new GenericRequestManager<>(
131131
threadPool,
132132
model,
133-
COMPLETION_HANDLER,
133+
UNIFIED_CHAT_COMPLETION_HANDLER,
134134
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
135135
ChatCompletionInput.class
136136
);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiCompletionModelTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import java.util.HashMap;
1818
import java.util.Map;
1919

20-
import static org.hamcrest.Matchers.*;
20+
import static org.hamcrest.Matchers.is;
2121

2222
public class GoogleVertexAiCompletionModelTests extends ESTestCase {
2323

0 commit comments

Comments
 (0)