Skip to content

Commit 01eaabd

Browse files
committed
more tests; address comments
1 parent f4ddb12 commit 01eaabd

24 files changed

+1258
-260
lines changed

x-pack/plugin/inference/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ tasks.named("dependencyLicenses").configure {
141141
mapping from: /json-utils.*/, to: 'aws-sdk-2'
142142
mapping from: /endpoints-spi.*/, to: 'aws-sdk-2'
143143
mapping from: /bedrockruntime.*/, to: 'aws-sdk-2'
144+
mapping from: /sagemakerruntime.*/, to: 'aws-sdk-2'
144145
mapping from: /netty-nio-client/, to: 'aws-sdk-2'
145146
/* Cannot use REGEX to match netty-* because netty-nio-client is an AWS package */
146147
mapping from: /netty-buffer/, to: 'netty'

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClient.java

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
package org.elasticsearch.xpack.inference.services.sagemaker;
99

10+
import org.elasticsearch.action.support.ListenerTimeouts;
11+
12+
import org.elasticsearch.common.util.concurrent.FutureUtils;
13+
1014
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
1115
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
1216
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
@@ -23,7 +27,7 @@
2327

2428
import org.apache.logging.log4j.LogManager;
2529
import org.apache.logging.log4j.Logger;
26-
import org.elasticsearch.ElasticsearchException;
30+
import org.elasticsearch.ElasticsearchStatusException;
2731
import org.elasticsearch.ExceptionsHelper;
2832
import org.elasticsearch.SpecialPermission;
2933
import org.elasticsearch.action.ActionListener;
@@ -32,6 +36,7 @@
3236
import org.elasticsearch.common.cache.CacheLoader;
3337
import org.elasticsearch.core.TimeValue;
3438
import org.elasticsearch.core.Tuple;
39+
import org.elasticsearch.rest.RestStatus;
3540
import org.elasticsearch.threadpool.ThreadPool;
3641
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
3742
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
@@ -40,9 +45,9 @@
4045
import java.io.Closeable;
4146
import java.security.AccessController;
4247
import java.security.PrivilegedExceptionAction;
48+
import java.util.concurrent.CompletableFuture;
4349
import java.util.concurrent.ExecutionException;
4450
import java.util.concurrent.Flow;
45-
import java.util.concurrent.TimeUnit;
4651
import java.util.concurrent.atomic.AtomicReference;
4752

4853
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
@@ -70,11 +75,36 @@ public void invoke(
7075
TimeValue timeout,
7176
ActionListener<InvokeEndpointResponse> listener
7277
) {
73-
var asyncClient = getOrCreateClient(regionAndSecrets);
74-
asyncClient.invokeEndpoint(request)
75-
.orTimeout(timeout.seconds(), TimeUnit.SECONDS)
76-
.thenAcceptAsync(listener::onResponse, threadPool.executor(UTILITY_THREAD_POOL_NAME))
77-
.exceptionallyAsync(t -> failAndMaybeThrowError(t, listener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
78+
SageMakerRuntimeAsyncClient asyncClient;
79+
try {
80+
asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
81+
} catch (ExecutionException e) {
82+
listener.onFailure(clientFailure(regionAndSecrets, e));
83+
return;
84+
}
85+
86+
var awsFuture = asyncClient.invokeEndpoint(request);
87+
var timeoutListener = ListenerTimeouts.wrapWithTimeout(
88+
threadPool,
89+
timeout,
90+
threadPool.executor(UTILITY_THREAD_POOL_NAME),
91+
listener,
92+
ignored -> {
93+
FutureUtils.cancel(awsFuture);
94+
listener.onFailure(new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout));
95+
}
96+
);
97+
awsFuture.thenAcceptAsync(timeoutListener::onResponse, threadPool.executor(UTILITY_THREAD_POOL_NAME))
98+
.exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
99+
}
100+
101+
private static Exception clientFailure(RegionAndSecrets regionAndSecrets, Exception cause) {
102+
return new ElasticsearchStatusException(
103+
"failed to create SageMakerRuntime client for region [{}]",
104+
RestStatus.INTERNAL_SERVER_ERROR,
105+
cause,
106+
regionAndSecrets.region()
107+
);
78108
}
79109

80110
private Void failAndMaybeThrowError(Throwable t, ActionListener<?> listener) {
@@ -94,24 +124,35 @@ public void invokeStream(
94124
TimeValue timeout,
95125
ActionListener<SageMakerStream> listener
96126
) {
97-
var asyncClient = getOrCreateClient(regionAndSecrets);
98-
var runOnceListener = ActionListener.notifyOnce(listener);
127+
SageMakerRuntimeAsyncClient asyncClient;
128+
try {
129+
asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
130+
} catch (ExecutionException e) {
131+
listener.onFailure(clientFailure(regionAndSecrets, e));
132+
return;
133+
}
134+
99135
var responseStreamProcessor = new SageMakerStreamingResponseProcessor();
136+
var cancelAwsRequestListener = new AtomicReference<CompletableFuture<?>>();
137+
var timeoutListener = ListenerTimeouts.wrapWithTimeout(
138+
threadPool,
139+
timeout,
140+
threadPool.executor(UTILITY_THREAD_POOL_NAME),
141+
listener,
142+
ignored -> {
143+
FutureUtils.cancel(cancelAwsRequestListener.get());
144+
listener.onFailure(new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout));
145+
}
146+
);
147+
// To stay consistent with HTTP providers, we cancel the TimeoutListener onResponse because we are measuring the time it takes to
148+
// start receiving bytes.
100149
var responseStreamListener = InvokeEndpointWithResponseStreamResponseHandler.builder()
101-
.onResponse(response -> runOnceListener.onResponse(new SageMakerStream(response, responseStreamProcessor)))
150+
.onResponse(response -> timeoutListener.onResponse(new SageMakerStream(response, responseStreamProcessor)))
102151
.onEventStream(publisher -> responseStreamProcessor.setPublisher(FlowAdapters.toFlowPublisher(publisher)))
103152
.build();
104-
asyncClient.invokeEndpointWithResponseStream(request, responseStreamListener)
105-
.orTimeout(timeout.seconds(), TimeUnit.SECONDS)
106-
.exceptionallyAsync(t -> failAndMaybeThrowError(t, runOnceListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
107-
}
108-
109-
private SageMakerRuntimeAsyncClient getOrCreateClient(RegionAndSecrets regionAndSecrets) {
110-
try {
111-
return existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
112-
} catch (ExecutionException e) {
113-
throw new ElasticsearchException("failed to create SageMakerRuntime client", e);
114-
}
153+
var awsFuture = asyncClient.invokeEndpointWithResponseStream(request, responseStreamListener);
154+
cancelAwsRequestListener.set(awsFuture);
155+
awsFuture.exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
115156
}
116157

117158
@Override
@@ -133,24 +174,25 @@ public SageMakerRuntimeAsyncClient load(RegionAndSecrets key) throws Exception {
133174
SpecialPermission.check();
134175
// TODO migrate to entitlements
135176
return AccessController.doPrivileged((PrivilegedExceptionAction<SageMakerRuntimeAsyncClient>) () -> {
136-
try (var accessKey = key.secretSettings().accessKey(); var secretKey = key.secretSettings().secretKey()) {
137-
var credentials = AwsBasicCredentials.create(accessKey.toString(), secretKey.toString());
138-
var credentialsProvider = StaticCredentialsProvider.create(credentials);
139-
var clientConfig = NettyNioAsyncHttpClient.builder().connectionTimeout(httpSettings.connectionTimeoutDuration());
140-
var override = ClientOverrideConfiguration.builder()
141-
// disable profileFile, user credentials will always come from the configured Model Secrets
142-
.defaultProfileFileSupplier(ProfileFile.aggregator()::build)
143-
.defaultProfileFile(ProfileFile.aggregator().build())
144-
.retryPolicy(retryPolicy -> retryPolicy.numRetries(3))
145-
.retryStrategy(retryStrategy -> retryStrategy.maxAttempts(3))
146-
.build();
147-
return SageMakerRuntimeAsyncClient.builder()
148-
.credentialsProvider(credentialsProvider)
149-
.region(Region.of(key.region()))
150-
.httpClientBuilder(clientConfig)
151-
.overrideConfiguration(override)
152-
.build();
153-
}
177+
var credentials = AwsBasicCredentials.create(
178+
key.secretSettings().accessKey().toString(),
179+
key.secretSettings().secretKey().toString()
180+
);
181+
var credentialsProvider = StaticCredentialsProvider.create(credentials);
182+
var clientConfig = NettyNioAsyncHttpClient.builder().connectionTimeout(httpSettings.connectionTimeoutDuration());
183+
var override = ClientOverrideConfiguration.builder()
184+
// disable profileFile, user credentials will always come from the configured Model Secrets
185+
.defaultProfileFileSupplier(ProfileFile.aggregator()::build)
186+
.defaultProfileFile(ProfileFile.aggregator().build())
187+
.retryPolicy(retryPolicy -> retryPolicy.numRetries(3))
188+
.retryStrategy(retryStrategy -> retryStrategy.maxAttempts(3))
189+
.build();
190+
return SageMakerRuntimeAsyncClient.builder()
191+
.credentialsProvider(credentialsProvider)
192+
.region(Region.of(key.region()))
193+
.httpClientBuilder(clientConfig)
194+
.overrideConfiguration(override)
195+
.build();
154196
});
155197
}
156198
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerInferenceRequest.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,18 @@
1111
import org.elasticsearch.inference.InputType;
1212

1313
import java.util.List;
14+
import java.util.Objects;
1415

1516
public record SageMakerInferenceRequest(
16-
String query,
17+
@Nullable String query,
1718
@Nullable Boolean returnDocuments,
1819
@Nullable Integer topN,
19-
@Nullable List<String> input,
20+
List<String> input,
2021
boolean stream,
2122
InputType inputType
22-
) {}
23+
) {
24+
public SageMakerInferenceRequest {
25+
Objects.requireNonNull(input);
26+
Objects.requireNonNull(inputType);
27+
}
28+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ public void infer(
149149
);
150150
}
151151
} catch (Exception e) {
152-
listener.onFailure(e);
152+
listener.onFailure(internalFailure(model, e));
153153
}
154154
}
155155

@@ -165,6 +165,19 @@ private SageMakerClient.RegionAndSecrets regionAndSecrets(SageMakerModel model)
165165
return new SageMakerClient.RegionAndSecrets(model.region(), secrets.get());
166166
}
167167

