|
25 | 25 | import org.elasticsearch.xpack.inference.registry.ModelRegistry; |
26 | 26 | import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; |
27 | 27 | import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; |
28 | | -import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; |
29 | 28 | import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; |
30 | 29 | import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; |
31 | 30 | import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings; |
|
37 | 36 | import java.io.IOException; |
38 | 37 | import java.util.Collection; |
39 | 38 | import java.util.List; |
| 39 | +import java.util.Set; |
40 | 40 | import java.util.concurrent.atomic.AtomicReference; |
41 | 41 | import java.util.function.Function; |
42 | 42 | import java.util.stream.Collectors; |
|
47 | 47 | import static org.hamcrest.Matchers.not; |
48 | 48 |
|
49 | 49 | public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { |
| 50 | + |
| 51 | + // rainbow-sprinkles |
| 52 | + public static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = ".rainbow-sprinkles-elastic"; |
| 53 | + |
| 54 | + // gp-llm-v2 |
| 55 | + public static final String GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-chat_completion"; |
| 56 | + |
| 57 | + // elser-2 |
| 58 | + public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = ".elser-2-elastic"; |
| 59 | + |
| 60 | + // multilingual-text-embed |
| 61 | + public static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = ".jina-embeddings-v3"; |
| 62 | + |
| 63 | + // rerank-v1 |
| 64 | + public static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = ".elastic-rerank-v1"; |
| 65 | + |
| 66 | + public static final Set<String> EIS_PRECONFIGURED_ENDPOINT_IDS = Set.of( |
| 67 | + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, |
| 68 | + GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, |
| 69 | + DEFAULT_ELSER_ENDPOINT_ID_V2, |
| 70 | + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, |
| 71 | + DEFAULT_RERANK_ENDPOINT_ID_V1 |
| 72 | + ); |
| 73 | + |
50 | 74 | public static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; |
51 | 75 |
|
52 | 76 | public static final String EMPTY_AUTH_RESPONSE = """ |
@@ -94,7 +118,7 @@ public void shutdown() { |
94 | 118 | static void removeEisPreconfiguredEndpoints(ModelRegistry modelRegistry) { |
95 | 119 | // Delete all the eis preconfigured endpoints |
96 | 120 | var listener = new PlainActionFuture<Boolean>(); |
97 | | - modelRegistry.deleteModels(InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS, listener); |
| 121 | + modelRegistry.deleteModels(EIS_PRECONFIGURED_ENDPOINT_IDS, listener); |
98 | 122 | listener.actionGet(TimeValue.THIRTY_SECONDS); |
99 | 123 | } |
100 | 124 |
|
@@ -149,7 +173,7 @@ static void assertNoAuthorizedEisEndpoints( |
149 | 173 | var eisEndpoints = getEisEndpoints(modelRegistry); |
150 | 174 | assertThat(eisEndpoints, empty()); |
151 | 175 |
|
152 | | - for (String eisPreconfiguredEndpoints : InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS) { |
| 176 | + for (String eisPreconfiguredEndpoints : EIS_PRECONFIGURED_ENDPOINT_IDS) { |
153 | 177 | assertFalse(modelRegistry.containsPreconfiguredInferenceEndpointId(eisPreconfiguredEndpoints)); |
154 | 178 | } |
155 | 179 | } |
@@ -250,15 +274,13 @@ static void assertChatCompletionEndpointExists(ModelRegistry modelRegistry) { |
250 | 274 |
|
251 | 275 | var rainbowSprinklesModel = eisEndpoints.get(0); |
252 | 276 | assertChatCompletionUnparsedModel(rainbowSprinklesModel); |
253 | | - assertTrue( |
254 | | - modelRegistry.containsPreconfiguredInferenceEndpointId(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1) |
255 | | - ); |
| 277 | + assertTrue(modelRegistry.containsPreconfiguredInferenceEndpointId(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); |
256 | 278 | } |
257 | 279 |
|
258 | 280 | static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) { |
259 | 281 | assertThat(rainbowSprinklesModel.taskType(), is(TaskType.CHAT_COMPLETION)); |
260 | 282 | assertThat(rainbowSprinklesModel.service(), is(ElasticInferenceService.NAME)); |
261 | | - assertThat(rainbowSprinklesModel.inferenceEntityId(), is(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); |
| 283 | + assertThat(rainbowSprinklesModel.inferenceEntityId(), is(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); |
262 | 284 | } |
263 | 285 |
|
264 | 286 | public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception { |
@@ -293,12 +315,12 @@ public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Excep |
293 | 315 | var eisEndpoints = getEisEndpoints().stream().collect(Collectors.toMap(UnparsedModel::inferenceEntityId, Function.identity())); |
294 | 316 | assertThat(eisEndpoints.size(), is(2)); |
295 | 317 |
|
296 | | - assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); |
297 | | - assertChatCompletionUnparsedModel(eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); |
| 318 | + assertTrue(eisEndpoints.containsKey(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); |
| 319 | + assertChatCompletionUnparsedModel(eisEndpoints.get(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); |
298 | 320 |
|
299 | | - assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID)); |
| 321 | + assertTrue(eisEndpoints.containsKey(DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID)); |
300 | 322 |
|
301 | | - var textEmbeddingEndpoint = eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID); |
| 323 | + var textEmbeddingEndpoint = eisEndpoints.get(DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID); |
302 | 324 | assertThat(textEmbeddingEndpoint.taskType(), is(TaskType.TEXT_EMBEDDING)); |
303 | 325 | assertThat(textEmbeddingEndpoint.service(), is(ElasticInferenceService.NAME)); |
304 | 326 | } |
|
0 commit comments