Skip to content

Commit a51f7b7

Browse files
jonathan-buttnerelasticsearchmachine
andauthored
[ML] Use EIS v2 authorization endpoint for inference API (#138249)
* Starting new response class * Writing tests * Fixing tests * [CI] Auto commit changes from spotless * Successful tests * Removing unused code * Renaming * [CI] Auto commit changes from spotless * Working integration tests * Fixing forbidden calls * Fixing tests * Fixing integration tests * Refactoring tests * Adding some comments * Fixing gp llm v2 name * Updating test name for rerank * Removing named writeable * Removing import * comments * Adding support for completion * Fixing tests * Addressing feedback * Addressing feedback * Refactoring into single if and removing listener --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent d847dfc commit a51f7b7

File tree

27 files changed

+1827
-1312
lines changed

27 files changed

+1827
-1312
lines changed

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.common.io.stream.Writeable;
1313
import org.elasticsearch.inference.EmptySecretSettings;
1414
import org.elasticsearch.inference.EmptyTaskSettings;
15+
import org.elasticsearch.inference.Model;
1516
import org.elasticsearch.inference.SecretSettings;
1617
import org.elasticsearch.inference.TaskSettings;
1718
import org.elasticsearch.xpack.core.XPackClientPlugin;
@@ -43,7 +44,7 @@ protected StoreInferenceEndpointsAction.Request createTestInstance() {
4344

4445
@Override
4546
protected StoreInferenceEndpointsAction.Request mutateInstance(StoreInferenceEndpointsAction.Request instance) throws IOException {
46-
var newModels = new ArrayList<>(instance.getModels());
47+
var newModels = new ArrayList<Model>(instance.getModels());
4748
newModels.add(ModelTests.randomModel());
4849
return new StoreInferenceEndpointsAction.Request(newModels, instance.masterNodeTimeout());
4950
}

x-pack/plugin/inference/qa/inference-service-tests/build.gradle

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ dependencies {
55
javaRestTestImplementation project(path: xpackModule('core'))
66
javaRestTestImplementation project(path: xpackModule('inference'))
77
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
8+
9+
// Allow javaRestTest to see unit-test classes from x-pack:plugin:inference so we can import some variables
10+
javaRestTestImplementation(testArtifact(project(xpackModule('inference'))))
11+
812
// Added this to have access to MockWebServer within the tests
913
javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
1014
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.elasticsearch.test.cluster.ElasticsearchCluster;
1818
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
1919
import org.elasticsearch.test.rest.ESRestTestCase;
20-
import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints;
2120
import org.junit.Before;
2221
import org.junit.ClassRule;
2322
import org.junit.Rule;
@@ -26,6 +25,7 @@
2625

2726
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModel;
2827
import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings.CCM_SUPPORTED_ENVIRONMENT;
28+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID;
2929

3030
public class BaseMockEISAuthServerTest extends ESRestTestCase {
3131

@@ -93,6 +93,6 @@ public void ensureEisPreconfiguredEndpointsExist() throws Exception {
9393
// available
9494
// Technically this only needs to be done before the suite runs but the underlying client is created in @Before and not statically
9595
// for the suite
96-
assertBusy(() -> getModel(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2));
96+
assertBusy(() -> getModel(ELSER_V2_ENDPOINT_ID));
9797
}
9898
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818

1919
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels;
2020
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels;
21+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID;
22+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID;
23+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.GP_LLM_V2_COMPLETION_ENDPOINT_ID;
24+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID;
25+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID;
26+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID;
2127
import static org.hamcrest.Matchers.hasSize;
2228
import static org.hamcrest.Matchers.is;
2329

@@ -55,12 +61,12 @@ public void testGetDefaultEndpoints() throws IOException {
5561
assertEquals("completion", model.get("task_type"));
5662
}
5763

58-
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
59-
assertInferenceIdTaskType(allModels, ".gp-llm-v2-chat_completion", TaskType.CHAT_COMPLETION);
60-
assertInferenceIdTaskType(allModels, ".gp-llm-v2-completion", TaskType.COMPLETION);
61-
assertInferenceIdTaskType(allModels, ".elser-2-elastic", TaskType.SPARSE_EMBEDDING);
62-
assertInferenceIdTaskType(allModels, ".jina-embeddings-v3", TaskType.TEXT_EMBEDDING);
63-
assertInferenceIdTaskType(allModels, ".jina-reranker-v2", TaskType.RERANK);
64+
assertInferenceIdTaskType(allModels, RAINBOW_SPRINKLES_ENDPOINT_ID, TaskType.CHAT_COMPLETION);
65+
assertInferenceIdTaskType(allModels, GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID, TaskType.CHAT_COMPLETION);
66+
assertInferenceIdTaskType(allModels, GP_LLM_V2_COMPLETION_ENDPOINT_ID, TaskType.COMPLETION);
67+
assertInferenceIdTaskType(allModels, ELSER_V2_ENDPOINT_ID, TaskType.SPARSE_EMBEDDING);
68+
assertInferenceIdTaskType(allModels, JINA_EMBED_V3_ENDPOINT_ID, TaskType.TEXT_EMBEDDING);
69+
assertInferenceIdTaskType(allModels, RERANK_V1_ENDPOINT_ID, TaskType.RERANK);
6470
}
6571

6672
private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,11 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
178178
"streaming_completion_test_service",
179179
"completion_test_service",
180180
"hugging_face",
181+
"elastic",
181182
"amazon_sagemaker",
182183
"mistral",
183-
"nvidia",
184-
"watsonxai"
184+
"watsonxai",
185+
"nvidia"
185186
).toArray()
186187
)
187188
);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,48 +17,16 @@
1717
import org.junit.runners.model.Statement;
1818

