Skip to content

Commit 6cf7526

Browse files
authored
[ML] SageMaker Elastic Payload (elastic#129413) (elastic#129882)
Send the Elastic API Payload to a SageMaker endpoint, and parse the response as if it were an Elastic API response. - SageMaker now supports all task types in the Elastic API format. - Streaming is supported using the SageMaker client/server rpc, rather than SSE. Payloads must be in a complete and valid JSON structure. - Task Settings can be used for additional passthrough settings, but they will not be saved alongside the model. Elastic cannot make guarantees on the structure or contents of this payload, so Elastic will treat it like the other input payloads and only allow them during inference.
1 parent 8c7a882 commit 6cf7526

25 files changed

+1568
-55
lines changed

docs/changelog/129413.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129413
2+
summary: '`SageMaker` Elastic Payload'
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ static TransportVersion def(int id) {
247247
public static final TransportVersion STREAMS_LOGS_SUPPORT_8_19 = def(8_841_0_54);
248248
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19 = def(8_841_0_55);
249249
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56);
250+
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_57);
250251

251252
/*
252253
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ public static Params withMaxCompletionTokensTokens(String modelId, Params params
128128
);
129129
}
130130

131+
/**
132+
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
133+
* - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
134+
*/
135+
public static Params withMaxCompletionTokensTokens(Params params) {
136+
return new DelegatingMapParams(Map.of(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD), params);
137+
}
138+
131139
public sealed interface Content extends NamedWriteable, ToXContent permits ContentObjects, ContentString {}
132140

133141
@SuppressWarnings("unchecked")

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public int hashCode() {
8080
}
8181

