Skip to content

Commit a9f706f

Browse files
Successful task creation
1 parent eb5df36 commit a9f706f

File tree

7 files changed

+128
-38
lines changed

7 files changed

+128
-38
lines changed

x-pack/plugin/inference/src/main/java/module-info.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
requires org.elasticsearch.sslconfig;
3838
requires org.apache.commons.text;
3939
requires software.amazon.awssdk.services.sagemakerruntime;
40-
requires org.elasticsearch.inference;
4140

4241
exports org.elasticsearch.xpack.inference.action;
4342
exports org.elasticsearch.xpack.inference.registry;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,19 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.lucene.util.SetOnce;
1313
import org.elasticsearch.action.support.MappedActionFilter;
14+
import org.elasticsearch.client.internal.Client;
1415
import org.elasticsearch.cluster.NamedDiff;
1516
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
1617
import org.elasticsearch.cluster.metadata.Metadata;
1718
import org.elasticsearch.cluster.node.DiscoveryNodes;
19+
import org.elasticsearch.cluster.service.ClusterService;
1820
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1921
import org.elasticsearch.common.settings.ClusterSettings;
2022
import org.elasticsearch.common.settings.IndexScopedSettings;
2123
import org.elasticsearch.common.settings.Setting;
2224
import org.elasticsearch.common.settings.Settings;
2325
import org.elasticsearch.common.settings.SettingsFilter;
26+
import org.elasticsearch.common.settings.SettingsModule;
2427
import org.elasticsearch.common.util.LazyInitializable;
2528
import org.elasticsearch.core.IOUtils;
2629
import org.elasticsearch.core.TimeValue;
@@ -36,10 +39,12 @@
3639
import org.elasticsearch.license.LicensedFeature;
3740
import org.elasticsearch.license.XPackLicenseState;
3841
import org.elasticsearch.node.PluginComponentBinding;
42+
import org.elasticsearch.persistent.PersistentTasksExecutor;
3943
import org.elasticsearch.plugins.ActionPlugin;
4044
import org.elasticsearch.plugins.ClusterPlugin;
4145
import org.elasticsearch.plugins.ExtensiblePlugin;
4246
import org.elasticsearch.plugins.MapperPlugin;
47+
import org.elasticsearch.plugins.PersistentTaskPlugin;
4348
import org.elasticsearch.plugins.Plugin;
4449
import org.elasticsearch.plugins.SearchPlugin;
4550
import org.elasticsearch.plugins.SystemIndexPlugin;
@@ -79,6 +84,7 @@
7984
import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy;
8085
import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction;
8186
import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction;
87+
import org.elasticsearch.xpack.inference.action.TransportStoreEndpointsAction;
8288
import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction;
8389
import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction;
8490
import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter;
@@ -110,6 +116,7 @@
110116
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;
111117
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
112118
import org.elasticsearch.xpack.inference.registry.ModelRegistryMetadata;
119+
import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction;
113120
import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction;
114121
import org.elasticsearch.xpack.inference.rest.RestGetInferenceDiagnosticsAction;
115122
import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction;
@@ -175,7 +182,8 @@ public class InferencePlugin extends Plugin
175182
MapperPlugin,
176183
SearchPlugin,
177184
InternalSearchPlugin,
178-
ClusterPlugin {
185+
ClusterPlugin,
186+
PersistentTaskPlugin {
179187

180188
/**
181189
* When this setting is true the verification check that
@@ -227,6 +235,7 @@ public class InferencePlugin extends Plugin
227235
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
228236
private final SetOnce<ModelRegistry> modelRegistry = new SetOnce<>();
229237
private List<InferenceServiceExtension> inferenceServiceExtensions;
238+
private final SetOnce<AuthorizationTaskExecutor> authorizationTaskExecutorRef = new SetOnce<>();
230239

231240
public InferencePlugin(Settings settings) {
232241
this.settings = settings;
@@ -246,7 +255,8 @@ public List<ActionHandler> getActions() {
246255
new ActionHandler(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class),
247256
new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class),
248257
new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class),
249-
new ActionHandler(ClearInferenceEndpointCacheAction.INSTANCE, ClearInferenceEndpointCacheAction.class)
258+
new ActionHandler(ClearInferenceEndpointCacheAction.INSTANCE, ClearInferenceEndpointCacheAction.class),
259+
new ActionHandler(StoreInferenceEndpointsAction.INSTANCE, TransportStoreEndpointsAction.class)
250260
);
251261
}
252262

@@ -337,10 +347,12 @@ public Collection<?> createComponents(PluginServices services) {
337347
elasicInferenceServiceFactory.get().createSender(),
338348
inferenceServiceSettings,
339349
eisComponents,
340-
modelRegistry.get()
350+
modelRegistry.get(),
351+
services.client()
341352
)
342353
);
343354
authTaskExecutor.init();
355+
authorizationTaskExecutorRef.set(authTaskExecutor);
344356

345357
var sageMakerSchemas = new SageMakerSchemas();
346358
var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas));
@@ -421,6 +433,17 @@ public Collection<?> createComponents(PluginServices services) {
421433
return components;
422434
}
423435

436+
@Override
437+
public List<PersistentTasksExecutor<?>> getPersistentTasksExecutor(
438+
ClusterService clusterService,
439+
ThreadPool threadPool,
440+
Client client,
441+
SettingsModule settingsModule,
442+
IndexNameExpressionResolver expressionResolver
443+
) {
444+
return List.of(authorizationTaskExecutorRef.get());
445+
}
446+
424447
@Override
425448
public void loadExtensions(ExtensionLoader loader) {
426449
inferenceServiceExtensions = loader.loadExtensions(InferenceServiceExtension.class);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportCreateEndpointsAction.java

Lines changed: 0 additions & 15 deletions
This file was deleted.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.action;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.ActionFilters;
12+
import org.elasticsearch.action.support.SubscribableListener;
13+
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
14+
import org.elasticsearch.cluster.ClusterState;
15+
import org.elasticsearch.cluster.block.ClusterBlockException;
16+
import org.elasticsearch.cluster.block.ClusterBlockLevel;
17+
import org.elasticsearch.cluster.service.ClusterService;
18+
import org.elasticsearch.common.util.concurrent.EsExecutors;
19+
import org.elasticsearch.injection.guice.Inject;
20+
import org.elasticsearch.tasks.Task;
21+
import org.elasticsearch.threadpool.ThreadPool;
22+
import org.elasticsearch.transport.TransportService;
23+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
24+
import org.elasticsearch.xpack.inference.registry.ModelStoreResponse;
25+
import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction;
26+
27+
import java.util.List;
28+
import java.util.Objects;
29+
30+
/**
31+
* Handles the internal action for creating multiple inference endpoints. This should not be used by external REST APIs.
32+
*/
33+
public class TransportStoreEndpointsAction extends TransportMasterNodeAction<
34+
StoreInferenceEndpointsAction.Request,
35+
StoreInferenceEndpointsAction.Response> {
36+
37+
private final ModelRegistry modelRegistry;
38+
39+
@Inject
40+
public TransportStoreEndpointsAction(
41+
TransportService transportService,
42+
ClusterService clusterService,
43+
ThreadPool threadPool,
44+
ActionFilters actionFilters,
45+
ModelRegistry modelRegistry
46+
) {
47+
super(
48+
StoreInferenceEndpointsAction.NAME,
49+
transportService,
50+
clusterService,
51+
threadPool,
52+
actionFilters,
53+
StoreInferenceEndpointsAction.Request::new,
54+
StoreInferenceEndpointsAction.Response::new,
55+
EsExecutors.DIRECT_EXECUTOR_SERVICE
56+
);
57+
58+
this.modelRegistry = Objects.requireNonNull(modelRegistry);
59+
}
60+
61+
@Override
62+
protected void masterOperation(
63+
Task task,
64+
StoreInferenceEndpointsAction.Request request,
65+
ClusterState state,
66+
ActionListener<StoreInferenceEndpointsAction.Response> masterListener
67+
) {
68+
SubscribableListener.<List<ModelStoreResponse>>newForked(
69+
listener -> modelRegistry.storeModels(request.getModels(), listener, request.masterNodeTimeout())
70+
).andThenApply(StoreInferenceEndpointsAction.Response::new).addListener(masterListener);
71+
}
72+
73+
@Override
74+
protected ClusterBlockException checkBlock(StoreInferenceEndpointsAction.Request request, ClusterState state) {
75+
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
76+
}
77+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -688,28 +688,22 @@ private void storeModel(Model model, boolean updateClusterState, ActionListener<
688688
"Inference endpoint [{}] already exists",
689689
RestStatus.BAD_REQUEST,
690690
failureItem.failureCause(),
691-
failureItem.inferenceId
691+
failureItem.inferenceId()
692692
)
693693
);
694694
return;
695695
}
696696

697697
delegate.onFailure(
698698
new ElasticsearchStatusException(
699-
format("Failed to store inference endpoint [%s]", failureItem.inferenceId),
699+
format("Failed to store inference endpoint [%s]", failureItem.inferenceId()),
700700
RestStatus.INTERNAL_SERVER_ERROR,
701701
failureItem.failureCause()
702702
)
703703
);
704704
}), timeout);
705705
}
706706

