Skip to content

Commit bf2de8f

Browse files
committed
[ML] Integrate OpenAi Chat Completion in SageMaker
SageMaker now supports Completion and Chat Completion using the OpenAI interfaces. Additionally: - Fixed bug related to timeouts being nullable, default to 30s timeout - Exposed existing OpenAi request/response parsing logic for reuse
1 parent 5f256cc commit bf2de8f

18 files changed

+647
-149
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212
import org.elasticsearch.xcontent.XContentParserConfiguration;
1313
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
1414

15-
import java.io.IOException;
1615
import java.util.ArrayDeque;
1716
import java.util.Deque;
18-
import java.util.Iterator;
1917
import java.util.concurrent.Flow;
2018
import java.util.concurrent.atomic.AtomicBoolean;
2119
import java.util.concurrent.atomic.AtomicLong;
20+
import java.util.stream.Stream;
2221

2322
/**
2423
* Processor that delegates the {@link java.util.concurrent.Flow.Subscription} to the upstream {@link java.util.concurrent.Flow.Publisher}
@@ -34,19 +33,13 @@ public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R>
3433
public static <ParsedChunk> Deque<ParsedChunk> parseEvent(
3534
Deque<ServerSentEvent> item,
3635
ParseChunkFunction<ParsedChunk> parseFunction,
37-
XContentParserConfiguration parserConfig,
38-
Logger logger
39-
) throws Exception {
36+
XContentParserConfiguration parserConfig
37+
) {
4038
var results = new ArrayDeque<ParsedChunk>(item.size());
4139
for (ServerSentEvent event : item) {
4240
if (event.hasData()) {
43-
try {
44-
var delta = parseFunction.apply(parserConfig, event);
45-
delta.forEachRemaining(results::offer);
46-
} catch (Exception e) {
47-
logger.warn("Failed to parse event from inference provider: {}", event);
48-
throw e;
49-
}
41+
var delta = parseFunction.apply(parserConfig, event);
42+
delta.forEach(results::offer);
5043
}
5144
}
5245

@@ -55,7 +48,7 @@ public static <ParsedChunk> Deque<ParsedChunk> parseEvent(
5548

5649
@FunctionalInterface
5750
public interface ParseChunkFunction<ParsedChunk> {
58-
Iterator<ParsedChunk> apply(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException;
51+
Stream<ParsedChunk> apply(XContentParserConfiguration parserConfig, ServerSentEvent event);
5952
}
6053

6154
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
4545
private final boolean stream;
4646

4747
public UnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) {
48-
Objects.requireNonNull(unifiedChatInput);
48+
this(Objects.requireNonNull(unifiedChatInput).getRequest(), Objects.requireNonNull(unifiedChatInput).stream());
49+
}
4950

50-
this.unifiedRequest = unifiedChatInput.getRequest();
51-
this.stream = unifiedChatInput.stream();
51+
public UnifiedChatCompletionRequestEntity(UnifiedCompletionRequest unifiedRequest, boolean stream) {
52+
this.unifiedRequest = Objects.requireNonNull(unifiedRequest);
53+
this.stream = stream;
5254
}
5355

5456
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiStreamingProcessor.java

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ElasticsearchStatusException;
1213
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
1314
import org.elasticsearch.inference.InferenceServiceResults;
15+
import org.elasticsearch.rest.RestStatus;
1416
import org.elasticsearch.xcontent.XContentFactory;
1517
import org.elasticsearch.xcontent.XContentParser;
1618
import org.elasticsearch.xcontent.XContentParserConfiguration;
@@ -20,11 +22,10 @@
2022
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
2123

2224
import java.io.IOException;
23-
import java.util.Collections;
2425
import java.util.Deque;
25-
import java.util.Iterator;
2626
import java.util.Objects;
2727
import java.util.function.Predicate;
28+
import java.util.stream.Stream;
2829

2930
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3031
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
@@ -113,7 +114,7 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSe
113114
@Override
114115
protected void next(Deque<ServerSentEvent> item) throws Exception {
115116
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
116-
var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig, log);
117+
var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig);
117118

118119
if (results.isEmpty()) {
119120
upstream().request(1);
@@ -122,10 +123,9 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
122123
}
123124
}
124125

125-
private static Iterator<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event)
126-
throws IOException {
126+
public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) {
127127
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
128-
return Collections.emptyIterator();
128+
return Stream.empty();
129129
}
130130

131131
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
@@ -167,11 +167,14 @@ private static Iterator<StreamingChatCompletionResults.Result> parse(XContentPar
167167

168168
consumeUntilObjectEnd(parser); // end choices
169169
return ""; // stopped
170-
}).stream()
171-
.filter(Objects::nonNull)
172-
.filter(Predicate.not(String::isEmpty))
173-
.map(StreamingChatCompletionResults.Result::new)
174-
.iterator();
170+
}).stream().filter(Objects::nonNull).filter(Predicate.not(String::isEmpty)).map(StreamingChatCompletionResults.Result::new);
171+
} catch (IOException e) {
172+
throw new ElasticsearchStatusException(
173+
"Failed to parse event from inference provider: {}",
174+
RestStatus.INTERNAL_SERVER_ERROR,
175+
e,
176+
event
177+
);
175178
}
176179
}
177180
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponsePa
4040
@Override
4141
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
4242
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
43-
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));
44-
43+
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request.getInferenceEntityId(), m, e));
4544
flow.subscribe(serverSentEventProcessor);
4645
serverSentEventProcessor.subscribe(openAiProcessor);
4746
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
@@ -67,15 +66,15 @@ protected Exception buildError(String message, Request request, HttpResult resul
6766
}
6867
}
6968

70-
private static Exception buildMidStreamError(Request request, String message, Exception e) {
69+
public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) {
7170
var errorResponse = OpenAiErrorResponse.fromString(message);
7271
if (errorResponse instanceof OpenAiErrorResponse oer) {
7372
return new UnifiedChatCompletionException(
7473
RestStatus.INTERNAL_SERVER_ERROR,
7574
format(
7675
"%s for request from inference entity id [%s]. Error message: [%s]",
7776
SERVER_ERROR_OBJECT,
78-
request.getInferenceEntityId(),
77+
inferenceEntityId,
7978
errorResponse.getErrorMessage()
8079
),
8180
oer.type(),
@@ -87,7 +86,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex
8786
} else {
8887
return new UnifiedChatCompletionException(
8988
RestStatus.INTERNAL_SERVER_ERROR,
90-
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
89+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId),
9190
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
9291
"stream_error"
9392
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@
2222

2323
import java.io.IOException;
2424
import java.util.ArrayDeque;
25-
import java.util.Collections;
2625
import java.util.Deque;
27-
import java.util.Iterator;
2826
import java.util.List;
2927
import java.util.concurrent.LinkedBlockingDeque;
3028
import java.util.function.BiFunction;
29+
import java.util.stream.Stream;
3130

3231
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3332
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
@@ -86,7 +85,7 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
8685
} else if (event.hasData()) {
8786
try {
8887
var delta = parse(parserConfig, event);
89-
delta.forEachRemaining(results::offer);
88+
delta.forEach(results::offer);
9089
} catch (Exception e) {
9190
logger.warn("Failed to parse event from inference provider: {}", event);
9291
throw errorParser.apply(event.data(), e);
@@ -108,12 +107,12 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
108107
}
109108
}
110109

111-
private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
110+
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
112111
XContentParserConfiguration parserConfig,
113112
ServerSentEvent event
114113
) throws IOException {
115114
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
116-
return Collections.emptyIterator();
115+
return Stream.empty();
117116
}
118117

119118
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
@@ -124,7 +123,7 @@ private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChun
124123

125124
StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser);
126125

127-
return Collections.singleton(chunk).iterator();
126+
return Stream.of(chunk);
128127
}
129128
}
130129

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiChatCompletionResponseEntity.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ public class OpenAiChatCompletionResponseEntity {
6767
*/
6868

