Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/126856.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126856
summary: "[ML] Integrate SageMaker with OpenAI Embeddings"
area: Machine Learning
type: enhancement
issues: []
5 changes: 5 additions & 0 deletions gradle/verification-metadata.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4912,6 +4912,11 @@
<sha256 value="c83dd82a9d82ff8c7d2eb1bdb2ae9f9505b312dad9a6bf0b80bc0136653a3a24" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="software.amazon.awssdk" name="sagemakerruntime" version="2.30.38">
<artifact name="sagemakerruntime-2.30.38.jar">
<sha256 value="b26ee73fa06d047eab9a174e49627972e646c0bbe909f479c18dbff193b561f5" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="software.amazon.awssdk" name="sdk-core" version="2.30.38">
<artifact name="sdk-core-2.30.38.jar">
<sha256 value="556463b8c353408d93feab74719d141fcfda7fd3d7b7d1ad3a8a548b7cc2982d" origin="Generated by Gradle"/>
Expand Down
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ static TransportVersion def(int id) {
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
Expand Down Expand Up @@ -228,6 +229,7 @@ static TransportVersion def(int id) {
public static final TransportVersion DENSE_VECTOR_OFF_HEAP_STATS = def(9_062_00_0);
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER = def(9_063_0_00);
public static final TransportVersion SETTINGS_IN_DATA_STREAMS = def(9_064_0_00);
public static final TransportVersion ML_INFERENCE_SAGEMAKER = def(9_065_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ public final List<String> validationErrors() {
return validationErrors;
}

public final void throwIfValidationErrorsExist() {
if (validationErrors().isEmpty() == false) {
throw this;
}
}

@Override
public final String getMessage() {
StringBuilder sb = new StringBuilder();
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ dependencies {

/* AWS SDK v2 */
implementation("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
implementation("software.amazon.awssdk:sagemakerruntime:${versions.awsv2sdk}")
api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}"
api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}"
api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}"
Expand Down Expand Up @@ -142,6 +143,7 @@ tasks.named("dependencyLicenses").configure {
mapping from: /json-utils.*/, to: 'aws-sdk-2'
mapping from: /endpoints-spi.*/, to: 'aws-sdk-2'
mapping from: /bedrockruntime.*/, to: 'aws-sdk-2'
mapping from: /sagemakerruntime.*/, to: 'aws-sdk-2'
mapping from: /netty-nio-client/, to: 'aws-sdk-2'
/* Cannot use REGEX to match netty-* because netty-nio-client is an AWS package */
mapping from: /netty-buffer/, to: 'netty'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,163 +18,161 @@
import java.util.Map;

import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;

public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(21));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

assertArrayEquals(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"anthropic",
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"elastic",
"elasticsearch",
"googleaistudio",
"googlevertexai",
"hugging_face",
"jinaai",
"mistral",
"openai",
"streaming_completion_test_service",
"test_reranking_service",
"test_service",
"text_embedding_test_service",
"voyageai",
"watsonxai"
).toArray(),
providers
assertThat(services.size(), equalTo(22));

var providers = providers(services);

assertThat(
providers,
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"anthropic",
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"elastic",
"elasticsearch",
"googleaistudio",
"googlevertexai",
"hugging_face",
"jinaai",
"mistral",
"openai",
"streaming_completion_test_service",
"test_reranking_service",
"test_service",
"text_embedding_test_service",
"voyageai",
"watsonxai",
"sagemaker"
).toArray()
)
);
}

@SuppressWarnings("unchecked")
private Iterable<String> providers(List<Object> services) {
return services.stream().map(service -> {
var serviceConfig = (Map<String, Object>) service;
return (String) serviceConfig.get("service");
}).toList();
}

public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(15));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

assertArrayEquals(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"azureaistudio",
"azureopenai",
"cohere",
"elasticsearch",
"googleaistudio",
"googlevertexai",
"hugging_face",
"jinaai",
"mistral",
"openai",
"text_embedding_test_service",
"voyageai",
"watsonxai"
).toArray(),
providers
assertThat(services.size(), equalTo(16));

var providers = providers(services);

assertThat(
providers,
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"azureaistudio",
"azureopenai",
"cohere",
"elasticsearch",
"googleaistudio",
"googlevertexai",
"hugging_face",
"jinaai",
"mistral",
"openai",
"text_embedding_test_service",
"voyageai",
"watsonxai",
"sagemaker"
).toArray()
)
);
}

@SuppressWarnings("unchecked")
public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(7));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

assertArrayEquals(
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
.toArray(),
providers
var providers = providers(services);

assertThat(
providers,
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"cohere",
"elasticsearch",
"googlevertexai",
"jinaai",
"test_reranking_service",
"voyageai"
).toArray()
)
);
}

@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

assertArrayEquals(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"anthropic",
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
).toArray(),
providers
var providers = providers(services);

assertThat(
providers,
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"anthropic",
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
).toArray()
)
);
}

@SuppressWarnings("unchecked")
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(4));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}
var providers = providers(services);

assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
}

@SuppressWarnings("unchecked")
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
assertThat(services.size(), equalTo(6));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

assertArrayEquals(
List.of(
"alibabacloud-ai-search",
"elastic",
"elasticsearch",
"hugging_face",
"streaming_completion_test_service",
"test_service"
).toArray(),
providers
var providers = providers(services);

assertThat(
providers,
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"elastic",
"elasticsearch",
"hugging_face",
"streaming_completion_test_service",
"test_service"
).toArray()
)
);
}

Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/inference/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
requires org.elasticsearch.logging;
requires org.elasticsearch.sslconfig;
requires org.apache.commons.text;
requires software.amazon.awssdk.services.sagemakerruntime;

exports org.elasticsearch.xpack.inference.action;
exports org.elasticsearch.xpack.inference.registry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
Expand Down Expand Up @@ -157,6 +159,8 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {

namedWriteables.addAll(StreamingTaskManager.namedWriteables());
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
namedWriteables.addAll(SageMakerModel.namedWriteables());
namedWriteables.addAll(SageMakerSchemas.namedWriteables());

return namedWriteables;
}
Expand Down
Loading
Loading