Skip to content

Commit faebbef

Browse files
propagate clustersettings to sageMaker
1 parent b8a5405 commit faebbef

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,8 @@ public Collection<?> createComponents(PluginServices services) {
329329
),
330330
sageMakerSchemas,
331331
services.threadPool(),
332-
sageMakerConfigurations::getOrCompute
332+
sageMakerConfigurations::getOrCompute,
333+
serviceComponents.get()
333334
)
334335
)
335336
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
import org.elasticsearch.inference.UnifiedCompletionRequest;
2929
import org.elasticsearch.rest.RestStatus;
3030
import org.elasticsearch.threadpool.ThreadPool;
31+
import org.elasticsearch.xpack.inference.InferencePlugin;
3132
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
33+
import org.elasticsearch.xpack.inference.services.ServiceComponents;
3234
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
3335
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
3436
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
@@ -55,13 +57,15 @@ public class SageMakerService implements InferenceService {
5557
private final SageMakerSchemas schemas;
5658
private final ThreadPool threadPool;
5759
private final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration;
60+
private final ServiceComponents serviceComponents;
5861

5962
public SageMakerService(
6063
SageMakerModelBuilder modelBuilder,
6164
SageMakerClient client,
6265
SageMakerSchemas schemas,
6366
ThreadPool threadPool,
64-
CheckedSupplier<Map<String, SettingsConfiguration>, RuntimeException> configurationMap
67+
CheckedSupplier<Map<String, SettingsConfiguration>, RuntimeException> configurationMap,
68+
ServiceComponents serviceComponents
6569
) {
6670
this.modelBuilder = modelBuilder;
6771
this.client = client;
@@ -74,6 +78,7 @@ public SageMakerService(
7478
.setConfigurations(configurationMap.get())
7579
.build()
7680
);
81+
this.serviceComponents = serviceComponents;
7782
}
7883

7984
@Override
@@ -146,6 +151,10 @@ public void infer(
146151

147152
var inferenceRequest = new SageMakerInferenceRequest(query, returnDocuments, topN, input, stream, inputType);
148153

154+
if (timeout == null) {
155+
timeout = serviceComponents.clusterService().getClusterSettings().get(InferencePlugin.SEMANTIC_TEXT_INFERENCE_TIMEOUT);
156+
}
157+
149158
try {
150159
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
151160
var regionAndSecrets = regionAndSecrets(sageMakerModel);
@@ -156,7 +165,7 @@ public void infer(
156165
client.invokeStream(
157166
regionAndSecrets,
158167
request,
159-
timeout != null ? timeout : DEFAULT_TIMEOUT,
168+
timeout,
160169
ActionListener.wrap(
161170
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)),
162171
e -> listener.onFailure(schema.error(sageMakerModel, e))
@@ -168,7 +177,7 @@ public void infer(
168177
client.invoke(
169178
regionAndSecrets,
170179
request,
171-
timeout != null ? timeout : DEFAULT_TIMEOUT,
180+
timeout,
172181
ActionListener.wrap(
173182
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())),
174183
e -> listener.onFailure(schema.error(sageMakerModel, e))

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.services.sagemaker;
99

10+
import org.elasticsearch.xpack.inference.services.ServiceComponents;
11+
1012
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
1113
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
1214

@@ -75,16 +77,18 @@ public class SageMakerServiceTests extends ESTestCase {
7577
private SageMakerClient client;
7678
private SageMakerSchemas schemas;
7779
private SageMakerService sageMakerService;
80+
private ServiceComponents serviceComponents;
7881

7982
@Before
8083
public void init() {
8184
modelBuilder = mock();
8285
client = mock();
8386
schemas = mock();
87+
serviceComponents = mock();
8488
ThreadPool threadPool = mock();
8589
when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
8690
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
87-
sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of);
91+
sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of, serviceComponents);
8892
}
8993

9094
public void testSupportedTaskTypes() {

0 commit comments

Comments
 (0)