Skip to content

Commit 1cdc01f

Browse files
[ML] Store CCM enabled state in cluster state and integrate cache (elastic#138530)
* Adding enablement service with cluster state * Fixing tests * Working tests * Adding more tests for enablement * Adding integration tests * Adding more tests and cleanup * Skipping tests for release testing * Trying to get request checks to work * Fixing test with bearer * Cleaning up * Addressing feedback * Renaming feedback
1 parent 65d8b48 commit 1cdc01f

25 files changed

+1189
-224
lines changed

muted-tests.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,6 @@ tests:
364364
- class: org.elasticsearch.search.basic.SearchWithRandomDisconnectsIT
365365
method: testSearchWithRandomDisconnects
366366
issue: https://github.com/elastic/elasticsearch/issues/138128
367-
- class: org.elasticsearch.xpack.inference.integration.CCMServiceIT
368-
method: testCreatesEisChatCompletionEndpoint
369-
issue: https://github.com/elastic/elasticsearch/issues/138206
370367
- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT
371368
method: test {p0=search/95_sort_mixed_numeric_types/Simple sort}
372369
issue: https://github.com/elastic/elasticsearch/issues/138297
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9235000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
msearch_project_routing,9234000
1+
inference_ccm_enablement_service,9235000

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public static void cleanUpClass() {
105105
@Override
106106
protected Settings nodeSettings() {
107107
return Settings.builder()
108-
// Disable CCM to ensure that only the authorization task executor is initialized in the inference plugin when it is created
108+
// Disable CCM to ensure that we don't rely on a CCM configuration existing
109109
.put(CCMSettings.CCM_SUPPORTED_ENVIRONMENT.getKey(), false)
110110
.put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), gatewayUrl)
111111
// Ensure that the polling logic only occurs once so we can deterministically control when an authorization response is
@@ -119,12 +119,20 @@ protected Collection<Class<? extends Plugin>> getPlugins() {
119119
return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class);
120120
}
121121

122+
@Override
123+
protected boolean resetNodeAfterTest() {
124+
return true;
125+
}
126+
122127
public void testCreatesEisChatCompletionEndpoint() throws Exception {
123128
assertNoAuthorizedEisEndpoints();
124129

130+
webServer.clearRequests();
125131
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
126132
restartPollingTaskAndWaitForAuthResponse();
127133

134+
assertWebServerReceivedRequest();
135+
128136
assertChatCompletionEndpointExists();
129137
}
130138

@@ -194,14 +202,21 @@ private void restartPollingTaskAndWaitForAuthResponse() throws Exception {
194202
restartPollingTaskAndWaitForAuthResponse(admin(), authorizationTaskExecutor);
195203
}
196204

197-
static void restartPollingTaskAndWaitForAuthResponse(AdminClient adminClient, AuthorizationTaskExecutor authTaskExecutor)
205+
private static void restartPollingTaskAndWaitForAuthResponse(AdminClient adminClient, AuthorizationTaskExecutor authTaskExecutor)
198206
throws Exception {
199207
cancelAuthorizationTask(adminClient);
200208

201209
// wait for the new task to be recreated and an authorization response to be processed
202210
waitForAuthorizationToComplete(authTaskExecutor);
203211
}
204212

