Skip to content

Commit acdca34

Browse files
committed
sagemaker
1 parent 8cb4493 commit acdca34

32 files changed

+2630
-0
lines changed

gradle/verification-metadata.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4902,6 +4902,11 @@
49024902
<sha256 value="da37cb021156b6aae5a30337e270a33a43817a64c59ca7aa4c39074cfda39a4b" origin="Generated by Gradle"/>
49034903
</artifact>
49044904
</component>
4905+
<component group="software.amazon.awssdk" name="sagemakerruntime" version="2.30.38">
4906+
<artifact name="sagemakerruntime-2.30.38.jar">
4907+
<sha256 value="b26ee73fa06d047eab9a174e49627972e646c0bbe909f479c18dbff193b561f5" origin="Generated by Gradle"/>
4908+
</artifact>
4909+
</component>
49054910
<component group="software.amazon.awssdk" name="sdk-core" version="2.30.38">
49064911
<artifact name="sdk-core-2.30.38.jar">
49074912
<sha256 value="556463b8c353408d93feab74719d141fcfda7fd3d7b7d1ad3a8a548b7cc2982d" origin="Generated by Gradle"/>

server/src/main/java/org/elasticsearch/common/ValidationException.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ public final List<String> validationErrors() {
5353
return validationErrors;
5454
}
5555

56+
public final void throwIfValidationErrorsExist() {
57+
if(validationErrors().isEmpty() == false) {
58+
throw this;
59+
}
60+
}
61+
5662
@Override
5763
public final String getMessage() {
5864
StringBuilder sb = new StringBuilder();

x-pack/plugin/inference/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ dependencies {
6060

6161
/* AWS SDK v2 */
6262
implementation ("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
63+
implementation ("software.amazon.awssdk:sagemakerruntime:${versions.awsv2sdk}")
6364
api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}"
6465
api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}"
6566
api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}"

x-pack/plugin/inference/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
requires org.reactivestreams;
3636
requires org.elasticsearch.logging;
3737
requires org.elasticsearch.sslconfig;
38+
requires software.amazon.awssdk.services.sagemakerruntime;
3839

3940
exports org.elasticsearch.xpack.inference.action;
4041
exports org.elasticsearch.xpack.inference.registry;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
9393
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
9494
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
95+
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
96+
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
9597
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
9698
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
9799
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
@@ -157,6 +159,8 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
157159

158160
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
159161
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
162+
namedWriteables.addAll(SageMakerModel.namedWriteables());
163+
namedWriteables.addAll(SageMakerSchemas.namedWriteables());
160164

161165
return namedWriteables;
162166
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@
132132
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
133133
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
134134
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
135+
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
136+
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService;
137+
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
138+
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
135139
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
136140
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
137141

@@ -293,6 +297,7 @@ public Collection<?> createComponents(PluginServices services) {
293297
services.threadPool()
294298
);
295299