168+
private static ElasticsearchStatusException internalFailure(Model model, Exception cause) {
169+
if (cause instanceof ElasticsearchStatusException ese) {
170+
return ese;
171+
} else {
172+
return new ElasticsearchStatusException(
173+
"Failed to call SageMaker for inference id [{}].",
174+
RestStatus.INTERNAL_SERVER_ERROR,
175+
cause,
176+
model.getInferenceEntityId()
177+
);
178+
}
179+
}
180+
168181
@Override
169182
public void unifiedCompletionInfer(
170183
Model model,
@@ -181,18 +194,18 @@ public void unifiedCompletionInfer(
181194
var sageMakerModel = (SageMakerModel) model;
182195
var regionAndSecrets = regionAndSecrets(sageMakerModel);
183196
var schema = schemas.streamSchemaFor(sageMakerModel);
184-
var sagemakerRequest = schema.unifiedStreamRequest(sageMakerModel, request);
197+
var sagemakerRequest = schema.chatCompletionStreamRequest(sageMakerModel, request);
185198
client.invokeStream(
186199
regionAndSecrets,
187200
sagemakerRequest,
188201
timeout,
189202
ActionListener.wrap(
190-
response -> listener.onResponse(schema.unifiedStreamResponse(sageMakerModel, response)),
191-
e -> listener.onFailure(schema.unifiedError(sageMakerModel, e))
203+
response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)),
204+
e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e))
192205
)
193206
);
194207
} catch (Exception e) {
195-
listener.onFailure(e);
208+
listener.onFailure(internalFailure(model, e));
196209
}
197210
}
198211

