Skip to content

Commit 2ca9d7a

Browse files
authored
[ML] Stream Google Completion (elastic#114596) (elastic#114762)
Google supports SSE for chat completion and sends the same payload as their non-streaming calls, so we can reuse the SSE parser with our existing parse function. The downside is, google requires a different URI, so we refactored away from the visitor pattern to allow for a different URI creating and set during request time rather than on model instantiation time.
1 parent a5e0226 commit 2ca9d7a

File tree

20 files changed

+422
-210
lines changed

20 files changed

+422
-210
lines changed

docs/changelog/114596.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114596
2+
summary: Stream Google Completion
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java

Lines changed: 0 additions & 55 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java

Lines changed: 0 additions & 21 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,48 @@
88
package org.elasticsearch.xpack.inference.external.googleaistudio;
99

1010
import org.apache.logging.log4j.Logger;
11+
import org.elasticsearch.core.CheckedFunction;
12+
import org.elasticsearch.inference.InferenceServiceResults;
13+
import org.elasticsearch.xcontent.XContentParser;
14+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
1115
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1216
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
1317
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1418
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
1519
import org.elasticsearch.xpack.inference.external.request.Request;
1620
import org.elasticsearch.xpack.inference.external.response.googleaistudio.GoogleAiStudioErrorResponseEntity;
21+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
22+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
1723
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
1824

25+
import java.io.IOException;
26+
import java.util.concurrent.Flow;
27+
1928
import static org.elasticsearch.core.Strings.format;
2029
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
2130

2231
public class GoogleAiStudioResponseHandler extends BaseResponseHandler {
2332

2433
static final String GOOGLE_AI_STUDIO_UNAVAILABLE = "The Google AI Studio service may be temporarily overloaded or down";
34+
private final boolean canHandleStreamingResponses;
35+
private final CheckedFunction<XContentParser, String, IOException> content;
2536

2637
public GoogleAiStudioResponseHandler(String requestType, ResponseParser parseFunction) {
38+
this(requestType, parseFunction, false, xContentParser -> {
39+
assert false : "do not call this";
40+
return "";
41+
});
42+
}
43+
44+
public GoogleAiStudioResponseHandler(
45+
String requestType,
46+
ResponseParser parseFunction,
47+
boolean canHandleStreamingResponses,
48+
CheckedFunction<XContentParser, String, IOException> content
49+
) {
2750
super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse);
51+
this.canHandleStreamingResponses = canHandleStreamingResponses;
52+
this.content = content;
2853
}
2954

3055
@Override
@@ -72,4 +97,18 @@ private static String resourceNotFoundError(Request request) {
7297
return format("Resource not found at [%s]", request.getURI());
7398
}
7499

100+
@Override
101+
public boolean canHandleStreamingResponses() {
102+
return canHandleStreamingResponses;
103+
}
104+
105+
@Override
106+
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
107+
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
108+
var googleAiProcessor = new GoogleAiStudioStreamingProcessor(content);
109+
flow.subscribe(serverSentEventProcessor);
110+
serverSentEventProcessor.subscribe(googleAiProcessor);
111+
return new StreamingChatCompletionResults(googleAiProcessor);
112+
}
113+
75114
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.googleaistudio;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
13+
import org.elasticsearch.core.CheckedFunction;
14+
import org.elasticsearch.xcontent.XContentFactory;
15+
import org.elasticsearch.xcontent.XContentParser;
16+
import org.elasticsearch.xcontent.XContentParserConfiguration;
17+
import org.elasticsearch.xcontent.XContentType;
18+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
19+
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
20+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
21+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
22+
23+
import java.io.IOException;
24+
import java.util.ArrayDeque;
25+
import java.util.Deque;
26+
27+
class GoogleAiStudioStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, StreamingChatCompletionResults.Results> {
28+
private static final Logger log = LogManager.getLogger(GoogleAiStudioStreamingProcessor.class);
29+
private final CheckedFunction<XContentParser, String, IOException> content;
30+
31+
GoogleAiStudioStreamingProcessor(CheckedFunction<XContentParser, String, IOException> content) {
32+
this.content = content;
33+
}
34+
35+
@Override
36+
protected void next(Deque<ServerSentEvent> item) throws Exception {
37+
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
38+
var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
39+
for (ServerSentEvent event : item) {
40+
if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
41+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) {
42+
var delta = content.apply(jsonParser);
43+
results.offer(new StreamingChatCompletionResults.Result(delta));
44+
} catch (Exception e) {
45+
log.warn("Failed to parse event from inference provider: {}", event);
46+
throw e;
47+
}
48+
}
49+
}
50+
51+
if (results.isEmpty()) {
52+
upstream().request(1);
53+
} else {
54+
downstream().onNext(new StreamingChatCompletionResults.Results(results));
55+
}
56+
}
57+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.elasticsearch.xpack.inference.external.response.googleaistudio.GoogleAiStudioCompletionResponseEntity;
2020
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel;
2121

22-
import java.util.List;
2322
import java.util.Objects;
2423
import java.util.function.Supplier;
2524

@@ -32,7 +31,12 @@ public class GoogleAiStudioCompletionRequestManager extends GoogleAiStudioReques
3231
private final GoogleAiStudioCompletionModel model;
3332

3433
private static ResponseHandler createCompletionHandler() {
35-
return new GoogleAiStudioResponseHandler("google ai studio completion", GoogleAiStudioCompletionResponseEntity::fromResponse);
34+
return new GoogleAiStudioResponseHandler(
35+
"google ai studio completion",
36+
GoogleAiStudioCompletionResponseEntity::fromResponse,
37+
true,
38+
GoogleAiStudioCompletionResponseEntity::content
39+
);
3640
}
3741