1919
import static org.elasticsearch.core.Strings.format;
20+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisAuthorizationResponseWithMultipleEndpoints;
2021

2122
public class MockElasticInferenceServiceAuthorizationServer implements TestRule {
2223

2324
private static final Logger logger = LogManager.getLogger(MockElasticInferenceServiceAuthorizationServer.class);
2425
private final MockWebServer webServer = new MockWebServer();
2526

26-
public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowSprinklesAndElser() {
27-
var server = new MockElasticInferenceServiceAuthorizationServer();
28-
29-
server.enqueueAuthorizeAllModelsResponse();
30-
return server;
31-
}
32-
3327
public void enqueueAuthorizeAllModelsResponse() {
34-
String responseJson = """
35-
{
36-
"models": [
37-
{
38-
"model_name": "rainbow-sprinkles",
39-
"task_types": ["chat"]
40-
},
41-
{
42-
"model_name": "gp-llm-v2",
43-
"task_types": ["chat"]
44-
},
45-
{
46-
"model_name": "elser_model_2",
47-
"task_types": ["embed/text/sparse"]
48-
},
49-
{
50-
"model_name": "jina-embeddings-v3",
51-
"task_types": ["embed/text/dense"]
52-
},
53-
{
54-
"model_name": "jina-reranker-v2",
55-
"task_types": ["rerank/text/text-similarity"]
56-
}
57-
]
58-
}
59-
""";
60-
61-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
28+
var authResponseBody = getEisAuthorizationResponseWithMultipleEndpoints("ignored").responseJson();
29+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authResponseBody));
6230
}
6331