300+
var sageMakerSchemas = new SageMakerSchemas();
296301
inferenceServices.add(
297302
() -> List.of(
298303
context -> new ElasticInferenceService(
@@ -301,6 +306,15 @@ public Collection<?> createComponents(PluginServices services) {
301306
inferenceServiceSettings,
302307
modelRegistry,
303308
authorizationHandler
309+
),
310+
context -> new SageMakerService(
311+
new SageMakerModelBuilder(sageMakerSchemas),
312+
new SageMakerClient(
313+
new SageMakerClient.Factory(new HttpSettings(settings, services.clusterService())),
314+
services.threadPool()
315+
),
316+
sageMakerSchemas,
317+
services.threadPool()
304318
)
305319
)
306320
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpSettings.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.unit.ByteSizeValue;
1515
import org.elasticsearch.core.TimeValue;
1616

17+
import java.time.Duration;
1718
import java.util.List;
1819
import java.util.Objects;
1920

@@ -55,6 +56,10 @@ public int connectionTimeout() {
5556
return connectionTimeout;
5657
}
5758

59+
public Duration connectionTimeoutDuration() {
60+
return Duration.ofMillis(connectionTimeout);
61+
}
62+
5863
private void setMaxResponseSize(ByteSizeValue maxResponseSize) {
5964
this.maxResponseSize = maxResponseSize;
6065
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.sagemaker;
9+
10+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
11+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
12+
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
13+
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
14+
import software.amazon.awssdk.profiles.ProfileFile;
15+
import software.amazon.awssdk.regions.Region;
16+
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient;
17+
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
18+
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
19+
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
20+
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
21+
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
22+
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
23+
24+
import org.apache.logging.log4j.LogManager;
25+
import org.apache.logging.log4j.Logger;
26+
import org.elasticsearch.ElasticsearchException;
27+
import org.elasticsearch.ExceptionsHelper;
28+
import org.elasticsearch.SpecialPermission;
29+
import org.elasticsearch.action.ActionListener;
30+
import org.elasticsearch.common.cache.Cache;
31+
import org.elasticsearch.common.cache.CacheBuilder;
32+
import org.elasticsearch.common.cache.CacheLoader;
33+
import org.elasticsearch.core.TimeValue;
34+
import org.elasticsearch.core.Tuple;
35+
import org.elasticsearch.threadpool.ThreadPool;
36+
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
37+
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
38+
import org.reactivestreams.FlowAdapters;
39+
40+
import java.io.Closeable;
41+
import java.security.AccessController;
42+
import java.security.PrivilegedExceptionAction;
43+
import java.util.concurrent.ExecutionException;
44+
import java.util.concurrent.Flow;
45+
import java.util.concurrent.TimeUnit;
46+
import java.util.concurrent.atomic.AtomicReference;
47+
48+
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
49+
50+
public class SageMakerClient implements Closeable {
51+
private static final Logger log = LogManager.getLogger(SageMakerClient.class);
52+
private final Cache<RegionAndSecrets, SageMakerRuntimeAsyncClient> existingClients = CacheBuilder.<
53+
RegionAndSecrets,
54+
SageMakerRuntimeAsyncClient>builder()
55+
.removalListener(removal -> removal.getValue().close())
56+
.setExpireAfterAccess(TimeValue.timeValueMinutes(15))
57+
.build();
58+
59+
private final CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory;
60+
private final ThreadPool threadPool;
61+
62+
public SageMakerClient(CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory, ThreadPool threadPool) {
63+
this.clientFactory = clientFactory;
64+
this.threadPool = threadPool;
65+
}
66+
67+
public void invoke(
68+
RegionAndSecrets regionAndSecrets,
69+
InvokeEndpointRequest request,
70+
TimeValue timeout,
71+
ActionListener<InvokeEndpointResponse> listener
72+
) {
73+
var asyncClient = getOrCreateClient(regionAndSecrets);
74+
asyncClient.invokeEndpoint(request)
75+
.orTimeout(timeout.seconds(), TimeUnit.SECONDS)
76+
.thenAcceptAsync(listener::onResponse, threadPool.executor(UTILITY_THREAD_POOL_NAME))
77+
.exceptionallyAsync(t -> failAndMaybeThrowError(t, listener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
78+
}
79+
80+
private Void failAndMaybeThrowError(Throwable t, ActionListener<?> listener) {
81+
if (t instanceof Exception e) {
82+
listener.onFailure(e);
83+
} else {
84+
ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
85+
log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker.");
86+
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker."));
87+
}
88+
return null; // Void
89+
}
90+
91+
public void invokeStream(
92+
RegionAndSecrets regionAndSecrets,
93+
InvokeEndpointWithResponseStreamRequest request,
94+
TimeValue timeout,
95+
ActionListener<SageMakerStream> listener
96+
) {
97+
var asyncClient = getOrCreateClient(regionAndSecrets);
98+
var runOnceListener = ActionListener.notifyOnce(listener);
99+
var responseStreamProcessor = new SageMakerStreamingResponseProcessor();
100+
var responseStreamListener = InvokeEndpointWithResponseStreamResponseHandler.builder()
101+
.onResponse(response -> runOnceListener.onResponse(new SageMakerStream(response, responseStreamProcessor)))
102+
.onEventStream(publisher -> responseStreamProcessor.setPublisher(FlowAdapters.toFlowPublisher(publisher)))
103+
.build();
104+
asyncClient.invokeEndpointWithResponseStream(request, responseStreamListener)
105+
.orTimeout(timeout.seconds(), TimeUnit.SECONDS)
106+
.exceptionallyAsync(t -> failAndMaybeThrowError(t, runOnceListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
107+
}
108+
109+
private SageMakerRuntimeAsyncClient getOrCreateClient(RegionAndSecrets regionAndSecrets) {
110+
try {
111+
return existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
112+
} catch (ExecutionException e) {
113+
throw new ElasticsearchException("failed to create SageMakerRuntime client", e);
114+
}
115+
}
116+
117+
@Override
118+
public void close() {
119+
existingClients.invalidateAll(); // will close each cached client
120+
}
121+
122+
public record RegionAndSecrets(String region, AwsSecretSettings secretSettings) {}
123+
124+
public static class Factory implements CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> {
125+
private final HttpSettings httpSettings;
126+
127+
public Factory(HttpSettings httpSettings) {
128+
this.httpSettings = httpSettings;
129+
}
130+
131+
@Override
132+
public SageMakerRuntimeAsyncClient load(RegionAndSecrets key) throws Exception {
133+
SpecialPermission.check();
134+
// TODO migrate to entitlements
135+
return AccessController.doPrivileged((PrivilegedExceptionAction<SageMakerRuntimeAsyncClient>) () -> {
136+
try (var accessKey = key.secretSettings().accessKey(); var secretKey = key.secretSettings().secretKey()) {
137+
var credentials = AwsBasicCredentials.create(accessKey.toString(), secretKey.toString());
138+
var credentialsProvider = StaticCredentialsProvider.create(credentials);
139+
var clientConfig = NettyNioAsyncHttpClient.builder().connectionTimeout(httpSettings.connectionTimeoutDuration());
140+
var override = ClientOverrideConfiguration.builder()
141+
// disable profileFile, user credentials will always come from the configured Model Secrets
142+
.defaultProfileFileSupplier(ProfileFile.aggregator()::build)
143+
.defaultProfileFile(ProfileFile.aggregator().build())
144+
.retryPolicy(retryPolicy -> retryPolicy.numRetries(3))
145+
.retryStrategy(retryStrategy -> retryStrategy.maxAttempts(3))
146+
.build();
147+
return SageMakerRuntimeAsyncClient.builder()
148+
.credentialsProvider(credentialsProvider)
149+
.region(Region.of(key.region()))
150+
.httpClientBuilder(clientConfig)
151+
.overrideConfiguration(override)
152+
.build();
153+
}
154+
});
155+
}
156+
}
157+
158+
private static class SageMakerStreamingResponseProcessor implements Flow.Publisher<ResponseStream> {
159+
private static final Logger log = LogManager.getLogger(SageMakerStreamingResponseProcessor.class);
160+
private final AtomicReference<Tuple<Flow.Publisher<ResponseStream>, Flow.Subscriber<? super ResponseStream>>> holder =
161+
new AtomicReference<>(null);
162+
163+
@Override
164+
public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) {
165+
if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) {
166+
log.debug("Subscriber connecting to publisher.");
167+
var publisher = holder.getAndSet(null).v1();
168+
publisher.subscribe(subscriber);
169+
} else {
170+
log.debug("Subscriber waiting for connection.");
171+
}
172+
}
173+
174+
private void setPublisher(Flow.Publisher<ResponseStream> publisher) {
175+
if (holder.compareAndSet(null, Tuple.tuple(publisher, null)) == false) {
176+
log.debug("Publisher connecting to subscriber.");
177+
var subscriber = holder.getAndSet(null).v2();
178+
publisher.subscribe(subscriber);
179+
} else {
180+
log.debug("Publisher waiting for connection.");
181+
}
182+
}
183+
}
184+
185+
public record SageMakerStream(InvokeEndpointWithResponseStreamResponse response, Flow.Publisher<ResponseStream> responseStream) {}
186+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.sagemaker;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.InputType;
12+
13+
import java.util.List;
14+
15+
public record SageMakerInferenceRequest(
16+
String query,
17+
@Nullable Boolean returnDocuments,
18+
@Nullable Integer topN,
19+
@Nullable List<String> input,
20+
boolean stream,
21+
InputType inputType
22+
) {}

0 commit comments

Comments
 (0)