Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/126856.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126856
summary: "[Draft][Not For Checkin] Current `SageMaker` work"
area: Machine Learning
type: enhancement
issues: []
5 changes: 5 additions & 0 deletions gradle/verification-metadata.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4902,6 +4902,11 @@
<sha256 value="da37cb021156b6aae5a30337e270a33a43817a64c59ca7aa4c39074cfda39a4b" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="software.amazon.awssdk" name="sagemakerruntime" version="2.30.38">
<artifact name="sagemakerruntime-2.30.38.jar">
<sha256 value="b26ee73fa06d047eab9a174e49627972e646c0bbe909f479c18dbff193b561f5" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="software.amazon.awssdk" name="sdk-core" version="2.30.38">
<artifact name="sdk-core-2.30.38.jar">
<sha256 value="556463b8c353408d93feab74719d141fcfda7fd3d7b7d1ad3a8a548b7cc2982d" origin="Generated by Gradle"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ public final List<String> validationErrors() {
return validationErrors;
}

public final void throwIfValidationErrorsExist() {
if (validationErrors().isEmpty() == false) {
throw this;
}
}

@Override
public final String getMessage() {
StringBuilder sb = new StringBuilder();
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ dependencies {

/* AWS SDK v2 */
implementation ("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
implementation ("software.amazon.awssdk:sagemakerruntime:${versions.awsv2sdk}")
api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}"
api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}"
api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}"
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/inference/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
requires org.reactivestreams;
requires org.elasticsearch.logging;
requires org.elasticsearch.sslconfig;
requires software.amazon.awssdk.services.sagemakerruntime;