3842
public GoogleAiStudioCompletionRequestManager(GoogleAiStudioCompletionModel model, ThreadPool threadPool) {
@@ -47,8 +51,7 @@ public void execute(
4751
Supplier<Boolean> hasRequestCompletedFunction,
4852
ActionListener<InferenceServiceResults> listener
4953
) {
50-
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
51-
GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(docsInput, model);
54+
GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(DocumentsOnlyInput.of(inferenceInputs), model);
5255
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5356
}
5457
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,46 +11,63 @@
1111
import org.apache.http.client.methods.HttpPost;
1212
import org.apache.http.entity.ByteArrayEntity;
1313
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.common.ValidationException;
15+
import org.elasticsearch.common.util.LazyInitializable;
1416
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
1518
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1619
import org.elasticsearch.xpack.inference.external.request.Request;
1720
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel;
1821

1922
import java.net.URI;
2023
import java.nio.charset.StandardCharsets;
21-
import java.util.List;
2224
import java.util.Objects;
2325

2426
public class GoogleAiStudioCompletionRequest implements GoogleAiStudioRequest {
27+
private static final String ALT_PARAM = "alt";
28+
private static final String SSE_VALUE = "sse";
2529

26-
private final List<String> input;
30+
private final DocumentsOnlyInput input;
2731

28-
private final URI uri;
32+
private final LazyInitializable<URI, RuntimeException> uri;
2933

3034
private final GoogleAiStudioCompletionModel model;
3135

32-
public GoogleAiStudioCompletionRequest(List<String> input, GoogleAiStudioCompletionModel model) {
33-
this.input = input;
36+
public GoogleAiStudioCompletionRequest(DocumentsOnlyInput input, GoogleAiStudioCompletionModel model) {
37+
this.input = Objects.requireNonNull(input);
3438
this.model = Objects.requireNonNull(model);
35-
this.uri = model.uri();
39+
this.uri = new LazyInitializable<>(() -> model.uri(input.stream()));
3640
}
3741

3842
@Override
3943
public HttpRequest createHttpRequest() {
40-
var httpPost = new HttpPost(uri);
41-
var requestEntity = Strings.toString(new GoogleAiStudioCompletionRequestEntity(input));
44+
var httpPost = createHttpPost();
45+
var requestEntity = Strings.toString(new GoogleAiStudioCompletionRequestEntity(input.getInputs()));
4246

4347
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
4448
httpPost.setEntity(byteEntity);
4549
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
46-
GoogleAiStudioRequest.decorateWithApiKeyParameter(httpPost, model.getSecretSettings());
4750

4851
return new HttpRequest(httpPost, getInferenceEntityId());
4952
}
5053

54+
private HttpPost createHttpPost() {
55+
try {
56+
var uriBuilder = GoogleAiStudioRequest.builderWithApiKeyParameter(uri.getOrCompute(), model.getSecretSettings());
57+
if (isStreaming()) {
58+
uriBuilder.addParameter(ALT_PARAM, SSE_VALUE);
59+
}
60+
return new HttpPost(uriBuilder.build());
61+
} catch (Exception e) {
62+
ValidationException validationException = new ValidationException(e);
63+
validationException.addValidationError(e.getMessage());
64+
throw validationException;
65+
}
66+
}
67+
5168
@Override
5269
public URI getURI() {
53-
return this.uri;
70+
return uri.getOrCompute();
5471
}
5572

5673
@Override
@@ -69,4 +86,9 @@ public boolean[] getTruncationInfo() {
6986
public String getInferenceEntityId() {
7087
return model.getInferenceEntityId();
7188
}
89+
90+
@Override
91+
public boolean isStreaming() {
92+
return input.stream();
93+
}
7294
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,15 @@
1313
import org.elasticsearch.xpack.inference.external.request.Request;
1414
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
1515

16+
import java.net.URI;
17+
1618
public interface GoogleAiStudioRequest extends Request {
1719

1820
String API_KEY_PARAMETER = "key";
1921

2022
static void decorateWithApiKeyParameter(HttpPost httpPost, DefaultSecretSettings secretSettings) {
2123
try {
22-
var uri = httpPost.getURI();
23-
var uriWithApiKey = new URIBuilder().setScheme(uri.getScheme())
24-
.setHost(uri.getHost())
25-
.setPort(uri.getPort())
26-
.setPath(uri.getPath())
27-
.addParameter(API_KEY_PARAMETER, secretSettings.apiKey().toString())
28-
.build();
29-
24+
var uriWithApiKey = builderWithApiKeyParameter(httpPost.getURI(), secretSettings).build();
3025
httpPost.setURI(uriWithApiKey);
3126
} catch (Exception e) {
3227
ValidationException validationException = new ValidationException(e);
@@ -35,4 +30,12 @@ static void decorateWithApiKeyParameter(HttpPost httpPost, DefaultSecretSettings
3530
}
3631
}
3732

33+
static URIBuilder builderWithApiKeyParameter(URI uri, DefaultSecretSettings secretSettings) {
34+
return new URIBuilder().setScheme(uri.getScheme())
35+
.setHost(uri.getHost())
36+
.setPort(uri.getPort())
37+
.setPath(uri.getPath())
38+
.addParameter(API_KEY_PARAMETER, secretSettings.apiKey().toString());
39+
}
40+
3841
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ public class GoogleAiStudioUtils {
1717

1818
public static final String GENERATE_CONTENT_ACTION = "generateContent";
1919

20+
public static final String STREAM_GENERATE_CONTENT_ACTION = "streamGenerateContent";
21+
2022
public static final String BATCH_EMBED_CONTENTS_ACTION = "batchEmbedContents";
2123

2224
private GoogleAiStudioUtils() {}

0 commit comments

Comments
 (0)