8282
public record Result(String delta) implements ChunkedToXContent, Writeable {
83-
private static final String RESULT = "delta";
83+
public static final String RESULT = "delta";
8484

8585
private Result(StreamInput in) throws IOException {
8686
this(in.readString());

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
2222
import static org.hamcrest.Matchers.containsInAnyOrder;
23-
import static org.hamcrest.Matchers.equalTo;
2423

2524
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2625

@@ -31,13 +30,8 @@ public static void init() {
3130
}
3231

3332
public void testGetServicesWithoutTaskType() throws IOException {
34-
List<Object> services = getAllServices();
35-
assertThat(services.size(), equalTo(24));
36-
37-
var providers = providers(services);
38-
3933
assertThat(
40-
providers,
34+
allProviders(),
4135
containsInAnyOrder(
4236
List.of(
4337
"alibabacloud-ai-search",
@@ -69,6 +63,10 @@ public void testGetServicesWithoutTaskType() throws IOException {
6963
);
7064
}
7165

66+
private Iterable<String> allProviders() throws IOException {
67+
return providers(getAllServices());
68+
}
69+
7270
@SuppressWarnings("unchecked")
7371
private Iterable<String> providers(List<Object> services) {
7472
return services.stream().map(service -> {
@@ -78,13 +76,8 @@ private Iterable<String> providers(List<Object> services) {
7876
}
7977

8078
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
81-
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
82-
assertThat(services.size(), equalTo(17));
83-
84-
var providers = providers(services);
85-
8679
assertThat(
87-
providers,
80+
providersFor(TaskType.TEXT_EMBEDDING),
8881
containsInAnyOrder(
8982
List.of(
9083
"alibabacloud-ai-search",
@@ -109,14 +102,13 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
109102
);
110103
}
111104

112-
public void testGetServicesWithRerankTaskType() throws IOException {
113-
List<Object> services = getServices(TaskType.RERANK);
114-
assertThat(services.size(), equalTo(9));
115-
116-
var providers = providers(services);
105+
private Iterable<String> providersFor(TaskType taskType) throws IOException {
106+
return providers(getServices(taskType));
107+
}
117108

109+
public void testGetServicesWithRerankTaskType() throws IOException {
118110
assertThat(
119-
providers,
111+
providersFor(TaskType.RERANK),
120112
containsInAnyOrder(
121113
List.of(
122114
"alibabacloud-ai-search",
@@ -127,20 +119,16 @@ public void testGetServicesWithRerankTaskType() throws IOException {
127119
"jinaai",
128120
"test_reranking_service",
129121
"voyageai",
130-
"hugging_face"
122+
"hugging_face",
123+
"amazon_sagemaker"
131124
).toArray()
132125
)
133126
);
134127
}
135128

136129
public void testGetServicesWithCompletionTaskType() throws IOException {
137-
List<Object> services = getServices(TaskType.COMPLETION);
138-
assertThat(services.size(), equalTo(16));
139-
140-
var providers = providers(services);
141-
142130
assertThat(
143-
providers,
131+
providersFor(TaskType.COMPLETION),
144132
containsInAnyOrder(
145133
List.of(
146134
"alibabacloud-ai-search",
@@ -165,13 +153,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
165153
}
166154

167155
public void testGetServicesWithChatCompletionTaskType() throws IOException {
168-
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
169-
assertThat(services.size(), equalTo(8));
170-
171-
var providers = providers(services);
172-
173156
assertThat(
174-
providers,
157+
providersFor(TaskType.CHAT_COMPLETION),
175158
containsInAnyOrder(
176159
List.of(
177160
"deepseek",
@@ -188,13 +171,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
188171
}
189172

190173
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
191-
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
192-
assertThat(services.size(), equalTo(7));
193-
194-
var providers = providers(services);
195-
196174
assertThat(
197-
providers,
175+
providersFor(TaskType.SPARSE_EMBEDDING),
198176
containsInAnyOrder(
199177
List.of(
200178
"alibabacloud-ai-search",
@@ -203,7 +181,8 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
203181
"elasticsearch",
204182
"hugging_face",
205183
"streaming_completion_test_service",
206-
"test_service"
184+
"test_service",
185+
"amazon_sagemaker"
207186
).toArray()
208187
)
209188
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,14 @@ public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>
9797
return Stream.empty();
9898
}
9999

100-
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
100+
return parse(parserConfig, event.data());
101+
}
102+
103+
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
104+
XContentParserConfiguration parserConfig,
105+
String data
106+
) throws IOException {
107+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) {
101108
moveToFirstToken(jsonParser);
102109

103110
XContentParser.Token token = jsonParser.currentToken();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public InvokeEndpointRequest request(SageMakerModel model, SageMakerInferenceReq
6767
throw e;
6868
} catch (Exception e) {
6969
throw new ElasticsearchStatusException(
70-
"Failed to create SageMaker request for [%s]",
70+
"Failed to create SageMaker request for [{}]",
7171
RestStatus.INTERNAL_SERVER_ERROR,
7272
e,
7373
model.getInferenceEntityId()
@@ -98,7 +98,7 @@ public InferenceServiceResults response(SageMakerModel model, InvokeEndpointResp
9898
throw e;
9999
} catch (Exception e) {
100100
throw new ElasticsearchStatusException(
101-
"Failed to translate SageMaker response for [%s]",
101+
"Failed to translate SageMaker response for [{}]",
102102
RestStatus.INTERNAL_SERVER_ERROR,
103103
e,
104104
model.getInferenceEntityId()

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
import org.elasticsearch.inference.TaskType;
1313
import org.elasticsearch.rest.RestStatus;
1414
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
15+
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticCompletionPayload;
16+
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticRerankPayload;
17+
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticSparseEmbeddingPayload;
18+
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticTextEmbeddingPayload;
1519
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload;
1620
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;
1721

@@ -41,7 +45,14 @@ public class SageMakerSchemas {
4145
/*
4246
* Add new model API to the register call.
4347
*/
44-
schemas = register(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload());
48+
schemas = register(
49+
new OpenAiTextEmbeddingPayload(),
50+
new OpenAiCompletionPayload(),
51+
new ElasticTextEmbeddingPayload(),
52+
new ElasticSparseEmbeddingPayload(),
53+
new ElasticCompletionPayload(),
54+
new ElasticRerankPayload()
55+
);
4556

4657
streamSchemas = schemas.entrySet()
4758
.stream()
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic;
9+
10+
import software.amazon.awssdk.core.SdkBytes;
11+
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
12+
13+
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
15+
import org.elasticsearch.inference.UnifiedCompletionRequest;
16+
import org.elasticsearch.xcontent.ConstructingObjectParser;
17+
import org.elasticsearch.xcontent.ObjectParser;
18+
import org.elasticsearch.xcontent.ParseField;
19+
import org.elasticsearch.xcontent.XContentParser;
20+
import org.elasticsearch.xcontent.XContentParserConfiguration;
21+
import org.elasticsearch.xpack.core.inference.DequeUtils;
22+
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
23+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
24+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
25+
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
26+
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedStreamingProcessor;
27+
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
28+
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchemaPayload;
29+
30+
import java.util.ArrayDeque;
31+
import java.util.Deque;
32+
import java.util.List;
33+
34+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
35+
import static org.elasticsearch.xcontent.json.JsonXContent.jsonXContent;
36+
37+
/**
38+
* Streaming payloads are expected to be in the exact format of the Elastic API. This does *not* use the Server-Sent Event transport
39+
* protocol, rather this expects the SageMaker client and the implemented Endpoint to use AWS's transport protocol to deliver entire chunks.
40+
* Each chunk should be in a valid JSON format, as that is the format the Elastic API uses.
41+
*/
42+
public class ElasticCompletionPayload implements SageMakerStreamSchemaPayload, ElasticPayload {
43+
private static final XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(
44+
LoggingDeprecationHandler.INSTANCE
45+
);
46+
47+
/**
48+
* {
49+
* "completion": [
50+
* {
51+
* "result": "some result 1"
52+
* },
53+
* {
54+
* "result": "some result 2"
55+
* }
56+
* ]
57+
* }
58+
*/
59+
@Override
60+
public ChatCompletionResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
61+
try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) {
62+
return Completion.PARSER.apply(p, null);
63+
}
64+
}
65+
66+
/**
67+
* {
68+
* "completion": [
69+
* {
70+
* "delta": "some result 1"
71+
* },
72+
* {
73+
* "delta": "some result 2"
74+
* }
75+
* ]
76+
* }
77+
*/
78+
@Override
79+
public StreamingChatCompletionResults.Results streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception {
80+
try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.asInputStream())) {
81+
return StreamCompletion.PARSER.apply(p, null);
82+
}
83+
}
84+
85+
@Override
86+
public SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) {
87+
return SdkBytes.fromUtf8String(Strings.toString((builder, params) -> {
88+
request.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokensTokens(params));
89+
return builder;
90+
}));
91+
}
92+
93+
@Override
94+
public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody(SageMakerModel model, SdkBytes response) {
95+
var responseData = response.asUtf8String();
96+
try {
97+
var results = OpenAiUnifiedStreamingProcessor.parse(parserConfig, responseData)
98+
.collect(
99+
() -> new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(),
100+
ArrayDeque::offer,
101+
ArrayDeque::addAll
102+
);
103+
return new StreamingUnifiedChatCompletionResults.Results(results);
104+
} catch (Exception e) {
105+
throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), responseData, e);
106+
}
107+
}
108+
109+
private static class Completion {
110+
@SuppressWarnings("unchecked")
111+
private static final ConstructingObjectParser<ChatCompletionResults, Void> PARSER = new ConstructingObjectParser<>(
112+
ChatCompletionResults.class.getSimpleName(),
113+
IGNORE_UNKNOWN_FIELDS,
114+
args -> new ChatCompletionResults((List<ChatCompletionResults.Result>) args[0])
115+
);
116+
private static final ConstructingObjectParser<ChatCompletionResults.Result, Void> RESULT_PARSER = new ConstructingObjectParser<>(
117+
ChatCompletionResults.Result.class.getSimpleName(),
118+
IGNORE_UNKNOWN_FIELDS,
119+
args -> new ChatCompletionResults.Result((String) args[0])
120+
);
121+
122+
static {
123+
RESULT_PARSER.declareString(constructorArg(), new ParseField(ChatCompletionResults.Result.RESULT));
124+
PARSER.declareObjectArray(constructorArg(), RESULT_PARSER::apply, new ParseField(ChatCompletionResults.COMPLETION));
125+
}
126+
}
127+
128+
private static class StreamCompletion {
129+
@SuppressWarnings("unchecked")
130+
private static final ConstructingObjectParser<StreamingChatCompletionResults.Results, Void> PARSER = new ConstructingObjectParser<>(
131+
StreamingChatCompletionResults.Results.class.getSimpleName(),
132+
IGNORE_UNKNOWN_FIELDS,
133+
args -> new StreamingChatCompletionResults.Results((Deque<StreamingChatCompletionResults.Result>) args[0])
134+
);
135+
private static final ConstructingObjectParser<StreamingChatCompletionResults.Result, Void> RESULT_PARSER =
136+
new ConstructingObjectParser<>(
137+
StreamingChatCompletionResults.Result.class.getSimpleName(),
138+
IGNORE_UNKNOWN_FIELDS,
139+
args -> new StreamingChatCompletionResults.Result((String) args[0])
140+
);
141+
142+
static {
143+
RESULT_PARSER.declareString(constructorArg(), new ParseField(StreamingChatCompletionResults.Result.RESULT));
144+
PARSER.declareField(constructorArg(), (p, c) -> {
145+
var currentToken = p.currentToken();
146+
147+
// ES allows users to send single-value strings instead of an array of one value
148+
if (currentToken.isValue()
149+
|| currentToken == XContentParser.Token.VALUE_NULL
150+
|| currentToken == XContentParser.Token.START_OBJECT) {
151+
return DequeUtils.of(RESULT_PARSER.apply(p, c));
152+
}
153+
154+
var deque = new ArrayDeque<StreamingChatCompletionResults.Result>();
155+
XContentParser.Token token;
156+
while ((token = p.nextToken()) != XContentParser.Token.END_ARRAY) {
157+
if (token.isValue() || token == XContentParser.Token.VALUE_NULL || token == XContentParser.Token.START_OBJECT) {
158+
deque.offer(RESULT_PARSER.apply(p, c));
159+
} else {
160+
throw new IllegalStateException("expected value but got [" + token + "]");
161+
}
162+
}
163+
return deque;
164+
}, new ParseField(ChatCompletionResults.COMPLETION), ObjectParser.ValueType.OBJECT_ARRAY);
165+
}
166+
}
167+
}

0 commit comments

Comments
 (0)