Skip to content

Commit acc36fa

Browse files
Adding enablement service with cluster state
1 parent b18ad86 commit acc36fa

File tree

9 files changed

+346
-89
lines changed

9 files changed

+346
-89
lines changed

x-pack/plugin/inference/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
requires org.elasticsearch.sslconfig;
3737
requires org.apache.commons.text;
3838
requires software.amazon.awssdk.services.sagemakerruntime;
39+
requires org.elasticsearch.inference;
3940

4041
exports org.elasticsearch.xpack.inference.action;
4142
exports org.elasticsearch.xpack.inference.registry;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,16 @@ public class InferenceFeatures implements FeatureSpecification {
5757
public static final NodeFeature INFERENCE_CCM_CACHE = new NodeFeature("inference.ccm.cache");
5858
public static final NodeFeature SEARCH_USAGE_EXTENDED_DATA = new NodeFeature("search.usage.extended_data");
5959
public static final NodeFeature INFERENCE_AUTH_POLLER_PERSISTENT_TASK = new NodeFeature("inference.auth_poller.persistent_task");
60+
public static final NodeFeature INFERENCE_CCM_ENABLEMENT_SERVICE = new NodeFeature("inference.ccm.enablement_service");
6061

6162
@Override
6263
public Set<NodeFeature> getFeatures() {
63-
return Set.of(INFERENCE_ENDPOINT_CACHE, INFERENCE_CCM_CACHE, INFERENCE_AUTH_POLLER_PERSISTENT_TASK);
64+
return Set.of(
65+
INFERENCE_ENDPOINT_CACHE,
66+
INFERENCE_CCM_CACHE,
67+
INFERENCE_AUTH_POLLER_PERSISTENT_TASK,
68+
INFERENCE_CCM_ENABLEMENT_SERVICE
69+
);
6470
}
6571

6672
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
156156
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory;
157157
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMCache;
158+
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMEnablementService;
158159
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature;
159160
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMIndex;
160161
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMInformedSettings;
@@ -287,6 +288,7 @@ public List<ActionHandler> getActions() {
287288
new ActionHandler(PutCCMConfigurationAction.INSTANCE, TransportPutCCMConfigurationAction.class),
288289
new ActionHandler(DeleteCCMConfigurationAction.INSTANCE, TransportDeleteCCMConfigurationAction.class),
289290
new ActionHandler(CCMCache.ClearCCMCacheAction.INSTANCE, CCMCache.ClearCCMCacheAction.class),
291+
// TODO can I remove this?
290292
new ActionHandler(AuthorizationTaskExecutor.Action.INSTANCE, AuthorizationTaskExecutor.Action.class),
291293
new ActionHandler(GetInferenceFieldsAction.INSTANCE, TransportGetInferenceFieldsAction.class)
292294
);
@@ -458,8 +460,9 @@ private CCMRelatedComponents createCCMDependentComponents(
458460
ModelRegistry modelRegistry,
459461
CCMFeature ccmFeature
460462
) {
463+
var ccmEnablementService = new CCMEnablementService(services.clusterService(), services.featureService(), ccmFeature);
461464
var ccmPersistentStorageService = new CCMPersistentStorageService(services.client());
462-
var ccmService = new CCMService(ccmPersistentStorageService, services.client());
465+
var ccmService = new CCMService(ccmPersistentStorageService, ccmEnablementService, services.client(), services.projectResolver());
463466
var ccmAuthApplierFactory = new CCMAuthenticationApplierFactory(ccmFeature, ccmService);
464467

465468
var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
@@ -471,6 +474,8 @@ private CCMRelatedComponents createCCMDependentComponents(
471474
var authTaskExecutor = AuthorizationTaskExecutor.create(
472475
services.clusterService(),
473476
services.featureService(),
477+
ccmEnablementService,
478+
ccmFeature,
474479
new AuthorizationPoller.Parameters(
475480
serviceComponents,
476481
authorizationHandler,
@@ -483,14 +488,7 @@ private CCMRelatedComponents createCCMDependentComponents(
483488
)
484489
);
485490
authorizationTaskExecutorRef.set(authTaskExecutor);
486-
487-
// If CCM is not allowed in this environment then we can initialize the auth poller task because
488-
// authentication with EIS will be through certs that are already configured. If CCM configuration is allowed,
489-
// we need to wait for the user to provide an API key before we can start polling EIS
490-
if (ccmFeature.isCcmSupportedEnvironment() == false) {
491-
logger.info("CCM configuration is not permitted - starting EIS authorization task executor");
492-
authTaskExecutor.startAndLazyCreateTask();
493-
}
491+
authTaskExecutor.startAndLazyCreateTask();
494492

495493
return new CCMRelatedComponents(
496494
List.of(
@@ -616,7 +614,8 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
616614
)
617615
),
618616
InferenceNamedWriteablesProvider.getNamedWriteables(),
619-
AuthorizationTaskExecutor.getNamedWriteables()
617+
AuthorizationTaskExecutor.getNamedWriteables(),
618+
CCMEnablementService.getNamedWriteables()
620619
).flatMap(List::stream).toList();
621620

622621
}
@@ -636,7 +635,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {
636635
ClearInferenceEndpointCacheAction.InvalidateCacheMetadata::fromXContent
637636
)
638637
),
639-
AuthorizationTaskExecutor.getNamedXContentParsers()
638+
AuthorizationTaskExecutor.getNamedXContentParsers(),
639+
CCMEnablementService.getNamedXContentParsers()
640640
).flatMap(List::stream).toList();
641641
}
642642

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteCCMConfigurationAction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ protected void masterOperation(
7474
return;
7575
}
7676