707-
public record ModelStoreResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) {
708-
public boolean failed() {
709-
return failureCause != null;
710-
}
711-
}
712-
713707
public void storeModels(List<Model> models, ActionListener<List<ModelStoreResponse>> listener, TimeValue timeout) {
714708
storeModels(models, true, listener, timeout);
715709
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public record MinimalModel(
5353
private static final ElasticInferenceServiceCompletionServiceSettings COMPLETION_SERVICE_SETTINGS =
5454
new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
5555
private static final ElasticInferenceServiceSparseEmbeddingsServiceSettings SPARSE_EMBEDDINGS_SERVICE_SETTINGS =
56-
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_RERANK_MODEL_ID_V1, null);
56+
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null);
5757
private static final ElasticInferenceServiceDenseTextEmbeddingsServiceSettings DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS =
5858
new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
5959
DEFAULT_MULTILINGUAL_EMBED_MODEL_ID,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,24 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.action.support.SubscribableListener;
14+
import org.elasticsearch.client.internal.Client;
15+
import org.elasticsearch.client.internal.OriginSettingClient;
1416
import org.elasticsearch.common.Randomness;
1517
import org.elasticsearch.common.Strings;
1618
import org.elasticsearch.core.TimeValue;
1719
import org.elasticsearch.persistent.AllocatedPersistentTask;
1820
import org.elasticsearch.tasks.TaskId;
1921
import org.elasticsearch.threadpool.Scheduler;
22+
import org.elasticsearch.xpack.core.ClientHelper;
2023
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2124
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
25+
import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction;
2226
import org.elasticsearch.xpack.inference.services.ServiceComponents;
2327
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
2428
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
2529
import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints;
2630

