Skip to content

Commit 3fd0895

Browse files
Fixing tests
1 parent acc36fa commit 3fd0895

File tree

10 files changed

+211
-150
lines changed

10 files changed

+211
-150
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9226000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
keystore_details_in_reload_secure_settings_response,9225000
1+
ml_inference_ccm_enablement_service,9226000

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/AuthorizationTaskExecutorUpgradeIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
2626
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL;
2727
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED;
28+
import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings.CCM_SUPPORTED_ENVIRONMENT;
2829
import static org.hamcrest.Matchers.is;
2930

3031
public class AuthorizationTaskExecutorUpgradeIT extends ParameterizedRollingUpgradeTestCase {
@@ -46,6 +47,7 @@ public class AuthorizationTaskExecutorUpgradeIT extends ParameterizedRollingUpgr
4647
// We need a url set for the authorization task to be created, but we don't actually care if we get a valid response
4748
// just that the task will be created upon upgrade
4849
.setting(ELASTIC_INFERENCE_SERVICE_URL.getKey(), "http://localhost:12345")
50+
.setting(CCM_SUPPORTED_ENVIRONMENT.getKey(), "false")
4951
.build();
5052

5153
private static final String GET_METHOD = "GET";

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,16 @@ public void testIsEnabled_ReturnsFalse_WhenNoCCMConfigurationStored() {
121121
assertFalse(listener.actionGet(TimeValue.THIRTY_SECONDS));
122122
}
123123

124+
public void testIsEnabled_ReturnsFalse_WhenCCMConfigurationRemoved() {
125+
assertStoreCCMConfiguration();
126+
disableCCM();
127+
128+
var listener = new PlainActionFuture<Boolean>();
129+
ccmService.get().isEnabled(listener);
130+
131+
assertFalse(listener.actionGet(TimeValue.THIRTY_SECONDS));
132+
}
133+
124134
public void testIsEnabled_ReturnsTrue_WhenCCMConfigurationIsPresent() {
125135
assertStoreCCMConfiguration();
126136

@@ -133,6 +143,7 @@ public void testIsEnabled_ReturnsTrue_WhenCCMConfigurationIsPresent() {
133143
public void testCreatesEisChatCompletionEndpoint() throws Exception {
134144
disableCCM();
135145
waitForNoTask(AUTH_TASK_ACTION, admin());
146+
assertCCMDisabled();
136147

137148
var eisEndpoints = getEisEndpoints(modelRegistry);
138149
assertThat(eisEndpoints, empty());
@@ -159,6 +170,7 @@ private void forceClusterUpdate() {
159170
public void testDisableCCM_RemovesAuthorizationTask() throws Exception {
160171
disableCCM();
161172
waitForNoTask(AUTH_TASK_ACTION, admin());
173+
assertCCMDisabled();
162174

163175
var listener = new TestPlainActionFuture<Void>();
164176
ccmService.get().storeConfiguration(new CCMModel(new SecureString("secret".toCharArray())), listener);
@@ -172,5 +184,13 @@ public void testDisableCCM_RemovesAuthorizationTask() throws Exception {
172184

173185
disableCCM();
174186
waitForNoTask(AUTH_TASK_ACTION, admin());
187+
assertCCMDisabled();
188+
}
189+
190+
private void assertCCMDisabled() {
191+
var listener = new PlainActionFuture<Boolean>();
192+
ccmService.get().isEnabled(listener);
193+
194+
assertFalse(listener.actionGet(TimeValue.THIRTY_SECONDS));
175195
}
176196
}

x-pack/plugin/inference/src/main/java/module-info.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
requires org.elasticsearch.sslconfig;
3737
requires org.apache.commons.text;
3838
requires software.amazon.awssdk.services.sagemakerruntime;
39-
requires org.elasticsearch.inference;
4039

4140
exports org.elasticsearch.xpack.inference.action;
4241
exports org.elasticsearch.xpack.inference.registry;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ private CCMRelatedComponents createCCMDependentComponents(
462462
) {
463463
var ccmEnablementService = new CCMEnablementService(services.clusterService(), services.featureService(), ccmFeature);
464464
var ccmPersistentStorageService = new CCMPersistentStorageService(services.client());
465-
var ccmService = new CCMService(ccmPersistentStorageService, ccmEnablementService, services.client(), services.projectResolver());
465+
var ccmService = new CCMService(ccmPersistentStorageService, ccmEnablementService, services.projectResolver());
466466
var ccmAuthApplierFactory = new CCMAuthenticationApplierFactory(ccmFeature, ccmService);
467467

468468
var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.ResourceAlreadyExistsException;
13+
import org.elasticsearch.ResourceNotFoundException;
1314
import org.elasticsearch.action.ActionListener;
1415
import org.elasticsearch.action.ActionType;
1516
import org.elasticsearch.action.support.ActionFilters;
@@ -162,9 +163,7 @@ private void sendStartRequest(@Nullable ClusterState state) {
162163
);
163164
}
164165

165-
// Default for testing
166-
// TODO test this
167-
boolean shouldSkipCreatingTask(@Nullable ClusterState state) {
166+
private boolean shouldSkipCreatingTask(@Nullable ClusterState state) {
168167
if (state == null) {
169168
return true;
170169
}
@@ -195,6 +194,22 @@ private static boolean authorizationTaskExists(@Nullable ClusterState state) {
195194
return ClusterPersistentTasksCustomMetadata.getTaskWithId(state, TASK_NAME) != null;
196195
}
197196

197+
private void sendStopRequest() {
198+
persistentTasksService.sendClusterRemoveRequest(
199+
TASK_NAME,
200+
TimeValue.THIRTY_SECONDS,
201+
ActionListener.wrap(
202+
persistentTask -> logger.info("Stopped authorization poller task, id {}", persistentTask.getId()),
203+
exception -> {
204+
var thrownException = exception instanceof RemoteTransportException ? exception.getCause() : exception;
205+
if (thrownException instanceof ResourceNotFoundException == false) {
206+
logger.error("Failed to stop authorization poller task", exception);
207+
}
208+
}
209+
)
210+
);
211+
}
212+
198213
/**
199214
* This method should only be used for testing purposes to get the current running task.
200215
*/
@@ -276,6 +291,8 @@ public Action(
276291
protected void receiveMessage(Message message) {
277292
if (message.enable()) {
278293
authorizationTaskExecutor.sendStartRequestWithCurrentClusterState();
294+
} else {
295+
authorizationTaskExecutor.sendStopRequest();
279296
}
280297
}
281298
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMEnablementService.java

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
import org.elasticsearch.action.support.master.AcknowledgedResponse;
1414
import org.elasticsearch.cluster.AbstractNamedDiffable;
1515
import org.elasticsearch.cluster.AckedBatchedClusterStateUpdateTask;
16-
import org.elasticsearch.cluster.ClusterChangedEvent;
1716
import org.elasticsearch.cluster.ClusterState;
1817
import org.elasticsearch.cluster.ClusterStateAckListener;
19-
import org.elasticsearch.cluster.ClusterStateListener;
2018
import org.elasticsearch.cluster.NamedDiff;
2119
import org.elasticsearch.cluster.SimpleBatchedAckListenerTaskExecutor;
2220
import org.elasticsearch.cluster.metadata.Metadata;
@@ -45,7 +43,6 @@
4543
import java.util.Iterator;
4644
import java.util.List;
4745
import java.util.Objects;
48-
import java.util.concurrent.atomic.AtomicReference;
4946

5047
import static org.elasticsearch.xpack.inference.InferenceFeatures.INFERENCE_CCM_ENABLEMENT_SERVICE;
5148
import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature.CCM_UNSUPPORTED_UNTIL_UPGRADED_EXCEPTION;
@@ -58,15 +55,14 @@
5855
*
5956
* This does not handle storing the actual CCM configuration, that is handled by {@link CCMPersistentStorageService}.
6057
*/
61-
public class CCMEnablementService implements ClusterStateListener {
58+
public class CCMEnablementService {
6259

6360
private static final String TASK_QUEUE_NAME = "inference-ccm-enabled-management";
6461
private static final TransportVersion ML_INFERENCE_CCM_ENABLEMENT_SERVICE = TransportVersion.fromName(
6562
"ml_inference_ccm_enablement_service"
6663
);
6764

6865
private final MasterServiceTaskQueue<MetadataTask> taskQueue;
69-
private final AtomicReference<Metadata> lastMetadata = new AtomicReference<>();
7066
private final FeatureService featureService;
7167
private final ClusterService clusterService;
7268
private final CCMFeature ccmFeature;
@@ -76,28 +72,24 @@ public CCMEnablementService(ClusterService clusterService, FeatureService featur
7672
this.featureService = Objects.requireNonNull(featureService);
7773
this.taskQueue = clusterService.createTaskQueue(TASK_QUEUE_NAME, Priority.NORMAL, new UpdateTaskExecutor());
7874
this.ccmFeature = Objects.requireNonNull(ccmFeature);
79-
if (this.ccmFeature.isCcmSupportedEnvironment()) {
80-
clusterService.addListener(this);
81-
}
82-
}
83-
84-
@Override
85-
public void clusterChanged(ClusterChangedEvent event) {
86-
if (lastMetadata.get() == null || event.metadataChanged()) {
87-
lastMetadata.set(event.state().metadata());
88-
}
8975
}
9076

9177
public boolean isEnabled(ProjectId projectId) {
92-
if (ccmFeature.isCcmSupportedEnvironment() == false || lastMetadata.get() == null) {
78+
if (ccmFeature.isCcmSupportedEnvironment() == false || isClusterStateReady() == false) {
9379
return false;
9480
}
9581

96-
var projectMetadata = lastMetadata.get().getProject(projectId);
82+
var projectMetadata = clusterService.state().metadata().getProject(projectId);
9783
var metadata = EnablementMetadata.fromMetadata(projectMetadata);
9884
return metadata.enabled;
9985
}
10086

87+
private boolean isClusterStateReady() {
88+
return clusterService.state() != null
89+
&& clusterService.state().clusterRecovered()
90+
&& featureService.clusterHasFeature(clusterService.state(), INFERENCE_CCM_ENABLEMENT_SERVICE);
91+
}
92+
10193
/**
10294
* This should only be called on the master node.
10395
* Sets the enabled state for CCM in cluster state.
@@ -108,8 +100,7 @@ public void setEnabled(ProjectId projectId, boolean enabled, ActionListener<Ackn
108100
return;
109101
}
110102

111-
if (clusterService.state().clusterRecovered() == false
112-
|| featureService.clusterHasFeature(clusterService.state(), INFERENCE_CCM_ENABLEMENT_SERVICE) == false) {
103+
if (isClusterStateReady() == false) {
113104
listener.onFailure(CCM_UNSUPPORTED_UNTIL_UPGRADED_EXCEPTION);
114105
return;
115106
}
@@ -159,7 +150,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
159150
public static class EnablementMetadata extends AbstractNamedDiffable<Metadata.ProjectCustom> implements Metadata.ProjectCustom {
160151
public static final String NAME = "inference-ccm-enablement-management-metadata";
161152
private static final EnablementMetadata DISABLED = new EnablementMetadata(false);
162-
private static final EnablementMetadata ENABLED = new EnablementMetadata(false);
153+
private static final EnablementMetadata ENABLED = new EnablementMetadata(true);
163154
private static final ParseField ENABLED_FIELD = new ParseField("enabled");
164155

165156
private static final ConstructingObjectParser<EnablementMetadata, Void> PARSER = new ConstructingObjectParser<>(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMService.java

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.client.internal.OriginSettingClient;
1616
import org.elasticsearch.cluster.project.ProjectResolver;
1717
import org.elasticsearch.xpack.core.ClientHelper;
18+
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor;
1819

1920
import java.util.Objects;
2021

@@ -23,20 +24,20 @@ public class CCMService {
2324
private static final Logger logger = LogManager.getLogger(CCMService.class);
2425

2526
private final CCMPersistentStorageService ccmPersistentStorageService;
26-
private final Client client;
2727
private final CCMEnablementService ccmEnablementService;
2828
private final ProjectResolver projectResolver;
29+
private final Client client;
2930

3031
public CCMService(
3132
CCMPersistentStorageService ccmPersistentStorageService,
3233
CCMEnablementService enablementService,
33-
Client client,
34-
ProjectResolver projectResolver
34+
ProjectResolver projectResolver,
35+
Client client
3536
) {
3637
this.ccmPersistentStorageService = Objects.requireNonNull(ccmPersistentStorageService);
37-
this.client = new OriginSettingClient(Objects.requireNonNull(client), ClientHelper.INFERENCE_ORIGIN);
3838
this.ccmEnablementService = Objects.requireNonNull(enablementService);
3939
this.projectResolver = Objects.requireNonNull(projectResolver);
40+
this.client = new OriginSettingClient(Objects.requireNonNull(client), ClientHelper.INFERENCE_ORIGIN);
4041
// TODO initialize the cache for the CCM configuration
4142
}
4243

@@ -47,14 +48,23 @@ public void isEnabled(ActionListener<Boolean> listener) {
4748
public void storeConfiguration(CCMModel model, ActionListener<Void> listener) {
4849
SubscribableListener.<Void>newForked(storeListener -> ccmPersistentStorageService.store(model, storeListener))
4950
.<Void>andThen(
50-
enableAuthExecutorListener -> ccmEnablementService.setEnabled(
51-
projectResolver.getProjectId(),
52-
true,
51+
enablementListener -> ccmEnablementService.setEnabled(projectResolver.getProjectId(), true, ActionListener.wrap(ack -> {
52+
logger.debug("Successfully set CCM enabled in enablement service");
53+
enablementListener.onResponse(null);
54+
}, e -> {
55+
logger.atDebug().withThrowable(e).log("Failed to enable CCM in enablement service");
56+
enablementListener.onFailure(e);
57+
}))
58+
)
59+
.<Void>andThen(
60+
enableAuthExecutorListener -> client.execute(
61+
AuthorizationTaskExecutor.Action.INSTANCE,
62+
AuthorizationTaskExecutor.Action.request(AuthorizationTaskExecutor.Message.ENABLE_MESSAGE, null),
5363
ActionListener.wrap(ack -> {
54-
logger.debug("Successfully set CCM enabled in enablement service");
64+
logger.debug("Successfully enabled authorization task executor");
5565
enableAuthExecutorListener.onResponse(null);
5666
}, e -> {
57-
logger.atDebug().withThrowable(e).log("Failed to enable CCM in enablement service");
67+
logger.atDebug().withThrowable(e).log("Failed to enable authorization task executor");
5868
enableAuthExecutorListener.onFailure(e);
5969
})
6070
)
@@ -82,7 +92,22 @@ public void disableCCM(ActionListener<Void> listener) {
8292
disableAuthExecutorListener.onFailure(e);
8393
})
8494
)
85-
).andThen(ccmPersistentStorageService::delete).addListener(listener);
95+
)
96+
.andThen(ccmPersistentStorageService::delete)
97+
.<Void>andThen(
98+
disableAuthExecutorListener -> client.execute(
99+
AuthorizationTaskExecutor.Action.INSTANCE,
100+
AuthorizationTaskExecutor.Action.request(AuthorizationTaskExecutor.Message.DISABLE_MESSAGE, null),
101+
ActionListener.wrap(ack -> {
102+
logger.debug("Successfully disabled authorization task executor");
103+
disableAuthExecutorListener.onResponse(null);
104+
}, e -> {
105+
logger.atDebug().withThrowable(e).log("Failed to disable authorization task executor");
106+
disableAuthExecutorListener.onFailure(e);
107+
})
108+
)
109+
)
110+
.addListener(listener);
86111

87112
// TODO implement invalidating the cache
88113
}

0 commit comments

Comments
 (0)