exports org.elasticsearch.xpack.inference.action;
exports org.elasticsearch.xpack.inference.registry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
Expand Down Expand Up @@ -157,6 +159,8 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {

namedWriteables.addAll(StreamingTaskManager.namedWriteables());
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
namedWriteables.addAll(SageMakerModel.namedWriteables());
namedWriteables.addAll(SageMakerSchemas.namedWriteables());

return namedWriteables;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

Expand Down Expand Up @@ -293,6 +297,7 @@ public Collection<?> createComponents(PluginServices services) {
services.threadPool()
);

var sageMakerSchemas = new SageMakerSchemas();
inferenceServices.add(
() -> List.of(
context -> new ElasticInferenceService(
Expand All @@ -301,6 +306,15 @@ public Collection<?> createComponents(PluginServices services) {
inferenceServiceSettings,
modelRegistry,
authorizationHandler
),
context -> new SageMakerService(
new SageMakerModelBuilder(sageMakerSchemas),
new SageMakerClient(
new SageMakerClient.Factory(new HttpSettings(settings, services.clusterService())),
services.threadPool()
),
sageMakerSchemas,
services.threadPool()
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.TimeValue;

import java.time.Duration;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -55,6 +56,10 @@ public int connectionTimeout() {
return connectionTimeout;
}

public Duration connectionTimeoutDuration() {
return Duration.ofMillis(connectionTimeout);
}

private void setMaxResponseSize(ByteSizeValue maxResponseSize) {
this.maxResponseSize = maxResponseSize;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.sagemaker;

import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.profiles.ProfileFile;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.CacheLoader;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
import org.reactivestreams.FlowAdapters;

import java.io.Closeable;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Flow;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;

public class SageMakerClient implements Closeable {
private static final Logger log = LogManager.getLogger(SageMakerClient.class);
private final Cache<RegionAndSecrets, SageMakerRuntimeAsyncClient> existingClients = CacheBuilder.<
RegionAndSecrets,
SageMakerRuntimeAsyncClient>builder()
.removalListener(removal -> removal.getValue().close())
.setExpireAfterAccess(TimeValue.timeValueMinutes(15))
.build();

private final CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory;
private final ThreadPool threadPool;

public SageMakerClient(CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory, ThreadPool threadPool) {
this.clientFactory = clientFactory;
this.threadPool = threadPool;
}

public void invoke(
RegionAndSecrets regionAndSecrets,
InvokeEndpointRequest request,
TimeValue timeout,
ActionListener<InvokeEndpointResponse> listener
) {
var asyncClient = getOrCreateClient(regionAndSecrets);
asyncClient.invokeEndpoint(request)
.orTimeout(timeout.seconds(), TimeUnit.SECONDS)
.thenAcceptAsync(listener::onResponse, threadPool.executor(UTILITY_THREAD_POOL_NAME))
.exceptionallyAsync(t -> failAndMaybeThrowError(t, listener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
}

private Void failAndMaybeThrowError(Throwable t, ActionListener<?> listener) {
if (t instanceof Exception e) {
listener.onFailure(e);
} else {
ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker.");
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker."));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker."));
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.", t));

}
return null; // Void
}

public void invokeStream(
RegionAndSecrets regionAndSecrets,
InvokeEndpointWithResponseStreamRequest request,
TimeValue timeout,
ActionListener<SageMakerStream> listener
) {
var asyncClient = getOrCreateClient(regionAndSecrets);
var runOnceListener = ActionListener.notifyOnce(listener);
var responseStreamProcessor = new SageMakerStreamingResponseProcessor();
var responseStreamListener = InvokeEndpointWithResponseStreamResponseHandler.builder()
.onResponse(response -> runOnceListener.onResponse(new SageMakerStream(response, responseStreamProcessor)))
.onEventStream(publisher -> responseStreamProcessor.setPublisher(FlowAdapters.toFlowPublisher(publisher)))
.build();
asyncClient.invokeEndpointWithResponseStream(request, responseStreamListener)
.orTimeout(timeout.seconds(), TimeUnit.SECONDS)
.exceptionallyAsync(t -> failAndMaybeThrowError(t, runOnceListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
}

private SageMakerRuntimeAsyncClient getOrCreateClient(RegionAndSecrets regionAndSecrets) {
try {
return existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
} catch (ExecutionException e) {
throw new ElasticsearchException("failed to create SageMakerRuntime client", e);
}
}

@Override
public void close() {
existingClients.invalidateAll(); // will close each cached client
}

public record RegionAndSecrets(String region, AwsSecretSettings secretSettings) {}

public static class Factory implements CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> {
private final HttpSettings httpSettings;

public Factory(HttpSettings httpSettings) {
this.httpSettings = httpSettings;
}

@Override
public SageMakerRuntimeAsyncClient load(RegionAndSecrets key) throws Exception {
SpecialPermission.check();
// TODO migrate to entitlements
return AccessController.doPrivileged((PrivilegedExceptionAction<SageMakerRuntimeAsyncClient>) () -> {
try (var accessKey = key.secretSettings().accessKey(); var secretKey = key.secretSettings().secretKey()) {
var credentials = AwsBasicCredentials.create(accessKey.toString(), secretKey.toString());
var credentialsProvider = StaticCredentialsProvider.create(credentials);
var clientConfig = NettyNioAsyncHttpClient.builder().connectionTimeout(httpSettings.connectionTimeoutDuration());
var override = ClientOverrideConfiguration.builder()
// disable profileFile, user credentials will always come from the configured Model Secrets
.defaultProfileFileSupplier(ProfileFile.aggregator()::build)
.defaultProfileFile(ProfileFile.aggregator().build())
.retryPolicy(retryPolicy -> retryPolicy.numRetries(3))
.retryStrategy(retryStrategy -> retryStrategy.maxAttempts(3))
.build();
return SageMakerRuntimeAsyncClient.builder()
.credentialsProvider(credentialsProvider)
.region(Region.of(key.region()))
.httpClientBuilder(clientConfig)
.overrideConfiguration(override)
.build();
}
});
}
}

private static class SageMakerStreamingResponseProcessor implements Flow.Publisher<ResponseStream> {
private static final Logger log = LogManager.getLogger(SageMakerStreamingResponseProcessor.class);
private final AtomicReference<Tuple<Flow.Publisher<ResponseStream>, Flow.Subscriber<? super ResponseStream>>> holder =
new AtomicReference<>(null);

@Override
public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) {
if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) {
log.debug("Subscriber connecting to publisher.");
var publisher = holder.getAndSet(null).v1();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

publisher.subscribe(subscriber);
} else {
log.debug("Subscriber waiting for connection.");
}
}

private void setPublisher(Flow.Publisher<ResponseStream> publisher) {
if (holder.compareAndSet(null, Tuple.tuple(publisher, null)) == false) {
log.debug("Publisher connecting to subscriber.");
var subscriber = holder.getAndSet(null).v2();
publisher.subscribe(subscriber);
} else {
log.debug("Publisher waiting for connection.");
}
}
}

public record SageMakerStream(InvokeEndpointWithResponseStreamResponse response, Flow.Publisher<ResponseStream> responseStream) {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.sagemaker;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;

import java.util.List;

public record SageMakerInferenceRequest(
String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
@Nullable List<String> input,
boolean stream,
InputType inputType
) {}
Loading
Loading