2731
import java.util.EnumSet;
28-
import java.util.List;
2932
import java.util.Map;
3033
import java.util.Objects;
3134
import java.util.Set;
@@ -55,6 +58,7 @@ public class AuthorizationPoller extends AllocatedPersistentTask {
5558
private final ElasticInferenceServiceSettings elasticInferenceServiceSettings;
5659
private final AtomicBoolean initialized = new AtomicBoolean(false);
5760
private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;
61+
private final Client client;
5862

5963
public record TaskFields(long id, String type, String action, String description, TaskId parentTask, Map<String, String> headers) {}
6064

@@ -64,7 +68,8 @@ public record Parameters(
6468
Sender sender,
6569
ElasticInferenceServiceSettings elasticInferenceServiceSettings,
6670
ElasticInferenceServiceComponents components,
67-
ModelRegistry modelRegistry
71+
ModelRegistry modelRegistry,
72+
Client client
6873
) {}
6974

7075
public AuthorizationPoller(TaskFields taskFields, Parameters parameters) {
@@ -76,6 +81,7 @@ public AuthorizationPoller(TaskFields taskFields, Parameters parameters) {
7681
parameters.elasticInferenceServiceSettings,
7782
parameters.components,
7883
parameters.modelRegistry,
84+
parameters.client,
7985
null
8086
);
8187
}
@@ -89,6 +95,7 @@ public AuthorizationPoller(TaskFields taskFields, Parameters parameters) {
8995
ElasticInferenceServiceSettings elasticInferenceServiceSettings,
9096
ElasticInferenceServiceComponents components,
9197
ModelRegistry modelRegistry,
98+
Client client,
9299
// this is a hack to facilitate testing
93100
Runnable callback
94101
) {
@@ -99,6 +106,7 @@ public AuthorizationPoller(TaskFields taskFields, Parameters parameters) {
99106
this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings);
100107
this.elasticInferenceServiceComponents = Objects.requireNonNull(components);
101108
this.modelRegistry = Objects.requireNonNull(modelRegistry);
109+
this.client = new OriginSettingClient(Objects.requireNonNull(client), ClientHelper.INFERENCE_ORIGIN);
102110
this.callback = callback;
103111
}
104112

@@ -184,7 +192,10 @@ private void sendAuthorizationRequest() {
184192
callback.run();
185193
}
186194
firstAuthorizationCompletedLatch.countDown();
187-
}).delegateResponse((delegate, e) -> logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints"));
195+
}).delegateResponse((delegate, e) -> {
196+
logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints");
197+
delegate.onResponse(null);
198+
});
188199