6969
public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException {
70-
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
70+
return fromResponse(response.body());
71+
}
72+
73+
public static ChatCompletionResults fromResponse(byte[] response) throws IOException {
74+
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response)) {
7175
return CompletionResult.PARSER.apply(p, null).toChatCompletionResults();
7276
}
7377
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
public class SageMakerService implements InferenceService {
4848
public static final String NAME = "sagemaker";
4949
private static final int DEFAULT_BATCH_SIZE = 256;
50+
private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS;
5051
private final SageMakerModelBuilder modelBuilder;
5152
private final SageMakerClient client;
5253
private final SageMakerSchemas schemas;
@@ -128,7 +129,7 @@ public void infer(
128129
boolean stream,
129130
Map<String, Object> taskSettings,
130131
InputType inputType,
131-
TimeValue timeout,
132+
@Nullable TimeValue timeout,
132133
ActionListener<InferenceServiceResults> listener
133134
) {
134135
if (model instanceof SageMakerModel == false) {
@@ -148,7 +149,7 @@ public void infer(
148149
client.invokeStream(
149150
regionAndSecrets,
150151
request,
151-
timeout,
152+
timeout != null ? timeout : DEFAULT_TIMEOUT,
152153
ActionListener.wrap(
153154
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)),
154155
e -> listener.onFailure(schema.error(sageMakerModel, e))
@@ -160,7 +161,7 @@ public void infer(
160161
client.invoke(
161162
regionAndSecrets,
162163
request,
163-
timeout,
164+
timeout != null ? timeout : DEFAULT_TIMEOUT,
164165
ActionListener.wrap(
165166
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())),
166167
e -> listener.onFailure(schema.error(sageMakerModel, e))
@@ -201,7 +202,7 @@ private static ElasticsearchStatusException internalFailure(Model model, Excepti
201202
public void unifiedCompletionInfer(
202203
Model model,
203204
UnifiedCompletionRequest request,
204-
TimeValue timeout,
205+
@Nullable TimeValue timeout,
205206
ActionListener<InferenceServiceResults> listener
206207
) {
207208
if (model instanceof SageMakerModel == false) {
@@ -217,7 +218,7 @@ public void unifiedCompletionInfer(
217218
client.invokeStream(
218219
regionAndSecrets,
219220
sagemakerRequest,
220-
timeout,
221+
timeout != null ? timeout : DEFAULT_TIMEOUT,
221222
ActionListener.wrap(
222223
response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)),
223224
e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e))
@@ -235,7 +236,7 @@ public void chunkedInfer(
235236
List<ChunkInferenceInput> input,
236237
Map<String, Object> taskSettings,
237238
InputType inputType,
238-
TimeValue timeout,
239+
@Nullable TimeValue timeout,
239240
ActionListener<List<ChunkedInference>> listener
240241
) {
241242
if (model instanceof SageMakerModel == false) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
import org.elasticsearch.inference.TaskType;
1313
import org.elasticsearch.rest.RestStatus;
1414
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
15+
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload;
1516
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;
1617

1718
import java.util.Arrays;
1819
import java.util.EnumSet;
20+
import java.util.HashMap;
1921
import java.util.List;
2022
import java.util.Map;
2123
import java.util.Set;
24+
import java.util.function.Predicate;
2225
import java.util.stream.Collectors;
2326
import java.util.stream.Stream;
2427

@@ -39,7 +42,7 @@ public class SageMakerSchemas {
3942
/*
4043
* Add new model API to the register call.
4144
*/
42-
schemas = register(new OpenAiTextEmbeddingPayload());
45+
schemas = register(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload());
4346

4447
streamSchemas = schemas.entrySet()
4548
.stream()
@@ -54,7 +57,13 @@ public class SageMakerSchemas {
5457
.collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet())));
5558

5659
supportedStreamingTasks = streamSchemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet());
57-
supportedTaskTypes = EnumSet.copyOf(schemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet()));
60+
supportedTaskTypes = EnumSet.copyOf(
61+
schemas.keySet()
62+
.stream()
63+
.map(TaskAndApi::taskType)
64+
.filter(Predicate.not(TaskType.CHAT_COMPLETION::equals)) // chat_completion is currently never supported for non-streaming
65+
.collect(Collectors.toSet())
66+
);
5867
}
5968

6069
private static Map<TaskAndApi, SageMakerSchema> register(SageMakerSchemaPayload... payloads) {
@@ -88,7 +97,16 @@ public static List<NamedWriteableRegistry.Entry> namedWriteables() {
8897
)
8998
),
9099
schemas.values().stream().flatMap(SageMakerSchema::namedWriteables)
91-
).toList();
100+
)
101+
// Dedupe based on Entry name, we allow Payloads to declare the same Entry but the Registry does not handle duplicates
102+
.collect(
103+
() -> new HashMap<String, NamedWriteableRegistry.Entry>(),
104+
(map, entry) -> map.putIfAbsent(entry.name, entry),
105+
Map::putAll
106+
)
107+
.values()
108+
.stream()
109+
.toList();
92110
}
93111

94112
public SageMakerSchema schemaFor(SageMakerModel model) throws ElasticsearchStatusException {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchema.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.inference.UnifiedCompletionRequest;
2121
import org.elasticsearch.rest.RestStatus;
2222
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
23+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
2324
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
2425
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
2526
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
@@ -66,16 +67,16 @@ private InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel mod
6667
}
6768

6869
public InferenceServiceResults streamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
69-
return streamResponse(model, response, payload::streamResponseBody, this::error);
70+
return new StreamingChatCompletionResults(streamResponse(model, response, payload::streamResponseBody, this::error));
7071
}
7172

