77
88package org .elasticsearch .xpack .inference .services .sagemaker ;
99
10+ import org .elasticsearch .action .support .ListenerTimeouts ;
11+
12+ import org .elasticsearch .common .util .concurrent .FutureUtils ;
13+
1014import software .amazon .awssdk .auth .credentials .AwsBasicCredentials ;
1115import software .amazon .awssdk .auth .credentials .StaticCredentialsProvider ;
1216import software .amazon .awssdk .core .client .config .ClientOverrideConfiguration ;
2327
2428import org .apache .logging .log4j .LogManager ;
2529import org .apache .logging .log4j .Logger ;
26- import org .elasticsearch .ElasticsearchException ;
30+ import org .elasticsearch .ElasticsearchStatusException ;
2731import org .elasticsearch .ExceptionsHelper ;
2832import org .elasticsearch .SpecialPermission ;
2933import org .elasticsearch .action .ActionListener ;
3236import org .elasticsearch .common .cache .CacheLoader ;
3337import org .elasticsearch .core .TimeValue ;
3438import org .elasticsearch .core .Tuple ;
39+ import org .elasticsearch .rest .RestStatus ;
3540import org .elasticsearch .threadpool .ThreadPool ;
3641import org .elasticsearch .xpack .inference .common .amazon .AwsSecretSettings ;
3742import org .elasticsearch .xpack .inference .external .http .HttpSettings ;
4045import java .io .Closeable ;
4146import java .security .AccessController ;
4247import java .security .PrivilegedExceptionAction ;
48+ import java .util .concurrent .CompletableFuture ;
4349import java .util .concurrent .ExecutionException ;
4450import java .util .concurrent .Flow ;
45- import java .util .concurrent .TimeUnit ;
4651import java .util .concurrent .atomic .AtomicReference ;
4752
4853import static org .elasticsearch .xpack .inference .InferencePlugin .UTILITY_THREAD_POOL_NAME ;
@@ -70,11 +75,36 @@ public void invoke(
7075 TimeValue timeout ,
7176 ActionListener <InvokeEndpointResponse > listener
7277 ) {
73- var asyncClient = getOrCreateClient (regionAndSecrets );
74- asyncClient .invokeEndpoint (request )
75- .orTimeout (timeout .seconds (), TimeUnit .SECONDS )
76- .thenAcceptAsync (listener ::onResponse , threadPool .executor (UTILITY_THREAD_POOL_NAME ))
77- .exceptionallyAsync (t -> failAndMaybeThrowError (t , listener ), threadPool .executor (UTILITY_THREAD_POOL_NAME ));
78+ SageMakerRuntimeAsyncClient asyncClient ;
79+ try {
80+ asyncClient = existingClients .computeIfAbsent (regionAndSecrets , clientFactory );
81+ } catch (ExecutionException e ) {
82+ listener .onFailure (clientFailure (regionAndSecrets , e ));
83+ return ;
84+ }
85+
86+ var awsFuture = asyncClient .invokeEndpoint (request );
87+ var timeoutListener = ListenerTimeouts .wrapWithTimeout (
88+ threadPool ,
89+ timeout ,
90+ threadPool .executor (UTILITY_THREAD_POOL_NAME ),
91+ listener ,
92+ ignored -> {
93+ FutureUtils .cancel (awsFuture );
94+ listener .onFailure (new ElasticsearchStatusException ("Request timed out after [{}]" , RestStatus .REQUEST_TIMEOUT , timeout ));
95+ }
96+ );
97+ awsFuture .thenAcceptAsync (timeoutListener ::onResponse , threadPool .executor (UTILITY_THREAD_POOL_NAME ))
98+ .exceptionallyAsync (t -> failAndMaybeThrowError (t , timeoutListener ), threadPool .executor (UTILITY_THREAD_POOL_NAME ));
99+ }
100+
101+ private static Exception clientFailure (RegionAndSecrets regionAndSecrets , Exception cause ) {
102+ return new ElasticsearchStatusException (
103+ "failed to create SageMakerRuntime client for region [{}]" ,
104+ RestStatus .INTERNAL_SERVER_ERROR ,
105+ cause ,
106+ regionAndSecrets .region ()
107+ );
78108 }
79109
80110 private Void failAndMaybeThrowError (Throwable t , ActionListener <?> listener ) {
@@ -94,24 +124,35 @@ public void invokeStream(
94124 TimeValue timeout ,
95125 ActionListener <SageMakerStream > listener
96126 ) {
97- var asyncClient = getOrCreateClient (regionAndSecrets );
98- var runOnceListener = ActionListener .notifyOnce (listener );
127+ SageMakerRuntimeAsyncClient asyncClient ;
128+ try {
129+ asyncClient = existingClients .computeIfAbsent (regionAndSecrets , clientFactory );
130+ } catch (ExecutionException e ) {
131+ listener .onFailure (clientFailure (regionAndSecrets , e ));
132+ return ;
133+ }
134+
99135 var responseStreamProcessor = new SageMakerStreamingResponseProcessor ();
136+ var cancelAwsRequestListener = new AtomicReference <CompletableFuture <?>>();
137+ var timeoutListener = ListenerTimeouts .wrapWithTimeout (
138+ threadPool ,
139+ timeout ,
140+ threadPool .executor (UTILITY_THREAD_POOL_NAME ),
141+ listener ,
142+ ignored -> {
143+ FutureUtils .cancel (cancelAwsRequestListener .get ());
144+ listener .onFailure (new ElasticsearchStatusException ("Request timed out after [{}]" , RestStatus .REQUEST_TIMEOUT , timeout ));
145+ }
146+ );
147+ // To stay consistent with HTTP providers, we cancel the TimeoutListener onResponse because we are measuring the time it takes to
148+ // start receiving bytes.
100149 var responseStreamListener = InvokeEndpointWithResponseStreamResponseHandler .builder ()
101- .onResponse (response -> runOnceListener .onResponse (new SageMakerStream (response , responseStreamProcessor )))
150+ .onResponse (response -> timeoutListener .onResponse (new SageMakerStream (response , responseStreamProcessor )))
102151 .onEventStream (publisher -> responseStreamProcessor .setPublisher (FlowAdapters .toFlowPublisher (publisher )))
103152 .build ();
104- asyncClient .invokeEndpointWithResponseStream (request , responseStreamListener )
105- .orTimeout (timeout .seconds (), TimeUnit .SECONDS )
106- .exceptionallyAsync (t -> failAndMaybeThrowError (t , runOnceListener ), threadPool .executor (UTILITY_THREAD_POOL_NAME ));
107- }
108-
109- private SageMakerRuntimeAsyncClient getOrCreateClient (RegionAndSecrets regionAndSecrets ) {
110- try {
111- return existingClients .computeIfAbsent (regionAndSecrets , clientFactory );
112- } catch (ExecutionException e ) {
113- throw new ElasticsearchException ("failed to create SageMakerRuntime client" , e );
114- }
153+ var awsFuture = asyncClient .invokeEndpointWithResponseStream (request , responseStreamListener );
154+ cancelAwsRequestListener .set (awsFuture );
155+ awsFuture .exceptionallyAsync (t -> failAndMaybeThrowError (t , timeoutListener ), threadPool .executor (UTILITY_THREAD_POOL_NAME ));
115156 }
116157
117158 @ Override
@@ -133,24 +174,25 @@ public SageMakerRuntimeAsyncClient load(RegionAndSecrets key) throws Exception {
133174 SpecialPermission .check ();
134175 // TODO migrate to entitlements
135176 return AccessController .doPrivileged ((PrivilegedExceptionAction <SageMakerRuntimeAsyncClient >) () -> {
136- try (var accessKey = key .secretSettings ().accessKey (); var secretKey = key .secretSettings ().secretKey ()) {
137- var credentials = AwsBasicCredentials .create (accessKey .toString (), secretKey .toString ());
138- var credentialsProvider = StaticCredentialsProvider .create (credentials );
139- var clientConfig = NettyNioAsyncHttpClient .builder ().connectionTimeout (httpSettings .connectionTimeoutDuration ());
140- var override = ClientOverrideConfiguration .builder ()
141- // disable profileFile, user credentials will always come from the configured Model Secrets
142- .defaultProfileFileSupplier (ProfileFile .aggregator ()::build )
143- .defaultProfileFile (ProfileFile .aggregator ().build ())
144- .retryPolicy (retryPolicy -> retryPolicy .numRetries (3 ))
145- .retryStrategy (retryStrategy -> retryStrategy .maxAttempts (3 ))
146- .build ();
147- return SageMakerRuntimeAsyncClient .builder ()
148- .credentialsProvider (credentialsProvider )
149- .region (Region .of (key .region ()))
150- .httpClientBuilder (clientConfig )
151- .overrideConfiguration (override )
152- .build ();
153- }
177+ var credentials = AwsBasicCredentials .create (
178+ key .secretSettings ().accessKey ().toString (),
179+ key .secretSettings ().secretKey ().toString ()
180+ );
181+ var credentialsProvider = StaticCredentialsProvider .create (credentials );
182+ var clientConfig = NettyNioAsyncHttpClient .builder ().connectionTimeout (httpSettings .connectionTimeoutDuration ());
183+ var override = ClientOverrideConfiguration .builder ()
184+ // disable profileFile, user credentials will always come from the configured Model Secrets
185+ .defaultProfileFileSupplier (ProfileFile .aggregator ()::build )
186+ .defaultProfileFile (ProfileFile .aggregator ().build ())
187+ .retryPolicy (retryPolicy -> retryPolicy .numRetries (3 ))
188+ .retryStrategy (retryStrategy -> retryStrategy .maxAttempts (3 ))
189+ .build ();
190+ return SageMakerRuntimeAsyncClient .builder ()
191+ .credentialsProvider (credentialsProvider )
192+ .region (Region .of (key .region ()))
193+ .httpClientBuilder (clientConfig )
194+ .overrideConfiguration (override )
195+ .build ();
154196 });
155197 }
156198 }
0 commit comments