2424import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSenderTests ;
2525import org .elasticsearch .xpack .inference .external .http .sender .Sender ;
2626import org .elasticsearch .xpack .inference .logging .ThrottlerManager ;
27+ import org .elasticsearch .xpack .inference .services .elastic .ElasticInferenceServiceSettingsTests ;
2728import org .junit .After ;
2829import org .junit .Before ;
2930import org .mockito .ArgumentCaptor ;
4445import static org .mockito .Mockito .mock ;
4546import static org .mockito .Mockito .times ;
4647import static org .mockito .Mockito .verify ;
47- import static org .mockito .Mockito .verifyNoMoreInteractions ;
4848import static org .mockito .Mockito .when ;
4949
5050public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends ESTestCase {
@@ -71,7 +71,11 @@ public void shutdown() throws IOException {
7171 public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull () throws Exception {
7272 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
7373 var logger = mock (Logger .class );
74- var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler (null , threadPool , logger );
74+ var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler (
75+ ElasticInferenceServiceSettingsTests .create (null ),
76+ threadPool ,
77+ logger
78+ );
7579
7680 try (var sender = senderFactory .createSender ()) {
7781 PlainActionFuture <ElasticInferenceServiceAuthorizationModel > listener = new PlainActionFuture <>();
@@ -93,7 +97,11 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws
9397 public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty () throws Exception {
9498 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
9599 var logger = mock (Logger .class );
96- var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler ("" , threadPool , logger );
100+ var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler (
101+ ElasticInferenceServiceSettingsTests .create ("" ),
102+ threadPool ,
103+ logger
104+ );
97105
98106 try (var sender = senderFactory .createSender ()) {
99107 PlainActionFuture <ElasticInferenceServiceAuthorizationModel > listener = new PlainActionFuture <>();
@@ -116,7 +124,11 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep
116124 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
117125 var eisGatewayUrl = getUrl (webServer );
118126 var logger = mock (Logger .class );
119- var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler (eisGatewayUrl , threadPool , logger );
127+ var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler (
128+ ElasticInferenceServiceSettingsTests .create (eisGatewayUrl ),
129+ threadPool ,
130+ logger
131+ );
120132
121133 try (var sender = senderFactory .createSender ()) {
122134 String responseJson = """
@@ -167,7 +179,11 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
167179 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
168180 var eisGatewayUrl = getUrl (webServer );
169181 var logger = mock (Logger .class );
170- var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler (eisGatewayUrl , threadPool , logger );
182+ var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler (
183+ ElasticInferenceServiceSettingsTests .create (eisGatewayUrl ),
184+ threadPool ,
185+ logger
186+ );
171187
172188 try (var sender = senderFactory .createSender ()) {
173189 String responseJson = """
@@ -196,7 +212,48 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
196212
197213 var message = loggerArgsCaptor .getValue ();
198214 assertThat (message , is ("Retrieving authorization information from the Elastic Inference Service." ));
199- verifyNoMoreInteractions (logger );
215+ }
216+ }
217+
218+ public void testGetAuthorization_PerformsADnsLookup () throws IOException {
219+ var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
220+ var eisGatewayUrl = getUrl (webServer );
221+ var logger = mock (Logger .class );
222+ when (logger .isDebugEnabled ()).thenReturn (true );
223+ var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler (
224+ ElasticInferenceServiceSettingsTests .createWithUrlAndDnsLookup (eisGatewayUrl ),
225+ threadPool ,
226+ logger
227+ );
228+
229+ try (var sender = senderFactory .createSender ()) {
230+ String responseJson = """
231+ {
232+ "models": [
233+ {
234+ "model_name": "model-a",
235+ "task_types": ["embed/text/sparse", "chat"]
236+ }
237+ ]
238+ }
239+ """ ;
240+
241+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
242+
243+ PlainActionFuture <ElasticInferenceServiceAuthorizationModel > listener = new PlainActionFuture <>();
244+ authHandler .getAuthorization (listener , sender );
245+
246+ var authResponse = listener .actionGet (TIMEOUT );
247+ assertThat (authResponse .getAuthorizedTaskTypes (), is (EnumSet .of (TaskType .SPARSE_EMBEDDING , TaskType .CHAT_COMPLETION )));
248+ assertThat (authResponse .getAuthorizedModelIds (), is (Set .of ("model-a" )));
249+ assertTrue (authResponse .isAuthorized ());
250+
251+ var loggerArgsCaptor = ArgumentCaptor .forClass (String .class );
252+ verify (logger , times (4 )).debug (loggerArgsCaptor .capture ());
253+
254+ var messages = loggerArgsCaptor .getAllValues ();
255+ assertThat (messages .size (), is (4 ));
256+ assertThat (messages .get (3 ), is ("Gateway address: 127.0.0.1" ));
200257 }
201258 }
202259
@@ -205,7 +262,11 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException {
205262 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
206263 var eisGatewayUrl = getUrl (webServer );
207264 var logger = mock (Logger .class );
208- var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler (eisGatewayUrl , threadPool , logger );
265+ var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler (
266+ ElasticInferenceServiceSettingsTests .create (eisGatewayUrl ),
267+ threadPool ,
268+ logger
269+ );
209270
210271 ActionListener <ElasticInferenceServiceAuthorizationModel > listener = mock (ActionListener .class );
211272 String responseJson = """
@@ -230,7 +291,6 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException {
230291
231292 var message = loggerArgsCaptor .getValue ();
232293 assertThat (message , is ("Retrieving authorization information from the Elastic Inference Service." ));
233- verifyNoMoreInteractions (logger );
234294 }
235295 }
236296
@@ -246,7 +306,11 @@ public void testGetAuthorization_InvalidResponse() throws IOException {
246306 }).when (senderMock ).sendWithoutQueuing (any (), any (), any (), any (), any ());
247307
248308 var logger = mock (Logger .class );
249- var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler ("abc" , threadPool , logger );
309+ var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler (
310+ ElasticInferenceServiceSettingsTests .create ("abc" ),
311+ threadPool ,
312+ logger
313+ );
250314
251315 try (var sender = senderFactory .createSender ()) {
252316 PlainActionFuture <ElasticInferenceServiceAuthorizationModel > listener = new PlainActionFuture <>();
0 commit comments