1212import org .elasticsearch .action .support .PlainActionFuture ;
1313import org .elasticsearch .common .settings .Settings ;
1414import org .elasticsearch .core .TimeValue ;
15+ import org .elasticsearch .inference .InferenceServiceResults ;
1516import org .elasticsearch .inference .TaskType ;
1617import org .elasticsearch .test .ESTestCase ;
1718import org .elasticsearch .test .http .MockResponse ;
1819import org .elasticsearch .test .http .MockWebServer ;
1920import org .elasticsearch .threadpool .ThreadPool ;
21+ import org .elasticsearch .xpack .core .inference .results .ChatCompletionResults ;
2022import org .elasticsearch .xpack .inference .external .http .HttpClientManager ;
23+ import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSender ;
2124import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSenderTests ;
25+ import org .elasticsearch .xpack .inference .external .http .sender .Sender ;
2226import org .elasticsearch .xpack .inference .logging .ThrottlerManager ;
2327import org .junit .After ;
2428import org .junit .Before ;
2529import org .mockito .ArgumentCaptor ;
2630
2731import java .io .IOException ;
2832import java .util .EnumSet ;
33+ import java .util .List ;
2934import java .util .concurrent .TimeUnit ;
3035
3136import static org .elasticsearch .xpack .inference .Utils .inferenceUtilityPool ;
3439import static org .elasticsearch .xpack .inference .external .http .retry .RetryingHttpSender .MAX_RETIES ;
3540import static org .hamcrest .Matchers .is ;
3641import static org .mockito .ArgumentMatchers .any ;
42+ import static org .mockito .Mockito .doAnswer ;
3743import static org .mockito .Mockito .mock ;
3844import static org .mockito .Mockito .times ;
3945import static org .mockito .Mockito .verify ;
4046import static org .mockito .Mockito .verifyNoMoreInteractions ;
47+ import static org .mockito .Mockito .when ;
4148
4249public class ElasticInferenceServiceAuthorizationHandlerTests extends ESTestCase {
4350 private static final TimeValue TIMEOUT = new TimeValue (30 , TimeUnit .SECONDS );
@@ -185,6 +192,7 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
185192 verifyNoMoreInteractions (logger );
186193 }
187194 }
195+
188196 @ SuppressWarnings ("unchecked" )
189197 public void testGetAuthorization_OnResponseCalledOnce () throws IllegalStateException {
190198 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
@@ -195,64 +203,63 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IllegalStateExcep
195203 ActionListener <ElasticInferenceServiceAuthorization > listener = mock (ActionListener .class );
196204
197205 String responseJson = """
198- {
199- "models": [
200- {
201- "model_name": "model-a",
202- "task_types": ["embedding /text/sparse", "chat/completion "]
203- }
204- ]
205- }
206- """ ;
206+ {
207+ "models": [
208+ {
209+ "model_name": "model-a",
210+ "task_types": ["embed /text/sparse", "chat"]
211+ }
212+ ]
213+ }
214+ """ ;
207215
208216 webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
209217
210218 authHandler .getAuthorization (listener , senderFactory .createSender ());
211- authHandler .waitForAuthRequestCompletion (TimeValue . timeValueSeconds ( 1 ) );
219+ authHandler .waitForAuthRequestCompletion (TIMEOUT );
212220
213221 verify (listener , times (1 )).onResponse (any ());
222+ var loggerArgsCaptor = ArgumentCaptor .forClass (String .class );
223+ verify (logger , times (1 )).debug (loggerArgsCaptor .capture ());
224+
225+ var message = loggerArgsCaptor .getValue ();
226+ assertThat (message , is ("Retrieving authorization information from the Elastic Inference Service." ));
214227 verifyNoMoreInteractions (logger );
215228 }
216229
217- public void testGetAuthorization_InvalidResponse () throws IOException {
218- var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
219- var eisGatewayUrl = getUrl (webServer );
220- var logger = mock (Logger .class );
221- var authHandler = new ElasticInferenceServiceAuthorizationHandler (eisGatewayUrl , threadPool , logger );
230+ public void testGetAuthorization_InvalidResponse () throws IOException {
231+ var senderMock = mock (Sender .class );
232+ var senderFactory = mock (HttpRequestSender .Factory .class );
233+ when (senderFactory .createSender ()).thenReturn (senderMock );
222234
223- try (var sender = senderFactory .createSender ()) {
224- String responseJson = """
225- {
226- "completion": [
227- {
228- "result": "some result 1"
229- },
230- {
231- "result": "some result 2"
232- }
233- ]
234- }
235- """ ;
235+ doAnswer (invocationOnMock -> {
236+ ActionListener <InferenceServiceResults > listener = invocationOnMock .getArgument (4 );
237+ listener .onResponse (new ChatCompletionResults (List .of (new ChatCompletionResults .Result ("awesome" ))));
238+ return Void .TYPE ;
239+ }).when (senderMock ).sendWithoutQueuing (any (), any (), any (), any (), any ());
236240
237- webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
241+ var logger = mock (Logger .class );
242+ var authHandler = new ElasticInferenceServiceAuthorizationHandler ("abc" , threadPool , logger );
243+
244+ try (var sender = senderFactory .createSender ()) {
245+ PlainActionFuture <ElasticInferenceServiceAuthorization > listener = new PlainActionFuture <>();
238246
239- PlainActionFuture < ElasticInferenceServiceAuthorization > listener = new PlainActionFuture <>( );
240- authHandler . getAuthorization ( listener , sender );
247+ authHandler . getAuthorization ( listener , sender );
248+ var result = listener . actionGet ( TIMEOUT );
241249
242- var authResponse = listener .actionGet (TIMEOUT );
243- assertTrue (authResponse .enabledTaskTypes ().isEmpty ());
244- assertFalse (authResponse .isEnabled ());
250+ assertThat (result , is (ElasticInferenceServiceAuthorization .newDisabledService ()));
245251
246- var loggerArgsCaptor = ArgumentCaptor .forClass (String .class );
247- verify (logger ).warn (loggerArgsCaptor .capture ());
248- var message = loggerArgsCaptor .getValue ();
249- assertThat (
250- message ,
251- is (
252- "Failed to retrieve the authorization information from the Elastic Inference Service."
253- + " Received an invalid response type: InferenceServiceResults"
254- ));
255- }
252+ var loggerArgsCaptor = ArgumentCaptor .forClass (String .class );
253+ verify (logger ).warn (loggerArgsCaptor .capture ());
254+ var message = loggerArgsCaptor .getValue ();
255+ assertThat (
256+ message ,
257+ is (
258+ "Failed to retrieve the authorization information from the Elastic Inference Service."
259+ + " Received an invalid response type: ChatCompletionResults"
260+ )
261+ );
262+ }
256263
257264 }
258265
0 commit comments