213+
private static void assertWebServerReceivedRequest() throws Exception {
214+
assertBusy(() -> {
215+
var requests = webServer.requests();
216+
assertThat(requests.size(), is(1));
217+
});
218+
}
219+
205220
static void waitForAuthorizationToComplete(AuthorizationTaskExecutor authTaskExecutor) throws Exception {
206221
assertBusy(() -> {
207222
var newPoller = authTaskExecutor.getCurrentPollerTask();
@@ -227,29 +242,35 @@ static void cancelAuthorizationTask(AdminClient adminClient) throws Exception {
227242
public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized() throws Exception {
228243
assertNoAuthorizedEisEndpoints();
229244

245+
webServer.clearRequests();
230246
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
231247
restartPollingTaskAndWaitForAuthResponse();
248+
assertWebServerReceivedRequest();
232249

233250
assertChatCompletionEndpointExists();
234251

252+
webServer.clearRequests();
235253
// Simulate that the model is no longer authorized
236254
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
237255
restartPollingTaskAndWaitForAuthResponse();
256+
assertWebServerReceivedRequest();
238257

239258
assertChatCompletionEndpointExists();
240259
}
241260

242-
private void assertChatCompletionEndpointExists() {
261+
private void assertChatCompletionEndpointExists() throws Exception {
243262
assertChatCompletionEndpointExists(modelRegistry);
244263
}
245264

246-
static void assertChatCompletionEndpointExists(ModelRegistry modelRegistry) {
247-
var eisEndpoints = getEisEndpoints(modelRegistry);
248-
assertThat(eisEndpoints.size(), is(1));
265+
static void assertChatCompletionEndpointExists(ModelRegistry modelRegistry) throws Exception {
266+
assertBusy(() -> {
267+
var eisEndpoints = getEisEndpoints(modelRegistry);
268+
assertThat(eisEndpoints.size(), is(1));
249269

250-
var rainbowSprinklesModel = eisEndpoints.get(0);
251-
assertChatCompletionUnparsedModel(rainbowSprinklesModel);
252-
assertTrue(modelRegistry.containsPreconfiguredInferenceEndpointId(RAINBOW_SPRINKLES_ENDPOINT_ID));
270+
var rainbowSprinklesModel = eisEndpoints.get(0);
271+
assertChatCompletionUnparsedModel(rainbowSprinklesModel);
272+
assertTrue(modelRegistry.containsPreconfiguredInferenceEndpointId(RAINBOW_SPRINKLES_ENDPOINT_ID));
273+
});
253274
}
254275

255276
static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) {
@@ -261,23 +282,29 @@ static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesMode
261282
public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception {
262283
assertNoAuthorizedEisEndpoints();
263284

285+
webServer.clearRequests();
264286
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
265287
restartPollingTaskAndWaitForAuthResponse();
288+
assertWebServerReceivedRequest();
266289

267290
assertChatCompletionEndpointExists();
268291

269292
// Simulate that the model is no longer authorized
293+
webServer.clearRequests();
270294
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
271295
restartPollingTaskAndWaitForAuthResponse();
296+
assertWebServerReceivedRequest();
272297

273298
assertChatCompletionEndpointExists();
274299

300+
webServer.clearRequests();
275301
// Simulate that a text embedding model is now authorized
276302
var jinaEmbedResponseBody = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisJinaEmbedAuthorizationResponse(gatewayUrl)
277303
.responseJson();
278304
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(jinaEmbedResponseBody));
279305

280306
restartPollingTaskAndWaitForAuthResponse();
307+
assertWebServerReceivedRequest();
281308

282309
var eisEndpoints = getEisEndpoints().stream().collect(Collectors.toMap(UnparsedModel::inferenceEntityId, Function.identity()));
283310
assertThat(eisEndpoints.size(), is(2));
@@ -296,8 +323,10 @@ public void testRestartsTaskAfterAbort() throws Exception {
296323
// Ensure the task is created and we get an initial authorization response
297324
assertNoAuthorizedEisEndpoints();
298325

326+
webServer.clearRequests();
299327
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
300328
// Abort the task and ensure it is restarted
301329
restartPollingTaskAndWaitForAuthResponse();
330+
assertWebServerReceivedRequest();
302331
}
303332
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.integration;
9+
10+
import org.elasticsearch.action.support.TestPlainActionFuture;
11+
import org.elasticsearch.action.support.master.AcknowledgedResponse;
12+
import org.elasticsearch.cluster.metadata.ProjectId;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.plugins.Plugin;
16+
import org.elasticsearch.reindex.ReindexPlugin;
17+
import org.elasticsearch.test.ESSingleNodeTestCase;
18+
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
19+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
20+
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMEnablementService;
21+
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeatureFlag;
22+
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings;
23+
import org.junit.Before;
24+
import org.junit.BeforeClass;
25+
26+
import java.util.Collection;
27+
28+
import static org.hamcrest.Matchers.is;
29+
30+
public class CCMEnablementServiceIT extends ESSingleNodeTestCase {
31+
32+
private CCMEnablementService ccmEnablementService;
33+
34+
@BeforeClass
35+
public static void classSetup() {
36+
assumeTrue("CCM is behind a feature flag and snapshot only right now", CCMFeatureFlag.FEATURE_FLAG.isEnabled());
37+
}
38+
39+
@Before
40+
public void createComponents() {
41+
ccmEnablementService = node().injector().getInstance(CCMEnablementService.class);
42+
}
43+
44+
// Ensure we have a node that doesn't contain any enablement cluster state
45+
@Override
46+
protected boolean resetNodeAfterTest() {
47+
return true;
48+
}
49+
50+
@Override
51+
protected Settings nodeSettings() {
52+
return Settings.builder()
53+
.put(CCMSettings.CCM_SUPPORTED_ENVIRONMENT.getKey(), true)
54+
// Disable the authorization task so we don't get errors about inconsistent state while we're
55+
// changing enablement
56+
.put(ElasticInferenceServiceSettings.AUTHORIZATION_ENABLED.getKey(), false)
57+
.build();
58+
}
59+
60+
@Override
61+
protected Collection<Class<? extends Plugin>> getPlugins() {
62+
return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class);
63+
}
64+
65+
public void testSetEnabled() {
66+
assertCCMDisabled();
67+
68+
enabledCCM();
69+
assertCCMEnabled();
70+
}
71+
72+
public void testIsEnabled() {
73+
assertCCMDisabled();
74+
75+
enabledCCM();
76+
assertCCMEnabled();
77+
78+
disabledCCM();
79+
assertCCMDisabled();
80+
}
81+
82+
private void enabledCCM() {
83+
setCCMState(true);
84+
}
85+
86+
private void disabledCCM() {
87+
setCCMState(false);
88+
}
89+
90+
private void setCCMState(boolean enabled) {
91+
var listener = new TestPlainActionFuture<AcknowledgedResponse>();
92+
ccmEnablementService.setEnabled(ProjectId.DEFAULT, enabled, listener);
93+
assertThat(listener.actionGet(TimeValue.THIRTY_SECONDS), is(AcknowledgedResponse.TRUE));
94+
}
95+
96+
private void assertCCMEnabled() {
97+
assertCCMState(true);
98+
}
99+
100+
private void assertCCMDisabled() {
101+
assertCCMState(false);
102+
}
103+
104+
private void assertCCMState(boolean expectedEnabled) {
105+
assertThat(expectedEnabled, is(ccmEnablementService.isEnabled(ProjectId.DEFAULT)));
106+
}
107+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.integration;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.support.TestPlainActionFuture;
12+
import org.elasticsearch.action.support.master.AcknowledgedResponse;
13+
import org.elasticsearch.cluster.metadata.ProjectId;
14+
import org.elasticsearch.common.settings.Settings;
15+
import org.elasticsearch.core.TimeValue;
16+
import org.elasticsearch.plugins.Plugin;
17+
import org.elasticsearch.reindex.ReindexPlugin;
18+
import org.elasticsearch.test.ESSingleNodeTestCase;
19+
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
20+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
21+
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMEnablementService;
22+
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings;
23+
import org.junit.Before;
24+
25+
import java.util.Collection;
26+
27+
import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature.CCM_FORBIDDEN_EXCEPTION;
28+
import static org.hamcrest.Matchers.is;
29+
30+
public class CCMEnablementServiceUnsupportedEnvironmentIT extends ESSingleNodeTestCase {
31+
32+
private CCMEnablementService ccmEnablementService;
33+
34+
@Before
35+
public void createComponents() {
36+
ccmEnablementService = node().injector().getInstance(CCMEnablementService.class);
37+
}
38+
39+
// Ensure we have a node that doesn't contain any enablement cluster state
40+
@Override
41+
protected boolean resetNodeAfterTest() {
42+
return true;
43+
}
44+
45+
@Override
46+
protected Settings nodeSettings() {
47+
return Settings.builder()
48+
.put(CCMSettings.CCM_SUPPORTED_ENVIRONMENT.getKey(), false)
49+
// Disable the authorization task so we don't get errors about inconsistent state while we're
50+
// changing enablement
51+
.put(ElasticInferenceServiceSettings.AUTHORIZATION_ENABLED.getKey(), false)
52+
.build();
53+
}
54+
55+
@Override
56+
protected Collection<Class<? extends Plugin>> getPlugins() {
57+
return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class);
58+
}
59+
60+
public void testSetEnabledReturnsFailure_WhenEnvironmentIsNotSupported() {
61+
assertCCMDisabled(false);
62+
63+
var listener = new TestPlainActionFuture<AcknowledgedResponse>();
64+
ccmEnablementService.setEnabled(ProjectId.DEFAULT, true, listener);
65+
66+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TimeValue.THIRTY_SECONDS));
67+
assertThat(exception, is(CCM_FORBIDDEN_EXCEPTION));
68+
assertCCMDisabled(false);
69+
}
70+
71+
private void assertCCMDisabled(boolean expectedEnabled) {
72+
assertThat(expectedEnabled, is(ccmEnablementService.isEnabled(ProjectId.DEFAULT)));
73+
}
74+
}

0 commit comments

Comments
 (0)