Skip to content

Commit 577a6f8

Browse files
authored
[ML] Integrate SageMaker with OpenAI Embeddings (elastic#126856) (elastic#127610)
Integrating with SageMaker. Current design: - SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well. - `api` implementations are extensions of `SageMakerSchemaPayload`, which supports: - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings` - conversion logic from model, service settings, task settings, and input to `SdkBytes` - conversion logic from responding `SdkBytes` to `InferenceServiceResults` - Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set - We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
1 parent 686a9e4 commit 577a6f8

39 files changed

+4309
-147
lines changed

docs/changelog/126856.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126856
2+
summary: "[ML] Integrate SageMaker with OpenAI Embeddings"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

gradle/verification-metadata.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4802,6 +4802,11 @@
48024802
<sha256 value="c83dd82a9d82ff8c7d2eb1bdb2ae9f9505b312dad9a6bf0b80bc0136653a3a24" origin="Generated by Gradle"/>
48034803
</artifact>
48044804
</component>
4805+
<component group="software.amazon.awssdk" name="sagemakerruntime" version="2.30.38">
4806+
<artifact name="sagemakerruntime-2.30.38.jar">
4807+
<sha256 value="b26ee73fa06d047eab9a174e49627972e646c0bbe909f479c18dbff193b561f5" origin="Generated by Gradle"/>
4808+
</artifact>
4809+
</component>
48054810
<component group="software.amazon.awssdk" name="sdk-core" version="2.30.38">
48064811
<artifact name="sdk-core-2.30.38.jar">
48074812
<sha256 value="556463b8c353408d93feab74719d141fcfda7fd3d7b7d1ad3a8a548b7cc2982d" origin="Generated by Gradle"/>

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ static TransportVersion def(int id) {
207207
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG_8_19 = def(8_841_0_18);
208208
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
209209
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20);
210+
public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21);
210211

211212
/*
212213
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/common/ValidationException.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ public final List<String> validationErrors() {
5353
return validationErrors;
5454
}
5555

56+
public final void throwIfValidationErrorsExist() {
57+
if (validationErrors().isEmpty() == false) {
58+
throw this;
59+
}
60+
}
61+
5662
@Override
5763
public final String getMessage() {
5864
StringBuilder sb = new StringBuilder();

x-pack/plugin/inference/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ dependencies {
6363

6464
/* AWS SDK v2 */
6565
implementation("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
66+
implementation("software.amazon.awssdk:sagemakerruntime:${versions.awsv2sdk}")
6667
api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}"
6768
api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}"
6869
api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}"
@@ -143,6 +144,7 @@ tasks.named("dependencyLicenses").configure {
143144
mapping from: /json-utils.*/, to: 'aws-sdk-2'
144145
mapping from: /endpoints-spi.*/, to: 'aws-sdk-2'
145146
mapping from: /bedrockruntime.*/, to: 'aws-sdk-2'
147+
mapping from: /sagemakerruntime.*/, to: 'aws-sdk-2'
146148
mapping from: /netty-nio-client/, to: 'aws-sdk-2'
147149
/* Cannot use REGEX to match netty-* because netty-nio-client is an AWS package */
148150
mapping from: /netty-buffer/, to: 'netty'

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 114 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -14,169 +14,165 @@
1414
import org.elasticsearch.inference.TaskType;
1515

1616
import java.io.IOException;
17-
import java.util.ArrayList;
1817
import java.util.List;
1918
import java.util.Map;
2019

2120
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
21+
import static org.hamcrest.Matchers.containsInAnyOrder;
2222
import static org.hamcrest.Matchers.equalTo;
2323

2424
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2525

26-
@SuppressWarnings("unchecked")
2726
public void testGetServicesWithoutTaskType() throws IOException {
2827
List<Object> services = getAllServices();
29-
assertThat(services.size(), equalTo(21));
30-
31-
String[] providers = new String[services.size()];
32-
for (int i = 0; i < services.size(); i++) {
33-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
34-
providers[i] = (String) serviceConfig.get("service");
35-
}
36-
37-
assertArrayEquals(
38-
List.of(
39-
"alibabacloud-ai-search",
40-
"amazonbedrock",
41-
"anthropic",
42-
"azureaistudio",
43-
"azureopenai",
44-
"cohere",
45-
"deepseek",
46-
"elastic",
47-
"elasticsearch",
48-
"googleaistudio",
49-
"googlevertexai",
50-
"hugging_face",
51-
"jinaai",
52-
"mistral",
53-
"openai",
54-
"streaming_completion_test_service",
55-
"test_reranking_service",
56-
"test_service",
57-
"text_embedding_test_service",
58-
"voyageai",
59-
"watsonxai"
60-
).toArray(),
61-
providers
28+
assertThat(services.size(), equalTo(22));
29+
30+
var providers = providers(services);
31+
32+
assertThat(
33+
providers,
34+
containsInAnyOrder(
35+
List.of(
36+
"alibabacloud-ai-search",
37+
"amazonbedrock",
38+
"anthropic",
39+
"azureaistudio",
40+
"azureopenai",
41+
"cohere",
42+
"deepseek",
43+
"elastic",
44+
"elasticsearch",
45+
"googleaistudio",
46+
"googlevertexai",
47+
"hugging_face",
48+
"jinaai",
49+
"mistral",
50+
"openai",
51+
"streaming_completion_test_service",
52+
"test_reranking_service",
53+
"test_service",
54+
"text_embedding_test_service",
55+
"voyageai",
56+
"watsonxai",
57+
"sagemaker"
58+
).toArray()
59+
)
6260
);
6361
}
6462

6563
@SuppressWarnings("unchecked")
64+
private Iterable<String> providers(List<Object> services) {
65+
return services.stream().map(service -> {
66+
var serviceConfig = (Map<String, Object>) service;
67+
return (String) serviceConfig.get("service");
68+
}).toList();
69+
}
70+
6671
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
6772
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
68-
assertThat(services.size(), equalTo(15));
69-
70-
String[] providers = new String[services.size()];
71-
for (int i = 0; i < services.size(); i++) {
72-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
73-
providers[i] = (String) serviceConfig.get("service");
74-
}
75-
76-
assertArrayEquals(
77-
List.of(
78-
"alibabacloud-ai-search",
79-
"amazonbedrock",
80-
"azureaistudio",
81-
"azureopenai",
82-
"cohere",
83-
"elasticsearch",
84-
"googleaistudio",
85-
"googlevertexai",
86-
"hugging_face",
87-
"jinaai",
88-
"mistral",
89-
"openai",
90-
"text_embedding_test_service",
91-
"voyageai",
92-
"watsonxai"
93-
).toArray(),
94-
providers
73+
assertThat(services.size(), equalTo(16));
74+
75+
var providers = providers(services);
76+
77+
assertThat(
78+
providers,
79+
containsInAnyOrder(
80+
List.of(
81+
"alibabacloud-ai-search",
82+
"amazonbedrock",
83+
"azureaistudio",
84+
"azureopenai",
85+
"cohere",
86+
"elasticsearch",
87+
"googleaistudio",
88+
"googlevertexai",
89+
"hugging_face",
90+
"jinaai",
91+
"mistral",
92+
"openai",
93+
"text_embedding_test_service",
94+
"voyageai",
95+
"watsonxai",
96+
"sagemaker"
97+
).toArray()
98+
)
9599
);
96100
}
97101

98-
@SuppressWarnings("unchecked")
99102
public void testGetServicesWithRerankTaskType() throws IOException {
100103
List<Object> services = getServices(TaskType.RERANK);
101104
assertThat(services.size(), equalTo(7));
102105

103-
String[] providers = new String[services.size()];
104-
for (int i = 0; i < services.size(); i++) {
105-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
106-
providers[i] = (String) serviceConfig.get("service");
107-
}
108-
109-
assertArrayEquals(
110-
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
111-
.toArray(),
112-
providers
106+
var providers = providers(services);
107+
108+
assertThat(
109+
providers,
110+
containsInAnyOrder(
111+
List.of(
112+
"alibabacloud-ai-search",
113+
"cohere",
114+
"elasticsearch",
115+
"googlevertexai",
116+
"jinaai",
117+
"test_reranking_service",
118+
"voyageai"
119+
).toArray()
120+
)
113121
);
114122
}
115123

116-
@SuppressWarnings("unchecked")
117124
public void testGetServicesWithCompletionTaskType() throws IOException {
118125
List<Object> services = getServices(TaskType.COMPLETION);
119126
assertThat(services.size(), equalTo(10));
120127

121-
String[] providers = new String[services.size()];
122-
for (int i = 0; i < services.size(); i++) {
123-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
124-
providers[i] = (String) serviceConfig.get("service");
125-
}
126-
127-
var providerList = new ArrayList<>(
128-
List.of(
129-
"alibabacloud-ai-search",
130-
"amazonbedrock",
131-
"anthropic",
132-
"azureaistudio",
133-
"azureopenai",
134-
"cohere",
135-
"deepseek",
136-
"googleaistudio",
137-
"openai",
138-
"streaming_completion_test_service"
128+
var providers = providers(services);
129+
130+
assertThat(
131+
providers,
132+
containsInAnyOrder(
133+
List.of(
134+
"alibabacloud-ai-search",
135+
"amazonbedrock",
136+
"anthropic",
137+
"azureaistudio",
138+
"azureopenai",
139+
"cohere",
140+
"deepseek",
141+
"googleaistudio",
142+
"openai",
143+
"streaming_completion_test_service"
144+
).toArray()
139145
)
140146
);
141-
142-
assertArrayEquals(providers, providerList.toArray());
143147
}
144148

145-
@SuppressWarnings("unchecked")
146149
public void testGetServicesWithChatCompletionTaskType() throws IOException {
147150
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
148151
assertThat(services.size(), equalTo(4));
149152

150-
String[] providers = new String[services.size()];
151-
for (int i = 0; i < services.size(); i++) {
152-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
153-
providers[i] = (String) serviceConfig.get("service");
154-
}
153+
var providers = providers(services);
155154

156-
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
155+
assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
157156
}
158157

159-
@SuppressWarnings("unchecked")
160158
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
161159
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
162160
assertThat(services.size(), equalTo(6));
163161

164-
String[] providers = new String[services.size()];
165-
for (int i = 0; i < services.size(); i++) {
166-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
167-
providers[i] = (String) serviceConfig.get("service");
168-
}
169-
170-
assertArrayEquals(
171-
List.of(
172-
"alibabacloud-ai-search",
173-
"elastic",
174-
"elasticsearch",
175-
"hugging_face",
176-
"streaming_completion_test_service",
177-
"test_service"
178-
).toArray(),
179-
providers
162+
var providers = providers(services);
163+
164+
assertThat(
165+
providers,
166+
containsInAnyOrder(
167+
List.of(
168+
"alibabacloud-ai-search",
169+
"elastic",
170+
"elasticsearch",
171+
"hugging_face",
172+
"streaming_completion_test_service",
173+
"test_service"
174+
).toArray()
175+
)
180176
);
181177
}
182178

x-pack/plugin/inference/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
requires org.elasticsearch.logging;
3737
requires org.elasticsearch.sslconfig;
3838
requires org.apache.commons.text;
39+
requires software.amazon.awssdk.services.sagemakerruntime;
3940

4041
exports org.elasticsearch.xpack.inference.action;
4142
exports org.elasticsearch.xpack.inference.registry;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
9393
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
9494
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
95+
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
96+
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
9597
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
9698
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
9799
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
@@ -157,6 +159,8 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
157159

158160
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
159161
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
162+
namedWriteables.addAll(SageMakerModel.namedWriteables());
163+
namedWriteables.addAll(SageMakerSchemas.namedWriteables());
160164

161165
return namedWriteables;
162166
}

0 commit comments

Comments
 (0)