Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/126856.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126856
summary: "[Draft][Not For Checkin] Current `SageMaker` work"
area: Machine Learning
type: enhancement
issues: []
5 changes: 5 additions & 0 deletions gradle/verification-metadata.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4912,6 +4912,11 @@
<sha256 value="da37cb021156b6aae5a30337e270a33a43817a64c59ca7aa4c39074cfda39a4b" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="software.amazon.awssdk" name="sagemakerruntime" version="2.30.38">
<artifact name="sagemakerruntime-2.30.38.jar">
<sha256 value="b26ee73fa06d047eab9a174e49627972e646c0bbe909f479c18dbff193b561f5" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="software.amazon.awssdk" name="sdk-core" version="2.30.38">
<artifact name="sdk-core-2.30.38.jar">
<sha256 value="556463b8c353408d93feab74719d141fcfda7fd3d7b7d1ad3a8a548b7cc2982d" origin="Generated by Gradle"/>
Expand Down
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ static TransportVersion def(int id) {
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
Expand Down Expand Up @@ -232,6 +233,7 @@ static TransportVersion def(int id) {
public static final TransportVersion BATCHED_QUERY_EXECUTION_DELAYABLE_WRITABLE = def(9_057_0_00);
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL = def(9_058_0_00);
public static final TransportVersion COMPRESS_DELAYABLE_WRITEABLE = def(9_059_0_00);
public static final TransportVersion ML_INFERENCE_SAGEMAKER = def(9_060_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ public final List<String> validationErrors() {
return validationErrors;
}

public final void throwIfValidationErrorsExist() {
if (validationErrors().isEmpty() == false) {
throw this;
}
}

@Override
public final String getMessage() {
StringBuilder sb = new StringBuilder();
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ dependencies {

/* AWS SDK v2 */
implementation ("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
implementation ("software.amazon.awssdk:sagemakerruntime:${versions.awsv2sdk}")
api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}"
api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}"
api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}"
Expand Down Expand Up @@ -140,6 +141,7 @@ tasks.named("dependencyLicenses").configure {
mapping from: /json-utils.*/, to: 'aws-sdk-2'
mapping from: /endpoints-spi.*/, to: 'aws-sdk-2'
mapping from: /bedrockruntime.*/, to: 'aws-sdk-2'
mapping from: /sagemakerruntime.*/, to: 'aws-sdk-2'
mapping from: /netty-nio-client/, to: 'aws-sdk-2'
/* Cannot use REGEX to match netty-* because netty-nio-client is an AWS package */
mapping from: /netty-buffer/, to: 'netty'
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/inference/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
requires org.reactivestreams;
requires org.elasticsearch.logging;
requires org.elasticsearch.sslconfig;
requires software.amazon.awssdk.services.sagemakerruntime;

exports org.elasticsearch.xpack.inference.action;
exports org.elasticsearch.xpack.inference.registry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
Expand Down Expand Up @@ -157,6 +159,8 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {

namedWriteables.addAll(StreamingTaskManager.namedWriteables());
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
namedWriteables.addAll(SageMakerModel.namedWriteables());
namedWriteables.addAll(SageMakerSchemas.namedWriteables());

return namedWriteables;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

Expand Down Expand Up @@ -293,6 +297,7 @@ public Collection<?> createComponents(PluginServices services) {
services.threadPool()
);

var sageMakerSchemas = new SageMakerSchemas();
inferenceServices.add(
() -> List.of(
context -> new ElasticInferenceService(
Expand All @@ -301,6 +306,15 @@ public Collection<?> createComponents(PluginServices services) {
inferenceServiceSettings,
modelRegistry,
authorizationHandler
),
context -> new SageMakerService(
new SageMakerModelBuilder(sageMakerSchemas),
new SageMakerClient(
new SageMakerClient.Factory(new HttpSettings(settings, services.clusterService())),
services.threadPool()
),
sageMakerSchemas,
services.threadPool()
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.TimeValue;

import java.time.Duration;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -55,6 +56,10 @@ public int connectionTimeout() {
return connectionTimeout;
}

public Duration connectionTimeoutDuration() {
return Duration.ofMillis(connectionTimeout);
}

private void setMaxResponseSize(ByteSizeValue maxResponseSize) {
this.maxResponseSize = maxResponseSize;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.sagemaker;

import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.profiles.ProfileFile;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ListenerTimeouts;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.CacheLoader;
import org.elasticsearch.common.util.concurrent.FutureUtils;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
import org.reactivestreams.FlowAdapters;

import java.io.Closeable;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;

public class SageMakerClient implements Closeable {
private static final Logger log = LogManager.getLogger(SageMakerClient.class);
private final Cache<RegionAndSecrets, SageMakerRuntimeAsyncClient> existingClients = CacheBuilder.<
RegionAndSecrets,
SageMakerRuntimeAsyncClient>builder()
.removalListener(removal -> removal.getValue().close())
.setExpireAfterAccess(TimeValue.timeValueMinutes(15))
.build();

private final CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory;
private final ThreadPool threadPool;

public SageMakerClient(CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory, ThreadPool threadPool) {
this.clientFactory = clientFactory;
this.threadPool = threadPool;
}

public void invoke(
RegionAndSecrets regionAndSecrets,
InvokeEndpointRequest request,
TimeValue timeout,
ActionListener<InvokeEndpointResponse> listener
) {
SageMakerRuntimeAsyncClient asyncClient;
try {
asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
} catch (ExecutionException e) {
listener.onFailure(clientFailure(regionAndSecrets, e));
return;
}

var awsFuture = asyncClient.invokeEndpoint(request);
var timeoutListener = ListenerTimeouts.wrapWithTimeout(
threadPool,
timeout,
threadPool.executor(UTILITY_THREAD_POOL_NAME),
listener,
ignored -> {
FutureUtils.cancel(awsFuture);
listener.onFailure(new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout));
}
);
awsFuture.thenAcceptAsync(timeoutListener::onResponse, threadPool.executor(UTILITY_THREAD_POOL_NAME))
.exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
}

private static Exception clientFailure(RegionAndSecrets regionAndSecrets, Exception cause) {
return new ElasticsearchStatusException(
"failed to create SageMakerRuntime client for region [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
cause,
regionAndSecrets.region()
);
}

private Void failAndMaybeThrowError(Throwable t, ActionListener<?> listener) {
if (t instanceof Exception e) {
listener.onFailure(e);
} else {
ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker.");
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker."));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker."));
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.", t));

}
return null; // Void
}

public void invokeStream(
RegionAndSecrets regionAndSecrets,
InvokeEndpointWithResponseStreamRequest request,
TimeValue timeout,
ActionListener<SageMakerStream> listener
) {
SageMakerRuntimeAsyncClient asyncClient;
try {
asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
} catch (ExecutionException e) {
listener.onFailure(clientFailure(regionAndSecrets, e));
return;
}

var responseStreamProcessor = new SageMakerStreamingResponseProcessor();
var cancelAwsRequestListener = new AtomicReference<CompletableFuture<?>>();
var timeoutListener = ListenerTimeouts.wrapWithTimeout(
threadPool,
timeout,
threadPool.executor(UTILITY_THREAD_POOL_NAME),
listener,
ignored -> {
FutureUtils.cancel(cancelAwsRequestListener.get());
listener.onFailure(new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout));
}
);
// To stay consistent with HTTP providers, we cancel the TimeoutListener onResponse because we are measuring the time it takes to
// start receiving bytes.
var responseStreamListener = InvokeEndpointWithResponseStreamResponseHandler.builder()
.onResponse(response -> timeoutListener.onResponse(new SageMakerStream(response, responseStreamProcessor)))
.onEventStream(publisher -> responseStreamProcessor.setPublisher(FlowAdapters.toFlowPublisher(publisher)))
.build();
var awsFuture = asyncClient.invokeEndpointWithResponseStream(request, responseStreamListener);
cancelAwsRequestListener.set(awsFuture);
awsFuture.exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
}

@Override
public void close() {
existingClients.invalidateAll(); // will close each cached client
}

public record RegionAndSecrets(String region, AwsSecretSettings secretSettings) {}

public static class Factory implements CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> {
private final HttpSettings httpSettings;

public Factory(HttpSettings httpSettings) {
this.httpSettings = httpSettings;
}

@Override
public SageMakerRuntimeAsyncClient load(RegionAndSecrets key) throws Exception {
SpecialPermission.check();
// TODO migrate to entitlements
return AccessController.doPrivileged((PrivilegedExceptionAction<SageMakerRuntimeAsyncClient>) () -> {
var credentials = AwsBasicCredentials.create(
key.secretSettings().accessKey().toString(),
key.secretSettings().secretKey().toString()
);
var credentialsProvider = StaticCredentialsProvider.create(credentials);
var clientConfig = NettyNioAsyncHttpClient.builder().connectionTimeout(httpSettings.connectionTimeoutDuration());
var override = ClientOverrideConfiguration.builder()
// disable profileFile, user credentials will always come from the configured Model Secrets
.defaultProfileFileSupplier(ProfileFile.aggregator()::build)
.defaultProfileFile(ProfileFile.aggregator().build())
.retryPolicy(retryPolicy -> retryPolicy.numRetries(3))
.retryStrategy(retryStrategy -> retryStrategy.maxAttempts(3))
.build();
return SageMakerRuntimeAsyncClient.builder()
.credentialsProvider(credentialsProvider)
.region(Region.of(key.region()))
.httpClientBuilder(clientConfig)
.overrideConfiguration(override)
.build();
});
}
}

private static class SageMakerStreamingResponseProcessor implements Flow.Publisher<ResponseStream> {
private static final Logger log = LogManager.getLogger(SageMakerStreamingResponseProcessor.class);
private final AtomicReference<Tuple<Flow.Publisher<ResponseStream>, Flow.Subscriber<? super ResponseStream>>> holder =
new AtomicReference<>(null);

@Override
public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) {
if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) {
log.debug("Subscriber connecting to publisher.");
var publisher = holder.getAndSet(null).v1();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

publisher.subscribe(subscriber);
} else {
log.debug("Subscriber waiting for connection.");
}
}

private void setPublisher(Flow.Publisher<ResponseStream> publisher) {
if (holder.compareAndSet(null, Tuple.tuple(publisher, null)) == false) {
log.debug("Publisher connecting to subscriber.");
var subscriber = holder.getAndSet(null).v2();
publisher.subscribe(subscriber);
} else {
log.debug("Publisher waiting for connection.");
}
}
}

public record SageMakerStream(InvokeEndpointWithResponseStreamResponse response, Flow.Publisher<ResponseStream> responseStream) {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.sagemaker;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;

import java.util.List;
import java.util.Objects;

public record SageMakerInferenceRequest(
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
InputType inputType
) {
public SageMakerInferenceRequest {
Objects.requireNonNull(input);
Objects.requireNonNull(inputType);
}
}
Loading
Loading