2525import org .elasticsearch .xpack .inference .registry .ModelRegistry ;
2626import org .elasticsearch .xpack .inference .services .elastic .ElasticInferenceService ;
2727import org .elasticsearch .xpack .inference .services .elastic .ElasticInferenceServiceSettings ;
28- import org .elasticsearch .xpack .inference .services .elastic .InternalPreconfiguredEndpoints ;
2928import org .elasticsearch .xpack .inference .services .elastic .authorization .AuthorizationPoller ;
3029import org .elasticsearch .xpack .inference .services .elastic .authorization .AuthorizationTaskExecutor ;
3130import org .elasticsearch .xpack .inference .services .elastic .ccm .CCMSettings ;
31+ import org .elasticsearch .xpack .inference .services .elastic .response .ElasticInferenceServiceAuthorizationResponseEntityTests ;
3232import org .junit .After ;
3333import org .junit .AfterClass ;
3434import org .junit .Before ;
3737import java .io .IOException ;
3838import java .util .Collection ;
3939import java .util .List ;
40+ import java .util .Set ;
4041import java .util .concurrent .atomic .AtomicReference ;
4142import java .util .function .Function ;
4243import java .util .stream .Collectors ;
4344
4445import static org .elasticsearch .xpack .inference .external .http .Utils .getUrl ;
46+ import static org .elasticsearch .xpack .inference .services .elastic .response .ElasticInferenceServiceAuthorizationResponseEntityTests .EIS_EMPTY_RESPONSE ;
47+ import static org .elasticsearch .xpack .inference .services .elastic .response .ElasticInferenceServiceAuthorizationResponseEntityTests .ELSER_V2_ENDPOINT_ID ;
48+ import static org .elasticsearch .xpack .inference .services .elastic .response .ElasticInferenceServiceAuthorizationResponseEntityTests .JINA_EMBED_V3_ENDPOINT_ID ;
49+ import static org .elasticsearch .xpack .inference .services .elastic .response .ElasticInferenceServiceAuthorizationResponseEntityTests .RAINBOW_SPRINKLES_ENDPOINT_ID ;
50+ import static org .elasticsearch .xpack .inference .services .elastic .response .ElasticInferenceServiceAuthorizationResponseEntityTests .RERANK_V1_ENDPOINT_ID ;
51+ import static org .elasticsearch .xpack .inference .services .elastic .response .ElasticInferenceServiceAuthorizationResponseEntityTests .getEisRainbowSprinklesAuthorizationResponse ;
4552import static org .hamcrest .Matchers .empty ;
4653import static org .hamcrest .Matchers .is ;
4754import static org .hamcrest .Matchers .not ;
4855
4956public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase {
50- public static final String AUTH_TASK_ACTION = AuthorizationPoller .TASK_NAME + "[c]" ;
5157
52- public static final String EMPTY_AUTH_RESPONSE = """
53- {
54- "models": [
55- ]
56- }
57- """ ;
58-
59- public static final String AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE = """
60- {
61- "models": [
62- {
63- "model_name": "rainbow-sprinkles",
64- "task_types": ["chat"]
65- }
66- ]
67- }
68- """ ;
58+ public static final Set <String > EIS_PRECONFIGURED_ENDPOINT_IDS = Set .of (
59+ RAINBOW_SPRINKLES_ENDPOINT_ID ,
60+ ELSER_V2_ENDPOINT_ID ,
61+ JINA_EMBED_V3_ENDPOINT_ID ,
62+ RERANK_V1_ENDPOINT_ID
63+ );
64+
65+ public static final String AUTH_TASK_ACTION = AuthorizationPoller .TASK_NAME + "[c]" ;
6966
7067 private static final MockWebServer webServer = new MockWebServer ();
7168 private static String gatewayUrl ;
69+ private static String chatCompletionResponseBody ;
7270
7371 private ModelRegistry modelRegistry ;
7472 private AuthorizationTaskExecutor authorizationTaskExecutor ;
@@ -77,7 +75,8 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase {
7775 public static void initClass () throws IOException {
7876 webServer .start ();
7977 gatewayUrl = getUrl (webServer );
80- webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (EMPTY_AUTH_RESPONSE ));
78+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (EIS_EMPTY_RESPONSE ));
79+ chatCompletionResponseBody = getEisRainbowSprinklesAuthorizationResponse (gatewayUrl ).responseJson ();
8180 }
8281
8382 @ Before
@@ -94,7 +93,7 @@ public void shutdown() {
9493 static void removeEisPreconfiguredEndpoints (ModelRegistry modelRegistry ) {
9594 // Delete all the eis preconfigured endpoints
9695 var listener = new PlainActionFuture <Boolean >();
97- modelRegistry .deleteModels (InternalPreconfiguredEndpoints . EIS_PRECONFIGURED_ENDPOINT_IDS , listener );
96+ modelRegistry .deleteModels (EIS_PRECONFIGURED_ENDPOINT_IDS , listener );
9897 listener .actionGet (TimeValue .THIRTY_SECONDS );
9998 }
10099
@@ -123,7 +122,7 @@ protected Collection<Class<? extends Plugin>> getPlugins() {
123122 public void testCreatesEisChatCompletionEndpoint () throws Exception {
124123 assertNoAuthorizedEisEndpoints ();
125124
126- webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE ));
125+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (chatCompletionResponseBody ));
127126 restartPollingTaskAndWaitForAuthResponse ();
128127
129128 assertChatCompletionEndpointExists ();
@@ -149,7 +148,7 @@ static void assertNoAuthorizedEisEndpoints(
149148 var eisEndpoints = getEisEndpoints (modelRegistry );
150149 assertThat (eisEndpoints , empty ());
151150
152- for (String eisPreconfiguredEndpoints : InternalPreconfiguredEndpoints . EIS_PRECONFIGURED_ENDPOINT_IDS ) {
151+ for (String eisPreconfiguredEndpoints : EIS_PRECONFIGURED_ENDPOINT_IDS ) {
153152 assertFalse (modelRegistry .containsPreconfiguredInferenceEndpointId (eisPreconfiguredEndpoints ));
154153 }
155154 }
@@ -228,13 +227,13 @@ static void cancelAuthorizationTask(AdminClient adminClient) throws Exception {
228227 public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized () throws Exception {
229228 assertNoAuthorizedEisEndpoints ();
230229
231- webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE ));
230+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (chatCompletionResponseBody ));
232231 restartPollingTaskAndWaitForAuthResponse ();
233232
234233 assertChatCompletionEndpointExists ();
235234
236235 // Simulate that the model is no longer authorized
237- webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (EMPTY_AUTH_RESPONSE ));
236+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (EIS_EMPTY_RESPONSE ));
238237 restartPollingTaskAndWaitForAuthResponse ();
239238
240239 assertChatCompletionEndpointExists ();
@@ -250,55 +249,45 @@ static void assertChatCompletionEndpointExists(ModelRegistry modelRegistry) {
250249
251250 var rainbowSprinklesModel = eisEndpoints .get (0 );
252251 assertChatCompletionUnparsedModel (rainbowSprinklesModel );
253- assertTrue (
254- modelRegistry .containsPreconfiguredInferenceEndpointId (InternalPreconfiguredEndpoints .DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 )
255- );
252+ assertTrue (modelRegistry .containsPreconfiguredInferenceEndpointId (RAINBOW_SPRINKLES_ENDPOINT_ID ));
256253 }
257254
258255 static void assertChatCompletionUnparsedModel (UnparsedModel rainbowSprinklesModel ) {
259256 assertThat (rainbowSprinklesModel .taskType (), is (TaskType .CHAT_COMPLETION ));
260257 assertThat (rainbowSprinklesModel .service (), is (ElasticInferenceService .NAME ));
261- assertThat (rainbowSprinklesModel .inferenceEntityId (), is (InternalPreconfiguredEndpoints . DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 ));
258+ assertThat (rainbowSprinklesModel .inferenceEntityId (), is (RAINBOW_SPRINKLES_ENDPOINT_ID ));
262259 }
263260
264261 public void testCreatesChatCompletion_AndThenCreatesTextEmbedding () throws Exception {
265262 assertNoAuthorizedEisEndpoints ();
266263
267- webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE ));
264+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (chatCompletionResponseBody ));
268265 restartPollingTaskAndWaitForAuthResponse ();
269266
270267 assertChatCompletionEndpointExists ();
271268
272269 // Simulate that the model is no longer authorized
273- webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (EMPTY_AUTH_RESPONSE ));
270+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (EIS_EMPTY_RESPONSE ));
274271 restartPollingTaskAndWaitForAuthResponse ();
275272
276273 assertChatCompletionEndpointExists ();
277274
278275 // Simulate that a text embedding model is now authorized
279- var authorizedTextEmbeddingResponse = """
280- {
281- "models": [
282- {
283- "model_name": "jina-embeddings-v3",
284- "task_types": ["embed/text/dense"]
285- }
286- ]
287- }
288- """ ;
289-
290- webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (authorizedTextEmbeddingResponse ));
276+ var jinaEmbedResponseBody = ElasticInferenceServiceAuthorizationResponseEntityTests .getEisJinaEmbedAuthorizationResponse (gatewayUrl )
277+ .responseJson ();
278+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (jinaEmbedResponseBody ));
279+
291280 restartPollingTaskAndWaitForAuthResponse ();
292281
293282 var eisEndpoints = getEisEndpoints ().stream ().collect (Collectors .toMap (UnparsedModel ::inferenceEntityId , Function .identity ()));
294283 assertThat (eisEndpoints .size (), is (2 ));
295284
296- assertTrue (eisEndpoints .containsKey (InternalPreconfiguredEndpoints . DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 ));
297- assertChatCompletionUnparsedModel (eisEndpoints .get (InternalPreconfiguredEndpoints . DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 ));
285+ assertTrue (eisEndpoints .containsKey (RAINBOW_SPRINKLES_ENDPOINT_ID ));
286+ assertChatCompletionUnparsedModel (eisEndpoints .get (RAINBOW_SPRINKLES_ENDPOINT_ID ));
298287
299- assertTrue (eisEndpoints .containsKey (InternalPreconfiguredEndpoints . DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID ));
288+ assertTrue (eisEndpoints .containsKey (JINA_EMBED_V3_ENDPOINT_ID ));
300289
301- var textEmbeddingEndpoint = eisEndpoints .get (InternalPreconfiguredEndpoints . DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID );
290+ var textEmbeddingEndpoint = eisEndpoints .get (JINA_EMBED_V3_ENDPOINT_ID );
302291 assertThat (textEmbeddingEndpoint .taskType (), is (TaskType .TEXT_EMBEDDING ));
303292 assertThat (textEmbeddingEndpoint .service (), is (ElasticInferenceService .NAME ));
304293 }
@@ -307,7 +296,7 @@ public void testRestartsTaskAfterAbort() throws Exception {
307296 // Ensure the task is created and we get an initial authorization response
308297 assertNoAuthorizedEisEndpoints ();
309298
310- webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (EMPTY_AUTH_RESPONSE ));
299+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (EIS_EMPTY_RESPONSE ));
311300 // Abort the task and ensure it is restarted
312301 restartPollingTaskAndWaitForAuthResponse ();
313302 }
0 commit comments