189200
SubscribableListener.<ElasticInferenceServiceAuthorizationModel>newForked(
190201
authModelListener -> authorizationHandler.getAuthorization(authModelListener, sender)
@@ -216,11 +227,12 @@ private void storePreconfiguredModels(Set<String> newInferenceIds, ActionListene
216227
return;
217228
}
218229

230+
logger.debug("Storing new EIS preconfigured inference endpoints with inference IDs {}", newInferenceIds);
219231
var modelsToAdd = PreconfiguredEndpointModelAdapter.getModels(newInferenceIds, elasticInferenceServiceComponents);
232+
var storeRequest = new StoreInferenceEndpointsAction.Request(modelsToAdd, TimeValue.THIRTY_SECONDS);
220233

221-
// TODO
222-
ActionListener<List<ModelRegistry.ModelStoreResponse>> storeListener = ActionListener.wrap(responses -> {
223-
for (var response : responses) {
234+
ActionListener<StoreInferenceEndpointsAction.Response> storeListener = ActionListener.wrap(responses -> {
235+
for (var response : responses.getResults()) {
224236
if (response.failed()) {
225237
logger.atWarn()
226238
.withThrowable(response.failureCause())
@@ -232,10 +244,10 @@ private void storePreconfiguredModels(Set<String> newInferenceIds, ActionListene
232244
}
233245
}, e -> logger.atWarn().withThrowable(e).log("Failed to store new EIS preconfigured inference endpoints [{}]", newInferenceIds));
234246

235-
modelRegistry.storeModels(
236-
modelsToAdd,
237-
ActionListener.runAfter(storeListener, () -> listener.onResponse(null)),
238-
TimeValue.THIRTY_SECONDS
247+
client.execute(
248+
StoreInferenceEndpointsAction.INSTANCE,
249+
storeRequest,
250+
ActionListener.runAfter(storeListener, () -> listener.onResponse(null))
239251
);
240252
}
241253
}

0 commit comments

Comments
 (0)