6432
public String getUrl() {

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2626
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
2727
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
28-
import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints;
2928
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller;
3029
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor;
3130
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings;
31+
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests;
3232
import org.junit.After;
3333
import org.junit.AfterClass;
3434
import org.junit.Before;
@@ -37,38 +37,36 @@
3737
import java.io.IOException;
3838
import java.util.Collection;
3939
import java.util.List;
40+
import java.util.Set;
4041
import java.util.concurrent.atomic.AtomicReference;
4142
import java.util.function.Function;
4243
import java.util.stream.Collectors;
4344

4445
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
46+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.EIS_EMPTY_RESPONSE;
47+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.ELSER_V2_ENDPOINT_ID;
48+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.JINA_EMBED_V3_ENDPOINT_ID;
49+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RAINBOW_SPRINKLES_ENDPOINT_ID;
50+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.RERANK_V1_ENDPOINT_ID;
51+
import static org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse;
4552
import static org.hamcrest.Matchers.empty;
4653
import static org.hamcrest.Matchers.is;
4754
import static org.hamcrest.Matchers.not;
4855

4956
public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase {
50-
public static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]";
5157

52-
public static final String EMPTY_AUTH_RESPONSE = """
53-
{
54-
"models": [
55-
]
56-
}
57-
""";
58-
59-
public static final String AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE = """
60-
{
61-
"models": [
62-
{
63-
"model_name": "rainbow-sprinkles",
64-
"task_types": ["chat"]
65-
}
66-
]
67-
}
68-
""";
58+
public static final Set<String> EIS_PRECONFIGURED_ENDPOINT_IDS = Set.of(
59+
RAINBOW_SPRINKLES_ENDPOINT_ID,
60+
ELSER_V2_ENDPOINT_ID,
61+
JINA_EMBED_V3_ENDPOINT_ID,
62+
RERANK_V1_ENDPOINT_ID
63+
);
64+
65+
public static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]";
6966

7067
private static final MockWebServer webServer = new MockWebServer();
7168
private static String gatewayUrl;
69+
private static String chatCompletionResponseBody;
7270

7371
private ModelRegistry modelRegistry;
7472
private AuthorizationTaskExecutor authorizationTaskExecutor;
@@ -77,7 +75,8 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase {
7775
public static void initClass() throws IOException {
7876
webServer.start();
7977
gatewayUrl = getUrl(webServer);
80-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
78+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
79+
chatCompletionResponseBody = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl).responseJson();
8180
}
8281

8382
@Before
@@ -94,7 +93,7 @@ public void shutdown() {
9493
static void removeEisPreconfiguredEndpoints(ModelRegistry modelRegistry) {
9594
// Delete all the eis preconfigured endpoints
9695
var listener = new PlainActionFuture<Boolean>();
97-
modelRegistry.deleteModels(InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS, listener);
96+
modelRegistry.deleteModels(EIS_PRECONFIGURED_ENDPOINT_IDS, listener);
9897
listener.actionGet(TimeValue.THIRTY_SECONDS);
9998
}
10099

@@ -123,7 +122,7 @@ protected Collection<Class<? extends Plugin>> getPlugins() {
123122
public void testCreatesEisChatCompletionEndpoint() throws Exception {
124123
assertNoAuthorizedEisEndpoints();
125124

126-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE));
125+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
127126
restartPollingTaskAndWaitForAuthResponse();
128127

129128
assertChatCompletionEndpointExists();
@@ -149,7 +148,7 @@ static void assertNoAuthorizedEisEndpoints(
149148
var eisEndpoints = getEisEndpoints(modelRegistry);
150149
assertThat(eisEndpoints, empty());
151150

152-
for (String eisPreconfiguredEndpoints : InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS) {
151+
for (String eisPreconfiguredEndpoints : EIS_PRECONFIGURED_ENDPOINT_IDS) {
153152
assertFalse(modelRegistry.containsPreconfiguredInferenceEndpointId(eisPreconfiguredEndpoints));
154153
}
155154
}
@@ -228,13 +227,13 @@ static void cancelAuthorizationTask(AdminClient adminClient) throws Exception {
228227
public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized() throws Exception {
229228
assertNoAuthorizedEisEndpoints();
230229

231-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE));
230+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
232231
restartPollingTaskAndWaitForAuthResponse();
233232

234233
assertChatCompletionEndpointExists();
235234

236235
// Simulate that the model is no longer authorized
237-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
236+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
238237
restartPollingTaskAndWaitForAuthResponse();
239238

240239
assertChatCompletionEndpointExists();
@@ -250,55 +249,45 @@ static void assertChatCompletionEndpointExists(ModelRegistry modelRegistry) {
250249

251250
var rainbowSprinklesModel = eisEndpoints.get(0);
252251
assertChatCompletionUnparsedModel(rainbowSprinklesModel);
253-
assertTrue(
254-
modelRegistry.containsPreconfiguredInferenceEndpointId(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)
255-
);
252+
assertTrue(modelRegistry.containsPreconfiguredInferenceEndpointId(RAINBOW_SPRINKLES_ENDPOINT_ID));
256253
}
257254

258255
static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) {
259256
assertThat(rainbowSprinklesModel.taskType(), is(TaskType.CHAT_COMPLETION));
260257
assertThat(rainbowSprinklesModel.service(), is(ElasticInferenceService.NAME));
261-
assertThat(rainbowSprinklesModel.inferenceEntityId(), is(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1));
258+
assertThat(rainbowSprinklesModel.inferenceEntityId(), is(RAINBOW_SPRINKLES_ENDPOINT_ID));
262259
}
263260

264261
public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception {
265262
assertNoAuthorizedEisEndpoints();
266263

267-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE));
264+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
268265
restartPollingTaskAndWaitForAuthResponse();
269266

