77
88package org .elasticsearch .xpack .inference .integration ;
99
10+ import org .elasticsearch .action .support .PlainActionFuture ;
1011import org .elasticsearch .common .settings .Settings ;
12+ import org .elasticsearch .core .TimeValue ;
13+ import org .elasticsearch .inference .TaskType ;
14+ import org .elasticsearch .inference .UnparsedModel ;
1115import org .elasticsearch .plugins .Plugin ;
1216import org .elasticsearch .reindex .ReindexPlugin ;
1317import org .elasticsearch .test .ESSingleNodeTestCase ;
18+ import org .elasticsearch .test .http .MockResponse ;
1419import org .elasticsearch .test .http .MockWebServer ;
1520import org .elasticsearch .threadpool .ThreadPool ;
1621import org .elasticsearch .xpack .inference .LocalStateInferencePlugin ;
1722import org .elasticsearch .xpack .inference .registry .ModelRegistry ;
23+ import org .elasticsearch .xpack .inference .services .elastic .ElasticInferenceService ;
1824import org .elasticsearch .xpack .inference .services .elastic .ElasticInferenceServiceSettings ;
25+ import org .elasticsearch .xpack .inference .services .elastic .InternalPreconfiguredEndpoints ;
26+ import org .elasticsearch .xpack .inference .services .elastic .authorization .AuthorizationTaskExecutor ;
1927import org .junit .After ;
28+ import org .junit .AfterClass ;
2029import org .junit .Before ;
2130import org .junit .BeforeClass ;
2231
2332import java .io .IOException ;
2433import java .util .Collection ;
34+ import java .util .List ;
35+ import java .util .function .Function ;
36+ import java .util .stream .Collectors ;
2537
2638import static org .elasticsearch .xpack .inference .Utils .inferenceUtilityExecutors ;
2739import static org .elasticsearch .xpack .inference .external .http .Utils .getUrl ;
40+ import static org .hamcrest .Matchers .empty ;
41+ import static org .hamcrest .Matchers .is ;
2842
2943public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase {
3044
45+ private static final String EMPTY_AUTH_RESPONSE = """
46+ {
47+ "models": [
48+ ]
49+ }
50+ """ ;
51+
52+ private static final String AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE = """
53+ {
54+ "models": [
55+ {
56+ "model_name": "rainbow-sprinkles",
57+ "task_types": ["chat"]
58+ }
59+ ]
60+ }
61+ """ ;
62+
3163 private static final MockWebServer webServer = new MockWebServer ();
3264 private static String gatewayUrl ;
3365
3466 private ModelRegistry modelRegistry ;
3567 private ThreadPool threadPool ;
68+ private AuthorizationTaskExecutor authorizationTaskExecutor ;
3669
3770 @ BeforeClass
3871 public static void initClass () throws IOException {
3972 webServer .start ();
4073 gatewayUrl = getUrl (webServer );
41- // TODO add response to the web server to return no authorized models
74+ webServer . enqueue ( new MockResponse (). setResponseCode ( 200 ). setBody ( EMPTY_AUTH_RESPONSE ));
4275 }
4376
4477 @ Before
4578 public void createComponents () {
4679 threadPool = createThreadPool (inferenceUtilityExecutors ());
4780 modelRegistry = node ().injector ().getInstance (ModelRegistry .class );
81+ authorizationTaskExecutor = node ().injector ().getInstance (AuthorizationTaskExecutor .class );
4882 }
4983
5084 @ After
5185 public void shutdown () {
86+ // Delete all the eis preconfigured endpoints
87+ var listener = new PlainActionFuture <Boolean >();
88+ modelRegistry .deleteModels (InternalPreconfiguredEndpoints .EIS_PRECONFIGURED_ENDPOINT_IDS , listener );
89+ listener .actionGet (TimeValue .THIRTY_SECONDS );
90+
5291 terminate (threadPool );
92+ }
93+
94+ @ AfterClass
95+ public static void cleanUpClass () {
5396 webServer .close ();
5497 }
5598
@@ -66,12 +109,126 @@ protected Collection<Class<? extends Plugin>> getPlugins() {
66109 return pluginList (ReindexPlugin .class , LocalStateInferencePlugin .class );
67110 }
68111
69- public void testCreateEndpoints () {
70- // verify that no models are authorized
71- // add request to return an authorized model
72- // cancel the task
73- // ensure the task is recreated?
74- // verify that the authorized model is present
75- fail ("Not implemented yet" );
112+ public void testCreatesEisChatCompletionEndpoint () throws Exception {
113+ assertNoAuthorizedEisEndpoints ();
114+
115+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE ));
116+ waitForNewAuthorizationResponse ();
117+
118+ assertChatCompletionEndpointExists ();
119+ }
120+
121+ private void assertNoAuthorizedEisEndpoints () throws Exception {
122+ assertBusy (() -> {
123+ var newPoller = authorizationTaskExecutor .getCurrentPollerTask ();
124+ assertNotNull (newPoller );
125+ newPoller .waitForAuthorizationToComplete (TimeValue .THIRTY_SECONDS );
126+ });
127+
128+ var eisEndpoints = getEisEndpoints ();
129+ assertThat (eisEndpoints , empty ());
130+ }
131+
132+ private List <UnparsedModel > getEisEndpoints () {
133+ var listener = new PlainActionFuture <List <UnparsedModel >>();
134+ modelRegistry .getAllModels (false , listener );
135+
136+ var endpoints = listener .actionGet (TimeValue .THIRTY_SECONDS );
137+ return endpoints .stream ().filter (m -> m .service ().equals (ElasticInferenceService .NAME )).toList ();
138+ }
139+
140+ private void waitForNewAuthorizationResponse () throws Exception {
141+ var taskListener = new PlainActionFuture <Void >();
142+
143+ authorizationTaskExecutor .abortTask (TimeValue .THIRTY_SECONDS , taskListener );
144+ // Ensure that the listener doesn't return a failure
145+ assertNull (taskListener .actionGet (TimeValue .THIRTY_SECONDS ));
146+
147+ // wait for the new task to be recreated
148+ assertBusy (() -> {
149+ var newPoller = authorizationTaskExecutor .getCurrentPollerTask ();
150+ assertNotNull (newPoller );
151+ newPoller .waitForAuthorizationToComplete (TimeValue .THIRTY_SECONDS );
152+ });
153+ }
154+
155+ public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized () throws Exception {
156+ assertNoAuthorizedEisEndpoints ();
157+
158+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE ));
159+ waitForNewAuthorizationResponse ();
160+
161+ assertChatCompletionEndpointExists ();
162+
163+ // Simulate that the model is no longer authorized
164+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (EMPTY_AUTH_RESPONSE ));
165+ waitForNewAuthorizationResponse ();
166+
167+ assertChatCompletionEndpointExists ();
168+ }
169+
170+ private void assertChatCompletionEndpointExists () {
171+ var eisEndpoints = getEisEndpoints ();
172+ assertThat (eisEndpoints .size (), is (1 ));
173+
174+ var rainbowSprinklesModel = eisEndpoints .get (0 );
175+ assertChatCompletionUnparsedModel (rainbowSprinklesModel );
176+ }
177+
178+ private void assertChatCompletionUnparsedModel (UnparsedModel rainbowSprinklesModel ) {
179+ assertThat (rainbowSprinklesModel .taskType (), is (TaskType .CHAT_COMPLETION ));
180+ assertThat (rainbowSprinklesModel .service (), is (ElasticInferenceService .NAME ));
181+ assertThat (rainbowSprinklesModel .inferenceEntityId (), is (InternalPreconfiguredEndpoints .DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 ));
182+ }
183+
184+ public void testCreatesChatCompletion_AndThenCreatesTextEmbedding () throws Exception {
185+ assertNoAuthorizedEisEndpoints ();
186+
187+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE ));
188+ waitForNewAuthorizationResponse ();
189+
190+ assertChatCompletionEndpointExists ();
191+
192+ // Simulate that the model is no longer authorized
193+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (EMPTY_AUTH_RESPONSE ));
194+ waitForNewAuthorizationResponse ();
195+
196+ assertChatCompletionEndpointExists ();
197+
198+ // Simulate that a text embedding model is now authorized
199+ var authorizedTextEmbeddingResponse = """
200+ {
201+ "models": [
202+ {
203+ "model_name": "multilingual-embed-v1",
204+ "task_types": ["embed/text/dense"]
205+ }
206+ ]
207+ }
208+ """ ;
209+
210+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (authorizedTextEmbeddingResponse ));
211+ waitForNewAuthorizationResponse ();
212+
213+ var eisEndpoints = getEisEndpoints ().stream ().collect (Collectors .toMap (UnparsedModel ::inferenceEntityId , Function .identity ()));
214+ assertThat (eisEndpoints .size (), is (2 ));
215+
216+ assertTrue (eisEndpoints .containsKey (InternalPreconfiguredEndpoints .DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 ));
217+ assertChatCompletionUnparsedModel (eisEndpoints .get (InternalPreconfiguredEndpoints .DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 ));
218+
219+ assertTrue (eisEndpoints .containsKey (InternalPreconfiguredEndpoints .DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID ));
220+
221+ var textEmbeddingEndpoint = eisEndpoints .get (InternalPreconfiguredEndpoints .DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID );
222+ assertThat (textEmbeddingEndpoint .taskType (), is (TaskType .TEXT_EMBEDDING ));
223+ assertThat (textEmbeddingEndpoint .service (), is (ElasticInferenceService .NAME ));
224+ }
225+
226+ public void testRestartsTaskAfterAbort () throws Exception {
227+ // Ensure the task is created and we get an initial authorization response
228+ assertNoAuthorizedEisEndpoints ();
229+
230+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (EMPTY_AUTH_RESPONSE ));
231+ // Abort the task and ensure it is restarted
232+ waitForNewAuthorizationResponse ();
76233 }
77234}
0 commit comments