77+
// TODO check that all nodes in the cluster support the feature
78+
7779
var disabledListener = listener.<Void>delegateFailureIgnoreResponseAndWrap(
7880
delegate -> delegate.onResponse(new CCMEnabledActionResponse(false))
7981
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutCCMConfigurationAction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ protected void masterOperation(
7575
return;
7676
}
7777

78+
// TODO check that all nodes in the cluster support the feature
79+
7880
var enabledListener = listener.<Void>delegateFailureIgnoreResponseAndWrap(
7981
delegate -> delegate.onResponse(new CCMEnabledActionResponse(true))
8082
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java

Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.ResourceAlreadyExistsException;
13-
import org.elasticsearch.ResourceNotFoundException;
1413
import org.elasticsearch.action.ActionListener;
1514
import org.elasticsearch.action.ActionType;
1615
import org.elasticsearch.action.support.ActionFilters;
1716
import org.elasticsearch.cluster.ClusterChangedEvent;
1817
import org.elasticsearch.cluster.ClusterState;
1918
import org.elasticsearch.cluster.ClusterStateListener;
19+
import org.elasticsearch.cluster.metadata.ProjectId;
2020
import org.elasticsearch.cluster.service.ClusterService;
2121
import org.elasticsearch.common.Strings;
2222
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@@ -35,14 +35,15 @@
3535
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
3636
import org.elasticsearch.persistent.PersistentTasksExecutor;
3737
import org.elasticsearch.persistent.PersistentTasksService;
38-
import org.elasticsearch.plugins.Plugin;
3938
import org.elasticsearch.tasks.TaskId;
4039
import org.elasticsearch.transport.RemoteTransportException;
4140
import org.elasticsearch.transport.TransportService;
4241
import org.elasticsearch.xcontent.NamedXContentRegistry;
4342
import org.elasticsearch.xcontent.ParseField;
4443
import org.elasticsearch.xpack.inference.InferenceFeatures;
4544
import org.elasticsearch.xpack.inference.common.BroadcastMessageAction;
45+
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMEnablementService;
46+
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature;
4647

4748
import java.io.IOException;
4849
import java.util.List;
@@ -71,10 +72,14 @@ public class AuthorizationTaskExecutor extends PersistentTasksExecutor<Authoriza
7172
private final AtomicReference<AuthorizationPoller> currentTask = new AtomicReference<>();
7273
private final AtomicBoolean running = new AtomicBoolean(false);
7374
private final FeatureService featureService;
75+
private final CCMEnablementService ccmEnablementService;
76+
private final CCMFeature ccmFeature;
7477

7578
public static AuthorizationTaskExecutor create(
7679
ClusterService clusterService,
7780
FeatureService featureService,
81+
CCMEnablementService ccmEnablementService,
82+
CCMFeature ccmFeature,
7883
AuthorizationPoller.Parameters parameters
7984
) {
8085
Objects.requireNonNull(clusterService);
@@ -84,6 +89,8 @@ public static AuthorizationTaskExecutor create(
8489
clusterService,
8590
new PersistentTasksService(clusterService, parameters.serviceComponents().threadPool(), parameters.client()),
8691
featureService,
92+
ccmEnablementService,
93+
ccmFeature,
8794
parameters
8895
);
8996
}
@@ -93,12 +100,16 @@ public static AuthorizationTaskExecutor create(
93100
ClusterService clusterService,
94101
PersistentTasksService persistentTasksService,
95102
FeatureService featureService,
103+
CCMEnablementService ccmEnablementService,
104+
CCMFeature ccmFeature,
96105
AuthorizationPoller.Parameters pollerParameters
97106
) {
98107
super(TASK_NAME, pollerParameters.serviceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME));
99108
this.clusterService = Objects.requireNonNull(clusterService);
100109
this.featureService = Objects.requireNonNull(featureService);
101110
this.persistentTasksService = Objects.requireNonNull(persistentTasksService);
111+
this.ccmEnablementService = Objects.requireNonNull(ccmEnablementService);
112+
this.ccmFeature = Objects.requireNonNull(ccmFeature);
102113
this.pollerParameters = Objects.requireNonNull(pollerParameters);
103114
}
104115

@@ -109,20 +120,10 @@ public static AuthorizationTaskExecutor create(
109120
* get an error indicating that it isn't aware of whether the task is a cluster scoped task.
110121
*/
111122
public synchronized void startAndLazyCreateTask() {
112-
startInternal(false);
123+
startInternal();
113124
}
114125

115-
/**
116-
* Starts the authorization task executor and creates the persistent task if it doesn't already exist. This should only be called from
117-
* a context where the cluster state is already initialized. Don't call this from the plugin
118-
* {@link org.elasticsearch.xpack.inference.InferencePlugin#createComponents(Plugin.PluginServices)}. Use
119-
* {@link #startAndLazyCreateTask()} instead.
120-
*/
121-
public synchronized void startAndImmediatelyCreateTask() {
122-
startInternal(true);
123-
}
124-
125-
private void startInternal(boolean createPersistentTask) {
126+
private void startInternal() {
126127
var eisUrl = pollerParameters.elasticInferenceServiceSettings().getElasticInferenceServiceUrl();
127128

128129
logger.info("Authorization task executor EIS URL: [{}]", eisUrl);
@@ -131,14 +132,14 @@ private void startInternal(boolean createPersistentTask) {
131132
if (Strings.isNullOrEmpty(eisUrl) == false && running.compareAndSet(false, true)) {
132133
logger.info("Starting authorization task executor");
133134

134-
if (createPersistentTask) {
135-
sendStartRequest(clusterService.state());
136-
}
137-
138135
clusterService.addListener(this);
139136
}
140137
}
141138

139+
private void sendStartRequestWithCurrentClusterState() {
140+
sendStartRequest(clusterService.state());
141+
}
142+
142143
private void sendStartRequest(@Nullable ClusterState state) {
143144
if (shouldSkipCreatingTask(state)) {
144145
return;
@@ -161,12 +162,21 @@ private void sendStartRequest(@Nullable ClusterState state) {
161162
);
162163
}
163164

164-
private boolean shouldSkipCreatingTask(@Nullable ClusterState state) {
165+
// Default for testing
166+
// TODO test this
167+
boolean shouldSkipCreatingTask(@Nullable ClusterState state) {
165168
if (state == null) {
166169
return true;
167170
}
168171

169-
return clusterCanSupportFeature(state) == false || running.get() == false || authorizationTaskExists(state);
172+
return clusterCanSupportFeature(state) == false
173+
|| running.get() == false
174+
|| authorizationTaskExists(state)
175+
|| ccmSupportedButNotYetConfigured();
176+
}
177+
178+
private boolean ccmSupportedButNotYetConfigured() {
179+
return ccmFeature.isCcmSupportedEnvironment() && ccmEnablementService.isEnabled(ProjectId.DEFAULT) == false;
170180
}
171181

172182
private boolean clusterCanSupportFeature(@Nullable ClusterState state) {
@@ -185,31 +195,6 @@ private static boolean authorizationTaskExists(@Nullable ClusterState state) {
185195
return ClusterPersistentTasksCustomMetadata.getTaskWithId(state, TASK_NAME) != null;
186196
}
187197

188-
public synchronized void stop() {
189-
if (running.compareAndSet(true, false)) {
190-
logger.info("Shutting down authorization task executor");
191-
clusterService.removeListener(this);
192-
193-
sendStopRequest();
194-
}
195-
}
196-
197-
private void sendStopRequest() {
198-
persistentTasksService.sendClusterRemoveRequest(
199-
TASK_NAME,
200-
TimeValue.THIRTY_SECONDS,
201-
ActionListener.wrap(
202-
persistentTask -> logger.info("Stopped authorization poller task, id {}", persistentTask.getId()),
203-
exception -> {
204-
var thrownException = exception instanceof RemoteTransportException ? exception.getCause() : exception;
205-
if (thrownException instanceof ResourceNotFoundException == false) {
206-
logger.error("Failed to stop authorization poller task", exception);
207-
}
208-
}
209-
)
210-
);
211-
}
212-
213198
/**
214199
* This method should only be used for testing purposes to get the current running task.
215200
*/
@@ -270,11 +255,6 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
270255
);
271256
}
272257

273-
/**
274-
* This action is used to broadcast to all the nodes that the authorization task executor should start or stop.
275-
* This is specifically useful for CCM, since whether to do the polling depends on the CCM
276-
* configuration to exist first.
277-
*/
278258
public static class Action extends BroadcastMessageAction<Message> {
279259
public static final String NAME = "cluster:internal/xpack/inference/update_authorization_task";
280260
public static final ActionType<Response> INSTANCE = new ActionType<>(NAME);
@@ -295,9 +275,7 @@ public Action(
295275
@Override
296276
protected void receiveMessage(Message message) {
297277
if (message.enable()) {
298-
authorizationTaskExecutor.startAndImmediatelyCreateTask();
299-
} else {
300-
authorizationTaskExecutor.stop();
278+
authorizationTaskExecutor.sendStartRequestWithCurrentClusterState();
301279
}
302280
}
303281
}

0 commit comments

Comments
 (0)