2828import org .elasticsearch .inference .UnifiedCompletionRequest ;
2929import org .elasticsearch .rest .RestStatus ;
3030import org .elasticsearch .threadpool .ThreadPool ;
31+ import org .elasticsearch .xpack .inference .InferencePlugin ;
3132import org .elasticsearch .xpack .inference .chunking .EmbeddingRequestChunker ;
33+ import org .elasticsearch .xpack .inference .services .ServiceComponents ;
3234import org .elasticsearch .xpack .inference .services .sagemaker .model .SageMakerModel ;
3335import org .elasticsearch .xpack .inference .services .sagemaker .model .SageMakerModelBuilder ;
3436import org .elasticsearch .xpack .inference .services .sagemaker .schema .SageMakerSchemas ;
@@ -55,13 +57,15 @@ public class SageMakerService implements InferenceService {
5557 private final SageMakerSchemas schemas ;
5658 private final ThreadPool threadPool ;
5759 private final LazyInitializable <InferenceServiceConfiguration , RuntimeException > configuration ;
60+ private final ServiceComponents serviceComponents ;
5861
5962 public SageMakerService (
6063 SageMakerModelBuilder modelBuilder ,
6164 SageMakerClient client ,
6265 SageMakerSchemas schemas ,
6366 ThreadPool threadPool ,
64- CheckedSupplier <Map <String , SettingsConfiguration >, RuntimeException > configurationMap
67+ CheckedSupplier <Map <String , SettingsConfiguration >, RuntimeException > configurationMap ,
68+ ServiceComponents serviceComponents
6569 ) {
6670 this .modelBuilder = modelBuilder ;
6771 this .client = client ;
@@ -74,6 +78,7 @@ public SageMakerService(
7478 .setConfigurations (configurationMap .get ())
7579 .build ()
7680 );
81+ this .serviceComponents = serviceComponents ;
7782 }
7883
7984 @ Override
@@ -146,6 +151,10 @@ public void infer(
146151
147152 var inferenceRequest = new SageMakerInferenceRequest (query , returnDocuments , topN , input , stream , inputType );
148153
154+ if (timeout == null ) {
155+ timeout = serviceComponents .clusterService ().getClusterSettings ().get (InferencePlugin .SEMANTIC_TEXT_INFERENCE_TIMEOUT );
156+ }
157+
149158 try {
150159 var sageMakerModel = ((SageMakerModel ) model ).override (taskSettings );
151160 var regionAndSecrets = regionAndSecrets (sageMakerModel );
@@ -156,7 +165,7 @@ public void infer(
156165 client .invokeStream (
157166 regionAndSecrets ,
158167 request ,
159- timeout != null ? timeout : DEFAULT_TIMEOUT ,
168+ timeout ,
160169 ActionListener .wrap (
161170 response -> listener .onResponse (schema .streamResponse (sageMakerModel , response )),
162171 e -> listener .onFailure (schema .error (sageMakerModel , e ))
@@ -168,7 +177,7 @@ public void infer(
168177 client .invoke (
169178 regionAndSecrets ,
170179 request ,
171- timeout != null ? timeout : DEFAULT_TIMEOUT ,
180+ timeout ,
172181 ActionListener .wrap (
173182 response -> listener .onResponse (schema .response (sageMakerModel , response , threadPool .getThreadContext ())),
174183 e -> listener .onFailure (schema .error (sageMakerModel , e ))
0 commit comments