270267
assertChatCompletionEndpointExists();
271268

272269
// Simulate that the model is no longer authorized
273-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
270+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
274271
restartPollingTaskAndWaitForAuthResponse();
275272

276273
assertChatCompletionEndpointExists();
277274

278275
// Simulate that a text embedding model is now authorized
279-
var authorizedTextEmbeddingResponse = """
280-
{
281-
"models": [
282-
{
283-
"model_name": "jina-embeddings-v3",
284-
"task_types": ["embed/text/dense"]
285-
}
286-
]
287-
}
288-
""";
289-
290-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authorizedTextEmbeddingResponse));
276+
var jinaEmbedResponseBody = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisJinaEmbedAuthorizationResponse(gatewayUrl)
277+
.responseJson();
278+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(jinaEmbedResponseBody));
279+
291280
restartPollingTaskAndWaitForAuthResponse();
292281

293282
var eisEndpoints = getEisEndpoints().stream().collect(Collectors.toMap(UnparsedModel::inferenceEntityId, Function.identity()));
294283
assertThat(eisEndpoints.size(), is(2));
295284

296-
assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1));
297-
assertChatCompletionUnparsedModel(eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1));
285+
assertTrue(eisEndpoints.containsKey(RAINBOW_SPRINKLES_ENDPOINT_ID));
286+
assertChatCompletionUnparsedModel(eisEndpoints.get(RAINBOW_SPRINKLES_ENDPOINT_ID));
298287

299-
assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID));
288+
assertTrue(eisEndpoints.containsKey(JINA_EMBED_V3_ENDPOINT_ID));
300289

301-
var textEmbeddingEndpoint = eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID);
290+
var textEmbeddingEndpoint = eisEndpoints.get(JINA_EMBED_V3_ENDPOINT_ID);
302291
assertThat(textEmbeddingEndpoint.taskType(), is(TaskType.TEXT_EMBEDDING));
303292
assertThat(textEmbeddingEndpoint.service(), is(ElasticInferenceService.NAME));
304293
}
@@ -307,7 +296,7 @@ public void testRestartsTaskAfterAbort() throws Exception {
307296
// Ensure the task is created and we get an initial authorization response
308297
assertNoAuthorizedEisEndpoints();
309298

310-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
299+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
311300
// Abort the task and ensure it is restarted
312301
restartPollingTaskAndWaitForAuthResponse();
313302
}

0 commit comments

Comments
 (0)