@@ -210,36 +223,41 @@ public void chunkedInfer(
210223
listener.onFailure(createInvalidModelException(model));
211224
return;
212225
}
226+
try {
227+
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
228+
var batchedRequests = new EmbeddingRequestChunker<>(
229+
input,
230+
sageMakerModel.batchSize().orElse(DEFAULT_BATCH_SIZE),
231+
sageMakerModel.getConfigurations().getChunkingSettings()
232+
).batchRequestsWithListeners(listener);
213233

214-
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
215-
var batchedRequests = new EmbeddingRequestChunker<>(
216-
input,
217-
sageMakerModel.batchSize().orElse(DEFAULT_BATCH_SIZE),
218-
sageMakerModel.getConfigurations().getChunkingSettings()
219-
).batchRequestsWithListeners(listener);
220-
221-
var subscribableListener = SubscribableListener.newSucceeded(null);
222-
for (var request : batchedRequests) {
223-
subscribableListener = subscribableListener.andThen(
224-
threadPool.executor(UTILITY_THREAD_POOL_NAME),
225-
threadPool.getThreadContext(),
226-
(l, ignored) -> infer(
227-
sageMakerModel,
228-
null, // no query when chunking?,
229-
null, // no return docs while chunking?
230-
null, // no topN while chunking?
231-
request.batch().inputs().get(),
232-
false, // we never stream when chunking
233-
null, // since we pass sageMakerModel as the model, we already overwrote the model with the task settings
234-
inputType,
235-
timeout,
236-
ActionListener.runAfter(request.listener(), () -> l.onResponse(null))
237-
)
234+
var subscribableListener = SubscribableListener.newSucceeded(null);
235+
for (var request : batchedRequests) {
236+
subscribableListener = subscribableListener.andThen(
237+
threadPool.executor(UTILITY_THREAD_POOL_NAME),
238+
threadPool.getThreadContext(),
239+
(l, ignored) -> infer(
240+
sageMakerModel,
241+
query,
242+
null, // no return docs while chunking?
243+
null, // no topN while chunking?
244+
request.batch().inputs().get(),
245+
false, // we never stream when chunking
246+
null, // since we pass sageMakerModel as the model, we already overwrote the model with the task settings
247+
inputType,
248+
timeout,
249+
ActionListener.runAfter(request.listener(), () -> l.onResponse(null))
250+
)
251+
);
252+
}
253+
// if there were any errors trying to create the SubscribableListener chain, then forward that to the listener
254+
// otherwise, BatchRequestAndListener will handle forwarding errors from the infer method
255+
subscribableListener.addListener(
256+
ActionListener.noop().delegateResponse((l, e) -> listener.onFailure(internalFailure(model, e)))
238257
);
258+
} catch (Exception e) {
259+
listener.onFailure(internalFailure(model, e));
239260
}
240-
// if there were any errors trying to create the SubscribableListener chain, then forward that to the listener
241-
// otherwise, BatchRequestAndListener will handle forwarding errors from the infer method
242-
subscribableListener.addListener(ActionListener.noop().delegateResponse((l, e) -> listener.onFailure(e)));
243261
}
244262

245263
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ public static List<NamedWriteableRegistry.Entry> namedWriteables() {
125125
);
126126
}
127127

