2323import org .elasticsearch .xpack .inference .external .http .sender .Sender ;
2424import org .elasticsearch .xpack .inference .services .elastic .ElasticInferenceServiceResponseHandler ;
2525import org .elasticsearch .xpack .inference .services .elastic .ccm .AuthenticationFactory ;
26+ import org .elasticsearch .xpack .inference .services .elastic .ccm .CCMFeature ;
27+ import org .elasticsearch .xpack .inference .services .elastic .ccm .CCMService ;
2628import org .elasticsearch .xpack .inference .services .elastic .request .ElasticInferenceServiceAuthorizationRequest ;
2729import org .elasticsearch .xpack .inference .services .elastic .response .ElasticInferenceServiceAuthorizationResponseEntity ;
2830import org .elasticsearch .xpack .inference .telemetry .TraceContext ;
@@ -56,17 +58,23 @@ private static ResponseHandler createAuthResponseHandler() {
5658 private final Logger logger ;
5759 private final CountDownLatch requestCompleteLatch = new CountDownLatch (1 );
5860 private final AuthenticationFactory authFactory ;
61+ private final CCMFeature ccmFeature ;
62+ private final CCMService ccmService ;
5963
6064 public ElasticInferenceServiceAuthorizationRequestHandler (
6165 @ Nullable String baseUrl ,
6266 ThreadPool threadPool ,
63- AuthenticationFactory authFactory
67+ AuthenticationFactory authFactory ,
68+ CCMFeature ccmFeature ,
69+ CCMService ccmService
6470 ) {
6571 this (
6672 baseUrl ,
6773 Objects .requireNonNull (threadPool ),
6874 LogManager .getLogger (ElasticInferenceServiceAuthorizationRequestHandler .class ),
69- authFactory
75+ authFactory ,
76+ ccmFeature ,
77+ ccmService
7078 );
7179 }
7280
@@ -75,12 +83,16 @@ public ElasticInferenceServiceAuthorizationRequestHandler(
7583 @ Nullable String baseUrl ,
7684 ThreadPool threadPool ,
7785 Logger logger ,
78- AuthenticationFactory authFactory
86+ AuthenticationFactory authFactory ,
87+ CCMFeature ccmFeature ,
88+ CCMService ccmService
7989 ) {
8090 this .baseUrl = baseUrl ;
8191 this .threadPool = Objects .requireNonNull (threadPool );
8292 this .logger = Objects .requireNonNull (logger );
8393 this .authFactory = Objects .requireNonNull (authFactory );
94+ this .ccmFeature = Objects .requireNonNull (ccmFeature );
95+ this .ccmService = Objects .requireNonNull (ccmService );
8496 }
8597
8698 /**
@@ -89,57 +101,98 @@ public ElasticInferenceServiceAuthorizationRequestHandler(
89101 * @param sender a {@link Sender} for making the request to the Elastic Inference Service
90102 */
91103 public void getAuthorization (ActionListener <ElasticInferenceServiceAuthorizationModel > listener , Sender sender ) {
92- try {
93- logger .debug ("Retrieving authorization information from the Elastic Inference Service." );
104+ getAuthorizationHelper (listener , sender , false );
105+ }
106+
107+ /**
108+ * Skips retrieving the authorization information from Elastic Inference Service if CCM is not configured,
109+ * and it is a supported environment, a supported environment would be on-prem or ECK. ECH and serverless are not supported
110+ * environments for CCM (because they can already connect to EIS). For environments where CCM is not supported, it will always
111+ * attempt to retrieve the authorization information.
112+ * @param listener a listener to receive the response
113+ * @param sender a {@link Sender} for making the request to the Elastic Inference Service
114+ */
115+ public void getAuthorizationSkippingIfCcmNotConfigured (
116+ ActionListener <ElasticInferenceServiceAuthorizationModel > listener ,
117+ Sender sender
118+ ) {
119+ getAuthorizationHelper (listener , sender , true );
120+ }
94121
95- if (Strings .isNullOrEmpty (baseUrl )) {
96- logger .debug ("The base URL for the authorization service is not valid, rejecting authorization." );
97- listener .onResponse (ElasticInferenceServiceAuthorizationModel .unauthorized ());
122+ private void getAuthorizationHelper (
123+ ActionListener <ElasticInferenceServiceAuthorizationModel > listener ,
124+ Sender sender ,
125+ boolean skipIfCcmNotConfigured
126+ ) {
127+ var countdownListener = ActionListener .runAfter (listener , requestCompleteLatch ::countDown );
128+
129+ try {
130+ if (skipIfCcmNotConfigured == false || ccmFeature .isCcmSupportedEnvironment () == false ) {
131+ retrieveAuthorizationInformation (countdownListener , sender );
98132 return ;
99133 }
100134
101- var handleFailuresListener = listener .delegateResponse ((authModelListener , e ) -> {
102- // unwrap because it's likely a retry exception
103- var exception = ExceptionsHelper .unwrapCause (e );
104-
105- logger .warn (Strings .format (FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s" , exception ), exception );
106- authModelListener .onFailure (e );
135+ var isCcmEnabledListener = ActionListener .<Boolean >wrap (response -> {
136+ if (response == null || response == false ) {
137+ logger .debug ("CCM is not configured, skipping authorization request to Elastic Inference Service." );
138+ countdownListener .onResponse (ElasticInferenceServiceAuthorizationModel .unauthorized ());
139+ } else {
140+ retrieveAuthorizationInformation (countdownListener , sender );
141+ }
142+ }, e -> {
143+ logger .atDebug ().withThrowable (e ).log ("Failed to determine if CCM is configured, returning unauthorized." );
144+ countdownListener .onResponse (ElasticInferenceServiceAuthorizationModel .unauthorized ());
107145 });
108146
109- SubscribableListener .newForked (sender ::startAsynchronously )
110- .andThen (authFactory ::getAuthenticationApplier )
111- .<InferenceServiceResults >andThen ((authListener , authApplier ) -> {
112- var requestMetadata = extractRequestMetadataFromThreadContext (threadPool .getThreadContext ());
113- var request = new ElasticInferenceServiceAuthorizationRequest (
114- baseUrl ,
115- getCurrentTraceInfo (),
116- requestMetadata ,
117- authApplier
118- );
119- sender .sendWithoutQueuing (logger , request , AUTH_RESPONSE_HANDLER , DEFAULT_AUTH_TIMEOUT , authListener );
120- })
121- .andThenApply (authResult -> {
122- if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity ) {
123- logger .debug (() -> Strings .format ("Received authorization information from gateway %s" , authResponseEntity ));
124- return ElasticInferenceServiceAuthorizationModel .of (authResponseEntity , baseUrl );
125- }
126-
127- var errorMessage = Strings .format (
128- "%s Received an invalid response type from the Elastic Inference Service: %s" ,
129- FAILED_TO_RETRIEVE_MESSAGE ,
130- authResult .getClass ().getSimpleName ()
131- );
132-
133- logger .warn (errorMessage );
134- throw new ElasticsearchException (errorMessage );
135- })
136- .addListener (ActionListener .runAfter (handleFailuresListener , requestCompleteLatch ::countDown ));
147+ ccmService .isEnabled (isCcmEnabledListener );
137148 } catch (Exception e ) {
138149 logger .warn (Strings .format ("Retrieving the authorization information encountered an exception: %s" , e ));
139- requestCompleteLatch . countDown ( );
150+ countdownListener . onFailure ( e );
140151 }
141152 }
142153
154+ private void retrieveAuthorizationInformation (ActionListener <ElasticInferenceServiceAuthorizationModel > listener , Sender sender ) {
155+ logger .debug ("Retrieving authorization information from the Elastic Inference Service." );
156+
157+ if (Strings .isNullOrEmpty (baseUrl )) {
158+ logger .debug ("The base URL for the authorization service is not valid, rejecting authorization." );
159+ listener .onResponse (ElasticInferenceServiceAuthorizationModel .unauthorized ());
160+ return ;
161+ }
162+
163+ var handleFailuresListener = listener .delegateResponse ((authModelListener , e ) -> {
164+ // unwrap because it's likely a retry exception
165+ var exception = ExceptionsHelper .unwrapCause (e );
166+
167+ logger .warn (Strings .format (FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s" , exception ), exception );
168+ authModelListener .onFailure (e );
169+ });
170+
171+ SubscribableListener .newForked (sender ::startAsynchronously )
172+ .andThen (authFactory ::getAuthenticationApplier )
173+ .<InferenceServiceResults >andThen ((authListener , authApplier ) -> {
174+ var requestMetadata = extractRequestMetadataFromThreadContext (threadPool .getThreadContext ());
175+ var request = new ElasticInferenceServiceAuthorizationRequest (baseUrl , getCurrentTraceInfo (), requestMetadata , authApplier );
176+ sender .sendWithoutQueuing (logger , request , AUTH_RESPONSE_HANDLER , DEFAULT_AUTH_TIMEOUT , authListener );
177+ })
178+ .andThenApply (authResult -> {
179+ if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity ) {
180+ logger .debug (() -> Strings .format ("Received authorization information from gateway %s" , authResponseEntity ));
181+ return ElasticInferenceServiceAuthorizationModel .of (authResponseEntity , baseUrl );
182+ }
183+
184+ var errorMessage = Strings .format (
185+ "%s Received an invalid response type from the Elastic Inference Service: %s" ,
186+ FAILED_TO_RETRIEVE_MESSAGE ,
187+ authResult .getClass ().getSimpleName ()
188+ );
189+
190+ logger .warn (errorMessage );
191+ throw new ElasticsearchException (errorMessage );
192+ })
193+ .addListener (handleFailuresListener );
194+ }
195+
143196 private TraceContext getCurrentTraceInfo () {
144197 var traceParent = threadPool .getThreadContext ().getHeader (Task .TRACE_PARENT_HTTP_HEADER );
145198 var traceState = threadPool .getThreadContext ().getHeader (Task .TRACE_STATE );
0 commit comments