72-
private InferenceServiceResults streamResponse(
73+
private <T> Flow.Publisher<T> streamResponse(
7374
SageMakerModel model,
7475
SageMakerClient.SageMakerStream response,
75-
CheckedBiFunction<SageMakerModel, SdkBytes, InferenceServiceResults.Result, Exception> parseFunction,
76+
CheckedBiFunction<SageMakerModel, SdkBytes, T, Exception> parseFunction,
7677
BiFunction<SageMakerModel, Exception, Exception> errorFunction
7778
) {
78-
return new StreamingChatCompletionResults(downstream -> {
79+
return downstream -> {
7980
response.responseStream().subscribe(new Flow.Subscriber<>() {
8081
private volatile Flow.Subscription upstream;
8182

@@ -118,15 +119,17 @@ public void onComplete() {
118119
downstream.onComplete();
119120
}
120121
});
121-
});
122+
};
122123
}
123124

124125
public InvokeEndpointWithResponseStreamRequest chatCompletionStreamRequest(SageMakerModel model, UnifiedCompletionRequest request) {
125126
return streamRequest(model, () -> payload.chatCompletionRequestBytes(model, request));
126127
}
127128

128129
public InferenceServiceResults chatCompletionStreamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
129-
return streamResponse(model, response, payload::chatCompletionResponseBody, this::chatCompletionError);
130+
return new StreamingUnifiedChatCompletionResults(
131+
streamResponse(model, response, payload::chatCompletionResponseBody, this::chatCompletionError)
132+
);
130133
}
131134

132135
public UnifiedChatCompletionException chatCompletionError(SageMakerModel model, Exception e) {

0 commit comments

Comments
 (0)