128-
public SageMakerStoredServiceSchema extraServiceSettings() {
129-
return serviceSettings.extraServiceSettings();
128+
public SageMakerStoredServiceSchema apiServiceSettings() {
129+
return serviceSettings.apiServiceSettings();
130130
}
131131

132-
public SageMakerStoredTaskSchema extraTaskSettings() {
133-
return taskSettings.extraTaskSettings();
132+
public SageMakerStoredTaskSchema apiTaskSettings() {
133+
return taskSettings.apiTaskSettings();
134134
}
135135
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public SageMakerModel fromRequest(String inferenceEntityId, TaskType taskType, S
3939
var taskSettingsMap = removeFromMapOrDefaultEmpty(requestMap, ModelConfigurations.TASK_SETTINGS);
4040
var taskSettings = SageMakerTaskSettings.fromMap(
4141
taskSettingsMap,
42-
schema.extraTaskSettings(taskSettingsMap, validationException),
42+
schema.apiTaskSettings(taskSettingsMap, validationException),
4343
validationException
4444
);
4545

@@ -80,7 +80,7 @@ public SageMakerModel fromStorage(
8080

8181
var taskSettings = SageMakerTaskSettings.fromMap(
8282
taskSettingsMap,
83-
schema.extraTaskSettings(taskSettingsMap, validationException),
83+
schema.apiTaskSettings(taskSettingsMap, validationException),
8484
validationException
8585
);
8686

0 commit comments

Comments
 (0)