From c6fc960cfead2545833c6141e33b1bab18a357f0 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 16 Oct 2025 14:28:14 -0400 Subject: [PATCH 01/32] Creating new cluster state listener to kick off polling logic --- .../xpack/inference/InferencePlugin.java | 20 +- .../inference/registry/ModelRegistry.java | 13 + .../registry/ModelRegistryMetadata.java | 4 + .../elastic/ElasticInferenceService.java | 4 +- .../elastic/ElasticInferenceServiceModel.java | 2 +- .../InternalPreconfiguredEndpoints.java | 141 +++++++++++ .../AuthorizationInitializer.java | 44 ++++ ...nferenceServiceAuthorizationHandlerV2.java | 227 ++++++++++++++++++ .../PreconfiguredEndpointModelAdapter.java | 43 ++++ 9 files changed, 493 insertions(+), 5 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationInitializer.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 60592c5dd1dbd..c51f3d3cda211 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -133,6 +133,8 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationInitializer; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandlerV2; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; @@ -223,6 +225,7 @@ public class InferencePlugin extends Plugin private final SetOnce inferenceServiceRegistry = new SetOnce<>(); private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private final SetOnce modelRegistry = new SetOnce<>(); + private final SetOnce eisAuthorizationHandler = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -322,6 +325,19 @@ public Collection createComponents(PluginServices services) { services.threadPool() ); + var eisComponents = new ElasticInferenceServiceComponents(inferenceServiceSettings.getElasticInferenceServiceUrl()); + var eisAuthPoller = new ElasticInferenceServiceAuthorizationHandlerV2( + serviceComponents.get(), + authorizationHandler, + elasicInferenceServiceFactory.get().createSender(), + inferenceServiceSettings, + eisComponents, + modelRegistry.get() + ); + var eisAuthInitializer = new AuthorizationInitializer(eisAuthPoller); + services.clusterService().addListener(eisAuthInitializer); + eisAuthorizationHandler.set(eisAuthPoller); + var sageMakerSchemas = new SageMakerSchemas(); var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas)); inferenceServices.add( @@ -595,7 +611,7 @@ public void close() { var serviceComponentsRef = serviceComponents.get(); var throttlerToClose = serviceComponentsRef != null ? serviceComponentsRef.throttlerManager() : null; - IOUtils.closeWhileHandlingException(inferenceServiceRegistry.get(), throttlerToClose); + IOUtils.closeWhileHandlingException(inferenceServiceRegistry.get(), throttlerToClose, eisAuthorizationHandler.get()); } @Override @@ -605,7 +621,7 @@ public Map getMetadataMappers() { // Overridable for tests protected Supplier getModelRegistry() { - return () -> modelRegistry.get(); + return modelRegistry::get; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index e556b1db9ecd8..47a919c5d9b96 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -243,6 +243,19 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId return existing; } + public Set getInferenceIds() { + synchronized (this) { + if (lastMetadata == null) { + throw new IllegalStateException("initial cluster state not set yet"); + } + } + var project = lastMetadata.getProject(ProjectId.DEFAULT); + var state = ModelRegistryMetadata.fromState(project); + var ids = new HashSet<>(state.getInferenceIds()); + ids.addAll(Set.copyOf(defaultConfigIds.keySet())); + return ids; + } + /** * Get a model with its secret settings * @param inferenceEntityId Model to get diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java index 4bf23103af5a1..9ba6bf38416c9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java @@ -225,6 +225,10 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId return modelMap.get(inferenceEntityId); } + public Set getInferenceIds() { + return Set.copyOf(modelMap.keySet()); + } + @Override public Diff diff(Metadata.ProjectCustom before) { return new ModelRegistryMetadataDiff((ModelRegistryMetadata) before, this); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 5d476955a7ad6..402cc370957f6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -95,7 +95,7 @@ public class ElasticInferenceService extends SenderService { // A batch size of 16 provides optimal throughput and stability, especially on lower-tier instance types. public static final Integer SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 16; - private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( + public static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, @@ -255,7 +255,7 @@ private static Map initDefaultEndpoints( @Override public void onNodeStarted() { - authorizationHandler.init(); +// authorizationHandler.init(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java index 34a8086119150..ccf776f5db597 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java @@ -15,7 +15,7 @@ import java.util.Objects; -public abstract class ElasticInferenceServiceModel extends RateLimitGroupingModel { +public class ElasticInferenceServiceModel extends RateLimitGroupingModel { private final ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java new file mode 100644 index 0000000000000..869fb6a43a54f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; + +import java.util.Map; +import java.util.Set; + +import static java.util.stream.Collectors.toMap; + +/** + * Represents the preconfigured endpoints that are included in Elasticsearch. EIS will support dynamic preconfigured endpoints which means + * it can provide new preconfigured endpoints that do not exist in the source here. + */ +public class InternalPreconfiguredEndpoints { + + // rainbow-sprinkles + public static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; + public static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); + + // elser-2 + public static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; + public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2"); + + // multilingual-text-embed + public static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024; + public static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed-v1"; + public static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); + + // rerank-v1 + public static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; + public static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); + + public record MinimalModel( + ModelConfigurations configurations, + ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings + ) {} + + private static final ElasticInferenceServiceCompletionServiceSettings COMPLETION_SERVICE_SETTINGS = + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); + private static final ElasticInferenceServiceSparseEmbeddingsServiceSettings SPARSE_EMBEDDINGS_SERVICE_SETTINGS = + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_RERANK_MODEL_ID_V1, null); + private static final ElasticInferenceServiceDenseTextEmbeddingsServiceSettings DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS = + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + defaultDenseTextEmbeddingsSimilarity(), + DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + null + ); + private static final ElasticInferenceServiceRerankServiceSettings RERANK_SERVICE_SETTINGS = + new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1); + + private static final Map MODEL_NAME_TO_MINIMAL_MODEL = Map.of( + DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, + new MinimalModel( + new ModelConfigurations( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + COMPLETION_SERVICE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + COMPLETION_SERVICE_SETTINGS + ), + DEFAULT_ELSER_2_MODEL_ID, + new MinimalModel( + new ModelConfigurations( + DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + SPARSE_EMBEDDINGS_SERVICE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + SPARSE_EMBEDDINGS_SERVICE_SETTINGS + ), + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + new MinimalModel( + new ModelConfigurations( + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS + ), + DEFAULT_RERANK_MODEL_ID_V1, + new MinimalModel( + new ModelConfigurations( + DEFAULT_RERANK_ENDPOINT_ID_V1, + TaskType.RERANK, + ElasticInferenceService.NAME, + RERANK_SERVICE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + RERANK_SERVICE_SETTINGS + ) + ); + + private static final Map INFERENCE_ID_TO_MINIMAL_MODEL = MODEL_NAME_TO_MINIMAL_MODEL.entrySet() + .stream() + .collect(toMap(e -> e.getValue().configurations().getInferenceEntityId(), Map.Entry::getValue)); + + public static final Set EIS_PRECONFIGURED_ENDPOINT_IDS = Set.copyOf(INFERENCE_ID_TO_MINIMAL_MODEL.keySet()); + + public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { + return SimilarityMeasure.COSINE; + } + + public static String defaultEndpointId(String modelId) { + return Strings.format(".%s-elastic", modelId); + } + + public static boolean containsModelName(String modelName) { + return MODEL_NAME_TO_MINIMAL_MODEL.containsKey(modelName); + } + + public static MinimalModel getWithModelName(String modelName) { + return MODEL_NAME_TO_MINIMAL_MODEL.get(modelName); + } + + public static MinimalModel getWithInferenceId(String inferenceId) { + return INFERENCE_ID_TO_MINIMAL_MODEL.get(inferenceId); + } + + private InternalPreconfiguredEndpoints() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationInitializer.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationInitializer.java new file mode 100644 index 0000000000000..341890e86cb09 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationInitializer.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.gateway.GatewayService; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Waits for the cluster state to be recovered before initializing the authorization handler. + */ +public class AuthorizationInitializer implements ClusterStateListener { + + private final ElasticInferenceServiceAuthorizationHandlerV2 authorizationHandler; + private final AtomicBoolean initializedAuthorization = new AtomicBoolean(false); + + public AuthorizationInitializer(ElasticInferenceServiceAuthorizationHandlerV2 authorizationHandler) { + this.authorizationHandler = Objects.requireNonNull(authorizationHandler); + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + if (event.localNodeMaster() == false) { + return; + } + + // wait for the cluster state to be recovered + if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) { + return; + } + + if (initializedAuthorization.compareAndSet(false, true)) { + authorizationHandler.init(); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java new file mode 100644 index 0000000000000..700d7d3c5c675 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java @@ -0,0 +1,227 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.threadpool.Scheduler; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; + +import java.io.Closeable; +import java.io.IOException; +import java.util.EnumSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.IMPLEMENTED_TASK_TYPES; + +public class ElasticInferenceServiceAuthorizationHandlerV2 implements Closeable { + private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandlerV2.class); + + private final ServiceComponents serviceComponents; + private final ModelRegistry modelRegistry; + private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler; + private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1); + private final Sender sender; + private final Runnable callback; + private final AtomicReference lastAuthTask = new AtomicReference<>(null); + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final ElasticInferenceServiceSettings elasticInferenceServiceSettings; + private final AtomicBoolean initialized = new AtomicBoolean(false); + private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; + + public ElasticInferenceServiceAuthorizationHandlerV2( + ServiceComponents serviceComponents, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + Sender sender, + ElasticInferenceServiceSettings elasticInferenceServiceSettings, + ElasticInferenceServiceComponents components, + ModelRegistry modelRegistry + ) { + this( + serviceComponents, + authorizationRequestHandler, + sender, + elasticInferenceServiceSettings, + components, + modelRegistry, + null + ); + } + + // default for testing + ElasticInferenceServiceAuthorizationHandlerV2( + ServiceComponents serviceComponents, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + Sender sender, + ElasticInferenceServiceSettings elasticInferenceServiceSettings, + ElasticInferenceServiceComponents components, + ModelRegistry modelRegistry, + // this is a hack to facilitate testing + Runnable callback + ) { + this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); + this.sender = Objects.requireNonNull(sender); + this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); + this.elasticInferenceServiceComponents = Objects.requireNonNull(components); + this.modelRegistry = Objects.requireNonNull(modelRegistry); + this.callback = callback; + } + + public void init() { + if (initialized.compareAndSet(false, true)) { + logger.debug("Initializing authorization logic"); + serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest); + } + } + + /** + * Waits the specified amount of time for the first authorization call to complete. This is mainly to make testing easier. + * @param waitTime the max time to wait + * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} + */ + public void waitForAuthorizationToComplete(TimeValue waitTime) { + try { + if (firstAuthorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { + throw new IllegalStateException("The wait time has expired for authorization to complete."); + } + } catch (InterruptedException e) { + throw new IllegalStateException("Waiting for authorization to complete was interrupted"); + } + } + + @Override + public void close() throws IOException { + shutdown.set(true); + if (lastAuthTask.get() != null) { + lastAuthTask.get().cancel(); + } + } + + private void scheduleAuthorizationRequest() { + try { + if (elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled() == false) { + return; + } + + // this call has to be on the individual thread otherwise we get an exception + var random = Randomness.get(); + var jitter = (long) (elasticInferenceServiceSettings.getMaxAuthorizationRequestJitter().millis() * random.nextDouble()); + var waitTime = TimeValue.timeValueMillis(elasticInferenceServiceSettings.getAuthRequestInterval().millis() + jitter); + + logger.debug( + () -> Strings.format( + "Scheduling the next authorization call with request interval: %s ms, jitter: %d ms", + elasticInferenceServiceSettings.getAuthRequestInterval().millis(), + jitter + ) + ); + logger.debug(() -> Strings.format("Next authorization call in %d minutes", waitTime.getMinutes())); + + lastAuthTask.set( + serviceComponents.threadPool() + .schedule( + this::scheduleAndSendAuthorizationRequest, + waitTime, + serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME) + ) + ); + } catch (Exception e) { + logger.warn("Failed scheduling authorization request", e); + } + } + + private void scheduleAndSendAuthorizationRequest() { + if (shutdown.get()) { + return; + } + + scheduleAuthorizationRequest(); + sendAuthorizationRequest(); + } + + private void sendAuthorizationRequest() { + var finalListener = ActionListener.running(() -> { + if (callback != null) { + callback.run(); + } + firstAuthorizationCompletedLatch.countDown(); + }).delegateResponse((delegate, e) -> { + logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints"); + }); + + SubscribableListener.newForked( + authModelListener -> authorizationHandler.getAuthorization(authModelListener, sender) + ) + .andThenApply(this::getNewInferenceEndpointsToStore) + .andThen((storeListener, newInferenceIds) -> storePreconfiguredModels(newInferenceIds, storeListener)) + .addListener(finalListener); + } + + private Set getNewInferenceEndpointsToStore(ElasticInferenceServiceAuthorizationModel authModel) { + var scopedAuthModel = authModel.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); + + var authorizedModelIds = scopedAuthModel.getAuthorizedModelIds(); + var existingInferenceIds = modelRegistry.getInferenceIds(); + + var newInferenceIds = authorizedModelIds.stream() + .map(InternalPreconfiguredEndpoints::getWithModelName) + .filter(Objects::nonNull) + .map(model -> model.configurations().getInferenceEntityId()) + .collect(Collectors.toSet()); + + newInferenceIds.removeAll(existingInferenceIds); + return newInferenceIds; + } + + private void storePreconfiguredModels(Set newInferenceIds, ActionListener listener) { + if (newInferenceIds.isEmpty()) { + listener.onResponse(null); + return; + } + + var modelsToAdd = PreconfiguredEndpointModelAdapter.getModels(newInferenceIds, elasticInferenceServiceComponents); + + ActionListener> storeListener = ActionListener.wrap(responses -> { + for (var response : responses) { + if (response.failed()) { + logger.atWarn() + .withThrowable(response.failureCause()) + .log("Failed to store new EIS preconfigured inference endpoint with inference ID [{}]", response.inferenceId()); + } else { + logger.atInfo() + .log("Successfully stored EIS preconfigured inference endpoint with inference ID [{}]", response.inferenceId()); + } + } + }, e -> logger.atWarn().withThrowable(e).log("Failed to store new EIS preconfigured inference endpoints [{}]", newInferenceIds)); + + modelRegistry.storeModels( + modelsToAdd, + ActionListener.runAfter(storeListener, () -> listener.onResponse(null)), + TimeValue.THIRTY_SECONDS + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java new file mode 100644 index 0000000000000..21ad626003c17 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; + +import java.util.List; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS; + +public class PreconfiguredEndpointModelAdapter { + public static List getModels(Set inferenceIds, ElasticInferenceServiceComponents elasticInferenceServiceComponents) { + return inferenceIds.stream() + .filter(EIS_PRECONFIGURED_ENDPOINT_IDS::contains) + .map(id -> createModel(InternalPreconfiguredEndpoints.getWithInferenceId(id), elasticInferenceServiceComponents)) + .toList(); + } + + private static Model createModel( + InternalPreconfiguredEndpoints.MinimalModel minimalModel, + ElasticInferenceServiceComponents elasticInferenceServiceComponents + ) { + return new ElasticInferenceServiceModel( + minimalModel.configurations(), + new ModelSecrets(EmptySecretSettings.INSTANCE), + minimalModel.rateLimitServiceSettings(), + elasticInferenceServiceComponents + ); + } + + private PreconfiguredEndpointModelAdapter() {} +} From ecfe8858a5d835f24bc0847aae714938e208e849 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Thu, 16 Oct 2025 14:33:32 -0400 Subject: [PATCH 02/32] Update docs/changelog/136713.yaml --- docs/changelog/136713.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/136713.yaml diff --git a/docs/changelog/136713.yaml b/docs/changelog/136713.yaml new file mode 100644 index 0000000000000..9b88a8aed1111 --- /dev/null +++ b/docs/changelog/136713.yaml @@ -0,0 +1,5 @@ +pr: 136713 +summary: Transition EIS auth polling to master node +area: Machine Learning +type: enhancement +issues: [] From b6928b602227bdda162948b5d7616a68bc1d9e86 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 16 Oct 2025 18:40:10 +0000 Subject: [PATCH 03/32] [CI] Auto commit changes from spotless --- .../services/elastic/ElasticInferenceService.java | 2 +- ...sticInferenceServiceAuthorizationHandlerV2.java | 14 ++------------ 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 402cc370957f6..213bb10e14187 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -255,7 +255,7 @@ private static Map initDefaultEndpoints( @Override public void onNodeStarted() { -// authorizationHandler.init(); + // authorizationHandler.init(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java index 700d7d3c5c675..b17c4b997f549 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java @@ -60,15 +60,7 @@ public ElasticInferenceServiceAuthorizationHandlerV2( ElasticInferenceServiceComponents components, ModelRegistry modelRegistry ) { - this( - serviceComponents, - authorizationRequestHandler, - sender, - elasticInferenceServiceSettings, - components, - modelRegistry, - null - ); + this(serviceComponents, authorizationRequestHandler, sender, elasticInferenceServiceSettings, components, modelRegistry, null); } // default for testing @@ -169,9 +161,7 @@ private void sendAuthorizationRequest() { callback.run(); } firstAuthorizationCompletedLatch.countDown(); - }).delegateResponse((delegate, e) -> { - logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints"); - }); + }).delegateResponse((delegate, e) -> { logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints"); }); SubscribableListener.newForked( authModelListener -> authorizationHandler.getAuthorization(authModelListener, sender) From 59ce7d0262ab3082ef1bdb32442a6daf17d9aa42 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 17 Oct 2025 13:20:59 -0400 Subject: [PATCH 04/32] Starting persistent tasks --- .../AuthorizationTaskExecutor.java | 12 ++++ .../AuthorizationTaskParams.java | 67 +++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java new file mode 100644 index 0000000000000..dbf20188600aa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +public class AuthorizationTaskExecutor { + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java new file mode 100644 index 0000000000000..ce854aa7787d5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.health.node.selection.HealthNodeTaskParams; +import org.elasticsearch.persistent.PersistentTaskParams; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +import static org.elasticsearch.health.node.selection.HealthNode.TASK_NAME; + +public class AuthorizationTaskParams implements PersistentTaskParams { + private static final AuthorizationTaskParams INSTANCE = new AuthorizationTaskParams(); + + private static final ObjectParser PARSER = new ObjectParser<>(TASK_NAME, true, () -> INSTANCE); + + AuthorizationTaskParams() {} + + AuthorizationTaskParams(StreamInput in) {} + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return TASK_NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_5_0; + } + + @Override + public void writeTo(StreamOutput out) {} + + public static HealthNodeTaskParams fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object obj) { + return obj instanceof HealthNodeTaskParams; + } +} From 273d5516b764dddd13c1f8292bf490f6be6fbe98 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 20 Oct 2025 15:24:12 -0400 Subject: [PATCH 05/32] Switching to a persistent task, need to create the action though --- .../inference/src/main/java/module-info.java | 1 + .../xpack/inference/InferencePlugin.java | 119 ++++++++--------- .../inference/registry/ModelRegistry.java | 13 ++ .../AuthorizationInitializer.java | 44 ------- ...andlerV2.java => AuthorizationPoller.java} | 46 +++++-- .../AuthorizationTaskExecutor.java | 123 +++++++++++++++++- .../AuthorizationTaskParams.java | 6 +- 7 files changed, 235 insertions(+), 117 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationInitializer.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/{ElasticInferenceServiceAuthorizationHandlerV2.java => AuthorizationPoller.java} (86%) diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index bd200fd88a706..a8e42b306e91f 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -36,6 +36,7 @@ requires org.elasticsearch.sslconfig; requires org.apache.commons.text; requires software.amazon.awssdk.services.sagemakerruntime; + requires org.elasticsearch.inference; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index c51f3d3cda211..159882ad0009c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -133,8 +133,8 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationInitializer; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandlerV2; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; @@ -162,6 +162,7 @@ import java.util.Set; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Stream; import static java.util.Collections.singletonList; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; @@ -225,7 +226,6 @@ public class InferencePlugin extends Plugin private final SetOnce inferenceServiceRegistry = new SetOnce<>(); private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private final SetOnce modelRegistry = new SetOnce<>(); - private final SetOnce eisAuthorizationHandler = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -326,17 +326,21 @@ public Collection createComponents(PluginServices services) { ); var eisComponents = new ElasticInferenceServiceComponents(inferenceServiceSettings.getElasticInferenceServiceUrl()); - var eisAuthPoller = new ElasticInferenceServiceAuthorizationHandlerV2( - serviceComponents.get(), - authorizationHandler, - elasicInferenceServiceFactory.get().createSender(), - inferenceServiceSettings, - eisComponents, - modelRegistry.get() + + var authTaskExecutor = new AuthorizationTaskExecutor( + services.client(), + services.clusterService(), + services.threadPool(), + new AuthorizationPoller.Parameters( + serviceComponents.get(), + authorizationHandler, + elasicInferenceServiceFactory.get().createSender(), + inferenceServiceSettings, + eisComponents, + modelRegistry.get() + ) ); - var eisAuthInitializer = new AuthorizationInitializer(eisAuthPoller); - services.clusterService().addListener(eisAuthInitializer); - eisAuthorizationHandler.set(eisAuthPoller); + authTaskExecutor.init(); var sageMakerSchemas = new SageMakerSchemas(); var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas)); @@ -412,6 +416,7 @@ public Collection createComponents(PluginServices services) { services.featureService() ) ); + components.add(authTaskExecutor); return components; } @@ -449,54 +454,52 @@ public List getInferenceServiceFactories() { @Override public List getNamedWriteables() { - var entries = new ArrayList<>(InferenceNamedWriteablesProvider.getNamedWriteables()); - entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new)); - entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new)); - entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new)); - entries.add(new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new)); - entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom)); - entries.add( - new NamedWriteableRegistry.Entry( - QueryBuilder.class, - InterceptedInferenceMatchQueryBuilder.NAME, - InterceptedInferenceMatchQueryBuilder::new - ) - ); - entries.add( - new NamedWriteableRegistry.Entry( - QueryBuilder.class, - InterceptedInferenceKnnVectorQueryBuilder.NAME, - InterceptedInferenceKnnVectorQueryBuilder::new - ) - ); - entries.add( - new NamedWriteableRegistry.Entry( - QueryBuilder.class, - InterceptedInferenceSparseVectorQueryBuilder.NAME, - InterceptedInferenceSparseVectorQueryBuilder::new - ) - ); - return entries; + return Stream.of( + List.of( + new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new), + new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new), + new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new), + new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new), + new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom), + new NamedWriteableRegistry.Entry( + QueryBuilder.class, + InterceptedInferenceMatchQueryBuilder.NAME, + InterceptedInferenceMatchQueryBuilder::new + ), + new NamedWriteableRegistry.Entry( + QueryBuilder.class, + InterceptedInferenceKnnVectorQueryBuilder.NAME, + InterceptedInferenceKnnVectorQueryBuilder::new + ), + new NamedWriteableRegistry.Entry( + QueryBuilder.class, + InterceptedInferenceSparseVectorQueryBuilder.NAME, + InterceptedInferenceSparseVectorQueryBuilder::new + ) + ), + InferenceNamedWriteablesProvider.getNamedWriteables(), + AuthorizationTaskExecutor.getNamedWriteables() + ).flatMap(List::stream).toList(); + } @Override public List getNamedXContent() { - List namedXContent = new ArrayList<>(); - namedXContent.add( - new NamedXContentRegistry.Entry( - Metadata.ProjectCustom.class, - new ParseField(ModelRegistryMetadata.TYPE), - ModelRegistryMetadata::fromXContent - ) - ); - namedXContent.add( - new NamedXContentRegistry.Entry( - Metadata.ProjectCustom.class, - new ParseField(ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME), - ClearInferenceEndpointCacheAction.InvalidateCacheMetadata::fromXContent - ) - ); - return namedXContent; + return Stream.of( + List.of( + new NamedXContentRegistry.Entry( + Metadata.ProjectCustom.class, + new ParseField(ModelRegistryMetadata.TYPE), + ModelRegistryMetadata::fromXContent + ), + new NamedXContentRegistry.Entry( + Metadata.ProjectCustom.class, + new ParseField(ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME), + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata::fromXContent + ) + ), + AuthorizationTaskExecutor.getNamedXContentParsers() + ).flatMap(List::stream).toList(); } @Override @@ -611,7 +614,7 @@ public void close() { var serviceComponentsRef = serviceComponents.get(); var throttlerToClose = serviceComponentsRef != null ? serviceComponentsRef.throttlerManager() : null; - IOUtils.closeWhileHandlingException(inferenceServiceRegistry.get(), throttlerToClose, eisAuthorizationHandler.get()); + IOUtils.closeWhileHandlingException(inferenceServiceRegistry.get(), throttlerToClose); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 47a919c5d9b96..a6f554fcb6423 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -87,6 +87,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -147,10 +148,12 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private final MasterServiceTaskQueue metadataTaskQueue; private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false); private final Set preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>()); + private final ClusterService clusterService; private volatile Metadata lastMetadata; public ModelRegistry(ClusterService clusterService, Client client) { + this.clusterService = Objects.requireNonNull(clusterService); this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); this.defaultConfigIds = new ConcurrentHashMap<>(); var executor = new SimpleBatchedAckListenerTaskExecutor() { @@ -954,6 +957,16 @@ private void updateClusterState(List models, ActionListener inferenceEntityIds, ActionListener listener) { if (inferenceEntityIds.isEmpty()) { listener.onResponse(true); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationInitializer.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationInitializer.java deleted file mode 100644 index 341890e86cb09..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationInitializer.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.elasticsearch.cluster.ClusterChangedEvent; -import org.elasticsearch.cluster.ClusterStateListener; -import org.elasticsearch.gateway.GatewayService; - -import java.util.Objects; -import java.util.concurrent.atomic.AtomicBoolean; - -/** - * Waits for the cluster state to be recovered before initializing the authorization handler. - */ -public class AuthorizationInitializer implements ClusterStateListener { - - private final ElasticInferenceServiceAuthorizationHandlerV2 authorizationHandler; - private final AtomicBoolean initializedAuthorization = new AtomicBoolean(false); - - public AuthorizationInitializer(ElasticInferenceServiceAuthorizationHandlerV2 authorizationHandler) { - this.authorizationHandler = Objects.requireNonNull(authorizationHandler); - } - - @Override - public void clusterChanged(ClusterChangedEvent event) { - if (event.localNodeMaster() == false) { - return; - } - - // wait for the cluster state to be recovered - if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) { - return; - } - - if (initializedAuthorization.compareAndSet(false, true)) { - authorizationHandler.init(); - } - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java similarity index 86% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index b17c4b997f549..9687f105926df 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerV2.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -14,6 +14,8 @@ import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -22,10 +24,9 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; -import java.io.Closeable; -import java.io.IOException; import java.util.EnumSet; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.concurrent.CountDownLatch; @@ -37,8 +38,11 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.IMPLEMENTED_TASK_TYPES; -public class ElasticInferenceServiceAuthorizationHandlerV2 implements Closeable { - private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandlerV2.class); +public class AuthorizationPoller extends AllocatedPersistentTask { + + public static final String TASK_NAME = "eis-authorization-poller"; + + private static final Logger logger = LogManager.getLogger(AuthorizationPoller.class); private final ServiceComponents serviceComponents; private final ModelRegistry modelRegistry; @@ -52,19 +56,33 @@ public class ElasticInferenceServiceAuthorizationHandlerV2 implements Closeable private final AtomicBoolean initialized = new AtomicBoolean(false); private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; - public ElasticInferenceServiceAuthorizationHandlerV2( + public record TaskFields(long id, String type, String action, String description, TaskId parentTask, Map headers) {} + + public record Parameters( ServiceComponents serviceComponents, ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, Sender sender, ElasticInferenceServiceSettings elasticInferenceServiceSettings, ElasticInferenceServiceComponents components, ModelRegistry modelRegistry - ) { - this(serviceComponents, authorizationRequestHandler, sender, elasticInferenceServiceSettings, components, modelRegistry, null); + ) {} + + public AuthorizationPoller(TaskFields taskFields, Parameters parameters) { + this( + taskFields, + parameters.serviceComponents, + parameters.authorizationRequestHandler, + parameters.sender, + parameters.elasticInferenceServiceSettings, + parameters.components, + parameters.modelRegistry, + null + ); } // default for testing - ElasticInferenceServiceAuthorizationHandlerV2( + AuthorizationPoller( + TaskFields taskFields, ServiceComponents serviceComponents, ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, Sender sender, @@ -74,6 +92,7 @@ public ElasticInferenceServiceAuthorizationHandlerV2( // this is a hack to facilitate testing Runnable callback ) { + super(taskFields.id, taskFields.type, taskFields.action, taskFields.description, taskFields.parentTask, taskFields.headers); this.serviceComponents = Objects.requireNonNull(serviceComponents); this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); this.sender = Objects.requireNonNull(sender); @@ -83,7 +102,7 @@ public ElasticInferenceServiceAuthorizationHandlerV2( this.callback = callback; } - public void init() { + public void start() { if (initialized.compareAndSet(false, true)) { logger.debug("Initializing authorization logic"); serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest); @@ -106,7 +125,7 @@ public void waitForAuthorizationToComplete(TimeValue waitTime) { } @Override - public void close() throws IOException { + protected void onCancelled() { shutdown.set(true); if (lastAuthTask.get() != null) { lastAuthTask.get().cancel(); @@ -156,12 +175,16 @@ private void scheduleAndSendAuthorizationRequest() { } private void sendAuthorizationRequest() { + if (modelRegistry.isReady() == false) { + return; + } + var finalListener = ActionListener.running(() -> { if (callback != null) { callback.run(); } firstAuthorizationCompletedLatch.countDown(); - }).delegateResponse((delegate, e) -> { logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints"); }); + }).delegateResponse((delegate, e) -> logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints")); SubscribableListener.newForked( authModelListener -> authorizationHandler.getAuthorization(authModelListener, sender) @@ -195,6 +218,7 @@ private void storePreconfiguredModels(Set newInferenceIds, ActionListene var modelsToAdd = PreconfiguredEndpointModelAdapter.getModels(newInferenceIds, elasticInferenceServiceComponents); + // TODO ActionListener> storeListener = ActionListener.wrap(responses -> { for (var response : responses) { if (response.failed()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java index dbf20188600aa..e1f58812ca40a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java @@ -7,6 +7,127 @@ package org.elasticsearch.xpack.inference.services.elastic.authorization; -public class AuthorizationTaskExecutor { +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.persistent.ClusterPersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTaskParams; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTasksExecutor; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.RemoteTransportException; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ParseField; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller.TASK_NAME; + +public class AuthorizationTaskExecutor extends PersistentTasksExecutor implements ClusterStateListener { + + private static final Logger logger = LogManager.getLogger(AuthorizationTaskExecutor.class); + + private final ClusterService clusterService; + private final PersistentTasksService persistentTasksService; + private final AuthorizationPoller.Parameters pollerParameters; + + public AuthorizationTaskExecutor( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + AuthorizationPoller.Parameters pollerParameters + ) { + super(TASK_NAME, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + this.clusterService = Objects.requireNonNull(clusterService); + this.persistentTasksService = new PersistentTasksService(clusterService, threadPool, client); + this.pollerParameters = Objects.requireNonNull(pollerParameters); + } + + public void init() { + clusterService.addListener(this); + } + + @Override + protected void nodeOperation(AllocatedPersistentTask task, AuthorizationTaskParams params, PersistentTaskState state) { + // TODO remove + logger.warn("Starting authorization poller task"); + var authPoller = (AuthorizationPoller) task; + authPoller.start(); + } + + @Override + public Scope scope() { + return Scope.CLUSTER; + } + + @Override + protected AuthorizationPoller createTask( + long id, + String type, + String action, + TaskId parentTaskId, + PersistentTasksCustomMetadata.PersistentTask taskInProgress, + Map headers + ) { + return new AuthorizationPoller( + new AuthorizationPoller.TaskFields(id, type, action, getDescription(taskInProgress), parentTaskId, headers), + pollerParameters + ); + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + if (authorizationTaskExists(event)) { + return; + } + + persistentTasksService.sendClusterStartRequest( + TASK_NAME, + TASK_NAME, + new AuthorizationTaskParams(), + TimeValue.THIRTY_SECONDS, + ActionListener.wrap(persistentTask -> { + logger.warn("Created authorization poller task"); + }, e -> { + var t = e instanceof RemoteTransportException ? e.getCause() : e; + if (t instanceof ResourceAlreadyExistsException == false) { + logger.error("Failed to create authorization poller task", e); + } + }) + ); + } + + private static boolean authorizationTaskExists(ClusterChangedEvent event) { + return ClusterPersistentTasksCustomMetadata.getTaskWithId(event.state(), TASK_NAME) != null; + } + + public static List getNamedXContentParsers() { + return List.of( + new NamedXContentRegistry.Entry( + PersistentTaskParams.class, + new ParseField(AuthorizationPoller.TASK_NAME), + AuthorizationTaskParams::fromXContent + ) + ); + } + + public static List getNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(PersistentTaskParams.class, AuthorizationPoller.TASK_NAME, AuthorizationTaskParams::new) + ); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java index ce854aa7787d5..38a9fced3607d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java @@ -20,12 +20,12 @@ import java.io.IOException; -import static org.elasticsearch.health.node.selection.HealthNode.TASK_NAME; +import static org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller.TASK_NAME; public class AuthorizationTaskParams implements PersistentTaskParams { private static final AuthorizationTaskParams INSTANCE = new AuthorizationTaskParams(); - private static final ObjectParser PARSER = new ObjectParser<>(TASK_NAME, true, () -> INSTANCE); + private static final ObjectParser PARSER = new ObjectParser<>(TASK_NAME, true, () -> INSTANCE); AuthorizationTaskParams() {} @@ -51,7 +51,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) {} - public static HealthNodeTaskParams fromXContent(XContentParser parser) { + public static AuthorizationTaskParams fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } From 458664eddfbeb6bc4c47d902924be91f6bd1456c Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 21 Oct 2025 09:48:50 -0400 Subject: [PATCH 06/32] Adding master action --- .../org/elasticsearch/inference/Model.java | 18 ++++- .../TransportCreateEndpointsAction.java | 15 ++++ .../registry/ModelStoreResponse.java | 38 +++++++++ .../StoreInferenceEndpointsAction.java | 81 +++++++++++++++++++ 4 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportCreateEndpointsAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelStoreResponse.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/StoreInferenceEndpointsAction.java diff --git a/server/src/main/java/org/elasticsearch/inference/Model.java b/server/src/main/java/org/elasticsearch/inference/Model.java index 87744fbd09574..369783463cd86 100644 --- a/server/src/main/java/org/elasticsearch/inference/Model.java +++ b/server/src/main/java/org/elasticsearch/inference/Model.java @@ -9,9 +9,14 @@ package org.elasticsearch.inference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; import java.util.Objects; -public class Model { +public class Model implements Writeable { public static String documentId(String modelId) { return "model_" + modelId; } @@ -42,6 +47,11 @@ public Model(ModelConfigurations configurations) { this(configurations, new ModelSecrets()); } + public Model(StreamInput in) throws IOException { + this.configurations = new ModelConfigurations(in); + this.secrets = new ModelSecrets(in); + } + public String getInferenceEntityId() { return configurations.getInferenceEntityId(); } @@ -111,4 +121,10 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(configurations, secrets); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + configurations.writeTo(out); + secrets.writeTo(out); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportCreateEndpointsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportCreateEndpointsAction.java new file mode 100644 index 0000000000000..43c3d285bc320 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportCreateEndpointsAction.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +/** + * Handles the internal action for creating multiple inference endpoints. + */ +public class TransportCreateEndpointsAction { + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelStoreResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelStoreResponse.java new file mode 100644 index 0000000000000..9689e875f62ad --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelStoreResponse.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.registry; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; + +import java.io.IOException; + +public record ModelStoreResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) implements Writeable { + + public ModelStoreResponse(StreamInput in) throws IOException { + this( + in.readString(), + RestStatus.readFrom(in), + in.readException() + ); + } + + public boolean failed() { + return failureCause != null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(inferenceId); + RestStatus.writeTo(out, status); + out.writeException(failureCause); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/StoreInferenceEndpointsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/StoreInferenceEndpointsAction.java new file mode 100644 index 0000000000000..047f5e5fdb368 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/StoreInferenceEndpointsAction.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.registry; + +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.Model; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class StoreInferenceEndpointsAction extends ActionType { + + public static final StoreInferenceEndpointsAction INSTANCE = new StoreInferenceEndpointsAction(); + public static final String NAME = "cluster:internal/xpack/inference/create_endpoints"; + + public StoreInferenceEndpointsAction() { + super(NAME); + } + + public static class Request extends AcknowledgedRequest { + private final List models; + + public Request(List models, TimeValue timeout) { + super(timeout, DEFAULT_ACK_TIMEOUT); + this.models = Objects.requireNonNull(models); + } + + public Request(StreamInput in) throws IOException { + super(in); + models = in.readCollectionAsImmutableList(Model::new); + } + + public List getModels() { + return models; + } + } + + public static class Response extends ActionResponse { + private final List results; + + public Response(List results) { + this.results = results; + } + + public Response(StreamInput in) throws IOException { + results = in.readCollectionAsImmutableList(ModelStoreResponse::new); + } + + public List getResults() { + return results; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(results); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Response response = (Response) o; + return Objects.equals(results, response.results); + } + + @Override + public int hashCode() { + return Objects.hash(results); + } + } +} From a7f0f916c351ebbbe8057a812154c6acee2d90e8 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 21 Oct 2025 15:48:29 -0400 Subject: [PATCH 07/32] Successful task creation --- .../inference/src/main/java/module-info.java | 1 - .../xpack/inference/InferencePlugin.java | 29 ++++++- .../TransportCreateEndpointsAction.java | 15 ---- .../action/TransportStoreEndpointsAction.java | 77 +++++++++++++++++++ .../inference/registry/ModelRegistry.java | 6 -- .../InternalPreconfiguredEndpoints.java | 2 +- .../authorization/AuthorizationPoller.java | 32 +++++--- 7 files changed, 126 insertions(+), 36 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportCreateEndpointsAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportStoreEndpointsAction.java diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index a8e42b306e91f..bd200fd88a706 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -36,7 +36,6 @@ requires org.elasticsearch.sslconfig; requires org.apache.commons.text; requires software.amazon.awssdk.services.sagemakerruntime; - requires org.elasticsearch.inference; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 159882ad0009c..78d45b30269fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -11,16 +11,19 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.support.MappedActionFilter; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.NamedDiff; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.IndexScopedSettings; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsFilter; +import org.elasticsearch.common.settings.SettingsModule; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; @@ -36,10 +39,12 @@ import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.node.PluginComponentBinding; +import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.MapperPlugin; +import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SystemIndexPlugin; @@ -79,6 +84,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.TransportStoreEndpointsAction; import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; @@ -110,6 +116,7 @@ import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistryMetadata; +import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceDiagnosticsAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -175,7 +182,8 @@ public class InferencePlugin extends Plugin MapperPlugin, SearchPlugin, InternalSearchPlugin, - ClusterPlugin { + ClusterPlugin, + PersistentTaskPlugin { /** * When this setting is true the verification check that @@ -227,6 +235,7 @@ public class InferencePlugin extends Plugin private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private final SetOnce modelRegistry = new SetOnce<>(); private List inferenceServiceExtensions; + private final SetOnce authorizationTaskExecutorRef = new SetOnce<>(); public InferencePlugin(Settings settings) { this.settings = settings; @@ -246,7 +255,8 @@ public List getActions() { new ActionHandler(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class), new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class), new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class), - new ActionHandler(ClearInferenceEndpointCacheAction.INSTANCE, ClearInferenceEndpointCacheAction.class) + new ActionHandler(ClearInferenceEndpointCacheAction.INSTANCE, ClearInferenceEndpointCacheAction.class), + new ActionHandler(StoreInferenceEndpointsAction.INSTANCE, TransportStoreEndpointsAction.class) ); } @@ -337,10 +347,12 @@ public Collection createComponents(PluginServices services) { elasicInferenceServiceFactory.get().createSender(), inferenceServiceSettings, eisComponents, - modelRegistry.get() + modelRegistry.get(), + services.client() ) ); authTaskExecutor.init(); + authorizationTaskExecutorRef.set(authTaskExecutor); var sageMakerSchemas = new SageMakerSchemas(); var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas)); @@ -421,6 +433,17 @@ public Collection createComponents(PluginServices services) { return components; } + @Override + public List> getPersistentTasksExecutor( + ClusterService clusterService, + ThreadPool threadPool, + Client client, + SettingsModule settingsModule, + IndexNameExpressionResolver expressionResolver + ) { + return List.of(authorizationTaskExecutorRef.get()); + } + @Override public void loadExtensions(ExtensionLoader loader) { inferenceServiceExtensions = loader.loadExtensions(InferenceServiceExtension.class); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportCreateEndpointsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportCreateEndpointsAction.java deleted file mode 100644 index 43c3d285bc320..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportCreateEndpointsAction.java +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.action; - -/** - * Handles the internal action for creating multiple inference endpoints. - */ -public class TransportCreateEndpointsAction { - -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportStoreEndpointsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportStoreEndpointsAction.java new file mode 100644 index 0000000000000..a1cd042467a72 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportStoreEndpointsAction.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelStoreResponse; +import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction; + +import java.util.List; +import java.util.Objects; + +/** + * Handles the internal action for creating multiple inference endpoints. This should not be used by external REST APIs. + */ +public class TransportStoreEndpointsAction extends TransportMasterNodeAction< + StoreInferenceEndpointsAction.Request, + StoreInferenceEndpointsAction.Response> { + + private final ModelRegistry modelRegistry; + + @Inject + public TransportStoreEndpointsAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + ModelRegistry modelRegistry + ) { + super( + StoreInferenceEndpointsAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + StoreInferenceEndpointsAction.Request::new, + StoreInferenceEndpointsAction.Response::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + + this.modelRegistry = Objects.requireNonNull(modelRegistry); + } + + @Override + protected void masterOperation( + Task task, + StoreInferenceEndpointsAction.Request request, + ClusterState state, + ActionListener masterListener + ) { + SubscribableListener.>newForked( + listener -> modelRegistry.storeModels(request.getModels(), listener, request.masterNodeTimeout()) + ).andThenApply(StoreInferenceEndpointsAction.Response::new).addListener(masterListener); + } + + @Override + protected ClusterBlockException checkBlock(StoreInferenceEndpointsAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index a6f554fcb6423..c4e087ef4cc53 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -707,12 +707,6 @@ private void storeModel(Model model, boolean updateClusterState, ActionListener< }), timeout); } - public record ModelStoreResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) { - public boolean failed() { - return failureCause != null; - } - } - public void storeModels(List models, ActionListener> listener, TimeValue timeout) { storeModels(models, true, listener, timeout); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java index 869fb6a43a54f..e1751b5edbe7a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java @@ -53,7 +53,7 @@ public record MinimalModel( private static final ElasticInferenceServiceCompletionServiceSettings COMPLETION_SERVICE_SETTINGS = new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); private static final ElasticInferenceServiceSparseEmbeddingsServiceSettings SPARSE_EMBEDDINGS_SERVICE_SETTINGS = - new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_RERANK_MODEL_ID_V1, null); + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null); private static final ElasticInferenceServiceDenseTextEmbeddingsServiceSettings DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index 9687f105926df..d2ba6481a3c1a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -11,21 +11,24 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.Scheduler; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; import java.util.EnumSet; -import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -55,6 +58,7 @@ public class AuthorizationPoller extends AllocatedPersistentTask { private final ElasticInferenceServiceSettings elasticInferenceServiceSettings; private final AtomicBoolean initialized = new AtomicBoolean(false); private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; + private final Client client; public record TaskFields(long id, String type, String action, String description, TaskId parentTask, Map headers) {} @@ -64,7 +68,8 @@ public record Parameters( Sender sender, ElasticInferenceServiceSettings elasticInferenceServiceSettings, ElasticInferenceServiceComponents components, - ModelRegistry modelRegistry + ModelRegistry modelRegistry, + Client client ) {} public AuthorizationPoller(TaskFields taskFields, Parameters parameters) { @@ -76,6 +81,7 @@ public AuthorizationPoller(TaskFields taskFields, Parameters parameters) { parameters.elasticInferenceServiceSettings, parameters.components, parameters.modelRegistry, + parameters.client, null ); } @@ -89,6 +95,7 @@ public AuthorizationPoller(TaskFields taskFields, Parameters parameters) { ElasticInferenceServiceSettings elasticInferenceServiceSettings, ElasticInferenceServiceComponents components, ModelRegistry modelRegistry, + Client client, // this is a hack to facilitate testing Runnable callback ) { @@ -99,6 +106,7 @@ public AuthorizationPoller(TaskFields taskFields, Parameters parameters) { this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); this.elasticInferenceServiceComponents = Objects.requireNonNull(components); this.modelRegistry = Objects.requireNonNull(modelRegistry); + this.client = new OriginSettingClient(Objects.requireNonNull(client), ClientHelper.INFERENCE_ORIGIN); this.callback = callback; } @@ -184,7 +192,10 @@ private void sendAuthorizationRequest() { callback.run(); } firstAuthorizationCompletedLatch.countDown(); - }).delegateResponse((delegate, e) -> logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints")); + }).delegateResponse((delegate, e) -> { + logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints"); + delegate.onResponse(null); + }); SubscribableListener.newForked( authModelListener -> authorizationHandler.getAuthorization(authModelListener, sender) @@ -216,11 +227,12 @@ private void storePreconfiguredModels(Set newInferenceIds, ActionListene return; } + logger.debug("Storing new EIS preconfigured inference endpoints with inference IDs {}", newInferenceIds); var modelsToAdd = PreconfiguredEndpointModelAdapter.getModels(newInferenceIds, elasticInferenceServiceComponents); + var storeRequest = new StoreInferenceEndpointsAction.Request(modelsToAdd, TimeValue.THIRTY_SECONDS); - // TODO - ActionListener> storeListener = ActionListener.wrap(responses -> { - for (var response : responses) { + ActionListener storeListener = ActionListener.wrap(responses -> { + for (var response : responses.getResults()) { if (response.failed()) { logger.atWarn() .withThrowable(response.failureCause()) @@ -232,10 +244,10 @@ private void storePreconfiguredModels(Set newInferenceIds, ActionListene } }, e -> logger.atWarn().withThrowable(e).log("Failed to store new EIS preconfigured inference endpoints [{}]", newInferenceIds)); - modelRegistry.storeModels( - modelsToAdd, - ActionListener.runAfter(storeListener, () -> listener.onResponse(null)), - TimeValue.THIRTY_SECONDS + client.execute( + StoreInferenceEndpointsAction.INSTANCE, + storeRequest, + ActionListener.runAfter(storeListener, () -> listener.onResponse(null)) ); } } From 2e3246ce03b1ecd191c6283f616cea26ccd7a944 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 22 Oct 2025 08:27:41 -0400 Subject: [PATCH 08/32] Starting tests --- .../integration/ModelRegistryIT.java | 25 +- .../elastic/ElasticInferenceService.java | 5 - .../authorization/AuthorizationPoller.java | 33 +-- .../PreconfiguredEndpointModelAdapter.java | 2 +- .../AuthorizationPollerTests.java | 228 ++++++++++++++++++ 5 files changed, 252 insertions(+), 41 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 67d67183a37e3..28541abeb1390 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -53,6 +53,7 @@ import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelStoreResponse; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; @@ -600,7 +601,7 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() } public void testStoreModels_ReturnsEmptyList_WhenGivenNoModelsToStore() { - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); @@ -620,12 +621,12 @@ public void testStoreModels_StoresSingleInferenceEndpoint() { new TestModel.TestSecretSettings(secrets) ); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), Matchers.is(1)); - assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + assertThat(response.get(0), Matchers.is(new ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); assertMinimalServiceSettings(modelRegistry, model); @@ -659,13 +660,13 @@ public void testStoreModels_StoresMultipleInferenceEndpoints() { new TestModel.TestSecretSettings(secrets) ); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), Matchers.is(2)); - assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId1, RestStatus.CREATED, null))); - assertThat(response.get(1), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId2, RestStatus.CREATED, null))); + assertThat(response.get(0), Matchers.is(new ModelStoreResponse(inferenceId1, RestStatus.CREATED, null))); + assertThat(response.get(1), Matchers.is(new ModelStoreResponse(inferenceId2, RestStatus.CREATED, null))); assertModelAndMinimalSettingsWithSecrets(modelRegistry, model1, secrets); assertModelAndMinimalSettingsWithSecrets(modelRegistry, model2, secrets); @@ -716,12 +717,12 @@ public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflic new TestModel.TestSecretSettings(secrets) ); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), Matchers.is(2)); - assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + assertThat(response.get(0), Matchers.is(new ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); assertThat(response.get(1).inferenceId(), Matchers.is(model2.getInferenceEntityId())); assertThat(response.get(1).status(), Matchers.is(RestStatus.CONFLICT)); assertTrue(response.get(1).failed()); @@ -758,12 +759,12 @@ public void testStoreModels_StoresOneModel_RemovesSecondDuplicateModelFromList_D new TestModel.TestSecretSettings(secrets) ); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model1, model1, model2), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), Matchers.is(1)); - assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + assertThat(response.get(0), Matchers.is(new ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); assertModelAndMinimalSettingsWithSecrets(modelRegistry, model1, secrets); assertIndicesContainExpectedDocsCount(model1, 2); @@ -783,7 +784,7 @@ public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyE storeCorruptedModel(model, false); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); @@ -837,7 +838,7 @@ public void testStoreModels_OnFailure_RemovesPartialWritesOfInferenceEndpoint() storeCorruptedModel(model1, false); storeCorruptedModel(model2, true); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model1, model2, model3), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 213bb10e14187..9fa47cdb23b2d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -253,11 +253,6 @@ private static Map initDefaultEndpoints( ); } - @Override - public void onNodeStarted() { - // authorizationHandler.init(); - } - @Override protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) { if (returnDocuments != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index d2ba6481a3c1a..ea27d974ed314 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -32,8 +32,6 @@ import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -50,7 +48,6 @@ public class AuthorizationPoller extends AllocatedPersistentTask { private final ServiceComponents serviceComponents; private final ModelRegistry modelRegistry; private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler; - private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1); private final Sender sender; private final Runnable callback; private final AtomicReference lastAuthTask = new AtomicReference<>(null); @@ -117,23 +114,13 @@ public void start() { } } - /** - * Waits the specified amount of time for the first authorization call to complete. This is mainly to make testing easier. - * @param waitTime the max time to wait - * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} - */ - public void waitForAuthorizationToComplete(TimeValue waitTime) { - try { - if (firstAuthorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { - throw new IllegalStateException("The wait time has expired for authorization to complete."); - } - } catch (InterruptedException e) { - throw new IllegalStateException("Waiting for authorization to complete was interrupted"); - } - } - @Override protected void onCancelled() { + shutdown(); + } + + // default for testing + void shutdown() { shutdown.set(true); if (lastAuthTask.get() != null) { lastAuthTask.get().cancel(); @@ -182,7 +169,8 @@ private void scheduleAndSendAuthorizationRequest() { sendAuthorizationRequest(); } - private void sendAuthorizationRequest() { + // default for testing + void sendAuthorizationRequest() { if (modelRegistry.isReady() == false) { return; } @@ -191,7 +179,6 @@ private void sendAuthorizationRequest() { if (callback != null) { callback.run(); } - firstAuthorizationCompletedLatch.countDown(); }).delegateResponse((delegate, e) -> { logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints"); delegate.onResponse(null); @@ -227,11 +214,11 @@ private void storePreconfiguredModels(Set newInferenceIds, ActionListene return; } - logger.debug("Storing new EIS preconfigured inference endpoints with inference IDs {}", newInferenceIds); + logger.info("Storing new EIS preconfigured inference endpoints with inference IDs {}", newInferenceIds); var modelsToAdd = PreconfiguredEndpointModelAdapter.getModels(newInferenceIds, elasticInferenceServiceComponents); var storeRequest = new StoreInferenceEndpointsAction.Request(modelsToAdd, TimeValue.THIRTY_SECONDS); - ActionListener storeListener = ActionListener.wrap(responses -> { + ActionListener logResultsListener = ActionListener.wrap(responses -> { for (var response : responses.getResults()) { if (response.failed()) { logger.atWarn() @@ -247,7 +234,7 @@ private void storePreconfiguredModels(Set newInferenceIds, ActionListene client.execute( StoreInferenceEndpointsAction.INSTANCE, storeRequest, - ActionListener.runAfter(storeListener, () -> listener.onResponse(null)) + ActionListener.runAfter(logResultsListener, () -> listener.onResponse(null)) ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java index 21ad626003c17..8a7729a693abb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java @@ -27,7 +27,7 @@ public static List getModels(Set inferenceIds, ElasticInferenceSe .toList(); } - private static Model createModel( + public static Model createModel( InternalPreconfiguredEndpoints.MinimalModel minimalModel, ElasticInferenceServiceComponents elasticInferenceServiceComponents ) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java new file mode 100644 index 0000000000000..a2a908b49efe7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -0,0 +1,228 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.util.EnumSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AuthorizationPollerTests extends ESTestCase { + private DeterministicTaskQueue taskQueue; + + @Before + public void init() throws Exception { + taskQueue = new DeterministicTaskQueue(); + } + + public void testDoesNotSendAuthorizationRequest_WhenModelRegistryIsNotReady() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(false); + + var authorizationRequestHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + authorizationRequestHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + new ElasticInferenceServiceComponents(""), + mockRegistry, + mock(Client.class), + null + ); + + poller.sendAuthorizationRequest(); + + verify(authorizationRequestHandler, never()).getAuthorization(any(), any()); + } + + public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); + + var eisComponents = new ElasticInferenceServiceComponents(""); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + eisComponents, + mockRegistry, + mockClient, + null + ); + + var requestArgCaptor = ArgumentCaptor.forClass(StoreInferenceEndpointsAction.Request.class); + + poller.sendAuthorizationRequest(); + verify(mockClient).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any()); + var capturedRequest = requestArgCaptor.getValue(); + assertThat( + capturedRequest.getModels(), + is( + List.of( + PreconfiguredEndpointModelAdapter.createModel( + InternalPreconfiguredEndpoints.getWithInferenceId(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2), + eisComponents + ) + ) + ) + ); + } + + public void testDoesNotAttemptToStoreModelIds_ThatDoNotExistInThePreconfiguredMapping() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + // This is a model id that does not exist in the preconfigured endpoints map so it will not be stored + "abc", + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); + + var eisComponents = new ElasticInferenceServiceComponents(""); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + eisComponents, + mockRegistry, + mockClient, + null + ); + + var requestArgCaptor = ArgumentCaptor.forClass(StoreInferenceEndpointsAction.Request.class); + + poller.sendAuthorizationRequest(); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any()); + } + + public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegration_DoesNotSupport() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, + // EIS does not yet support completions so this model will be ignored + EnumSet.of(TaskType.COMPLETION) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); + + var eisComponents = new ElasticInferenceServiceComponents(""); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + eisComponents, + mockRegistry, + mockClient, + null + ); + + var requestArgCaptor = ArgumentCaptor.forClass(StoreInferenceEndpointsAction.Request.class); + + poller.sendAuthorizationRequest(); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any()); + } + + public void testSendsTwoAuthorizationRequests() { + fail("TODO"); + } +} From 9e5ed519b9ae1b8397488f9df46f75699dec2191 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 22 Oct 2025 17:13:44 -0400 Subject: [PATCH 09/32] More tests --- .../xpack/inference/InferencePlugin.java | 5 +- .../elastic/ElasticInferenceServiceModel.java | 14 ++ .../authorization/AuthorizationPoller.java | 12 +- .../AuthorizationTaskExecutor.java | 44 ++-- .../AuthorizationTaskParams.java | 7 +- .../PreconfiguredEndpointModelAdapter.java | 1 + .../AuthorizationPollerTests.java | 69 +++++- .../AuthorizationTaskExecutorTests.java | 223 ++++++++++++++++++ .../AuthorizationTaskParamsTests.java | 37 +++ ...reconfiguredEndpointModelAdapterTests.java | 176 ++++++++++++++ 10 files changed, 559 insertions(+), 29 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParamsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 78d45b30269fc..54276a10c0d8d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -337,10 +337,8 @@ public Collection createComponents(PluginServices services) { var eisComponents = new ElasticInferenceServiceComponents(inferenceServiceSettings.getElasticInferenceServiceUrl()); - var authTaskExecutor = new AuthorizationTaskExecutor( - services.client(), + var authTaskExecutor = AuthorizationTaskExecutor.create( services.clusterService(), - services.threadPool(), new AuthorizationPoller.Parameters( serviceComponents.get(), authorizationHandler, @@ -351,7 +349,6 @@ public Collection createComponents(PluginServices services) { services.client() ) ); - authTaskExecutor.init(); authorizationTaskExecutorRef.set(authTaskExecutor); var sageMakerSchemas = new SageMakerSchemas(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java index ccf776f5db597..9f5f6d1b75dfc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java @@ -53,4 +53,18 @@ public RateLimitSettings rateLimitSettings() { public ElasticInferenceServiceComponents elasticInferenceServiceComponents() { return elasticInferenceServiceComponents; } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + ElasticInferenceServiceModel that = (ElasticInferenceServiceModel) o; + return Objects.equals(rateLimitServiceSettings, that.rateLimitServiceSettings) + && Objects.equals(elasticInferenceServiceComponents, that.elasticInferenceServiceComponents); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), rateLimitServiceSettings, elasticInferenceServiceComponents); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index ea27d974ed314..c578f991df2a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -64,19 +64,23 @@ public record Parameters( ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, Sender sender, ElasticInferenceServiceSettings elasticInferenceServiceSettings, - ElasticInferenceServiceComponents components, + ElasticInferenceServiceComponents eisComponents, ModelRegistry modelRegistry, Client client ) {} - public AuthorizationPoller(TaskFields taskFields, Parameters parameters) { + public static AuthorizationPoller create(TaskFields taskFields, Parameters parameters) { + return new AuthorizationPoller(Objects.requireNonNull(taskFields), Objects.requireNonNull(parameters)); + } + + private AuthorizationPoller(TaskFields taskFields, Parameters parameters) { this( taskFields, parameters.serviceComponents, parameters.authorizationRequestHandler, parameters.sender, parameters.elasticInferenceServiceSettings, - parameters.components, + parameters.eisComponents, parameters.modelRegistry, parameters.client, null @@ -109,7 +113,7 @@ public AuthorizationPoller(TaskFields taskFields, Parameters parameters) { public void start() { if (initialized.compareAndSet(false, true)) { - logger.debug("Initializing authorization logic"); + logger.debug("Initializing EIS authorization logic"); serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java index e1f58812ca40a..b932d91e42745 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java @@ -11,10 +11,10 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.core.TimeValue; import org.elasticsearch.persistent.AllocatedPersistentTask; @@ -25,7 +25,6 @@ import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.persistent.PersistentTasksService; import org.elasticsearch.tasks.TaskId; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.RemoteTransportException; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; @@ -33,6 +32,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller.TASK_NAME; @@ -44,27 +44,43 @@ public class AuthorizationTaskExecutor extends PersistentTasksExecutor currentTask = new AtomicReference<>(); - public AuthorizationTaskExecutor( - Client client, + public static AuthorizationTaskExecutor create(ClusterService clusterService, AuthorizationPoller.Parameters parameters) { + Objects.requireNonNull(clusterService); + Objects.requireNonNull(parameters); + + var executor = new AuthorizationTaskExecutor( + clusterService, + new PersistentTasksService(clusterService, parameters.serviceComponents().threadPool(), parameters.client()), + parameters + ); + executor.init(); + return executor; + } + + // default for testing + AuthorizationTaskExecutor( ClusterService clusterService, - ThreadPool threadPool, + PersistentTasksService persistentTasksService, AuthorizationPoller.Parameters pollerParameters ) { - super(TASK_NAME, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + super(TASK_NAME, pollerParameters.serviceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME)); this.clusterService = Objects.requireNonNull(clusterService); - this.persistentTasksService = new PersistentTasksService(clusterService, threadPool, client); + this.persistentTasksService = Objects.requireNonNull(persistentTasksService); this.pollerParameters = Objects.requireNonNull(pollerParameters); } - public void init() { - clusterService.addListener(this); + // default for testing + void init() { + // If the EIS url is not configured, then we won't be able to interact with the service, so don't start the task. + if (Strings.isNullOrEmpty(pollerParameters.eisComponents().elasticInferenceServiceUrl()) == false) { + clusterService.addListener(this); + } } @Override protected void nodeOperation(AllocatedPersistentTask task, AuthorizationTaskParams params, PersistentTaskState state) { - // TODO remove - logger.warn("Starting authorization poller task"); var authPoller = (AuthorizationPoller) task; authPoller.start(); } @@ -83,7 +99,7 @@ protected AuthorizationPoller createTask( PersistentTasksCustomMetadata.PersistentTask taskInProgress, Map headers ) { - return new AuthorizationPoller( + return AuthorizationPoller.create( new AuthorizationPoller.TaskFields(id, type, action, getDescription(taskInProgress), parentTaskId, headers), pollerParameters ); @@ -100,9 +116,7 @@ public void clusterChanged(ClusterChangedEvent event) { TASK_NAME, new AuthorizationTaskParams(), TimeValue.THIRTY_SECONDS, - ActionListener.wrap(persistentTask -> { - logger.warn("Created authorization poller task"); - }, e -> { + ActionListener.wrap(persistentTask -> logger.debug("Created authorization poller task"), e -> { var t = e instanceof RemoteTransportException ? e.getCause() : e; if (t instanceof ResourceAlreadyExistsException == false) { logger.error("Failed to create authorization poller task", e); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java index 38a9fced3607d..a93428e35ae3c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java @@ -11,7 +11,6 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.health.node.selection.HealthNodeTaskParams; import org.elasticsearch.persistent.PersistentTaskParams; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ToXContent; @@ -23,7 +22,7 @@ import static org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller.TASK_NAME; public class AuthorizationTaskParams implements PersistentTaskParams { - private static final AuthorizationTaskParams INSTANCE = new AuthorizationTaskParams(); + public static final AuthorizationTaskParams INSTANCE = new AuthorizationTaskParams(); private static final ObjectParser PARSER = new ObjectParser<>(TASK_NAME, true, () -> INSTANCE); @@ -61,7 +60,7 @@ public int hashCode() { } @Override - public boolean equals(Object obj) { - return obj instanceof HealthNodeTaskParams; + public boolean equals(Object o) { + return this == o || (o != null && getClass() == o.getClass()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java index 8a7729a693abb..ab23da7cab5b2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java @@ -22,6 +22,7 @@ public class PreconfiguredEndpointModelAdapter { public static List getModels(Set inferenceIds, ElasticInferenceServiceComponents elasticInferenceServiceComponents) { return inferenceIds.stream() + .sorted() .filter(EIS_PRECONFIGURED_ENDPOINT_IDS::contains) .map(id -> createModel(InternalPreconfiguredEndpoints.getWithInferenceId(id), elasticInferenceServiceComponents)) .toList(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index a2a908b49efe7..8e0896fe65ad9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -28,6 +28,10 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.is; @@ -222,7 +226,68 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any()); } - public void testSendsTwoAuthorizationRequests() { - fail("TODO"); + public void testSendsTwoAuthorizationRequests() throws InterruptedException { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + // this is an unknown model id so it won't trigger storing an inference endpoint because + // it doesn't map to a known one + "abc", + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + var eisComponents = new ElasticInferenceServiceComponents(""); + + var callbackCount = new AtomicInteger(0); + var latch = new CountDownLatch(2); + final var pollerRef = new AtomicReference(); + + Runnable callback = () -> { + var count = callbackCount.incrementAndGet(); + latch.countDown(); + + // we only want to run the tasks twice, so advance the time on the queue + // which flags the scheduled authorization request to be ready to run + if (count == 1) { + taskQueue.advanceTime(); + } else { + pollerRef.get().shutdown(); + } + }; + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + eisComponents, + mockRegistry, + mockClient, + callback + ); + pollerRef.set(poller); + poller.start(); + taskQueue.runAllRunnableTasks(); + latch.await(TimeValue.THIRTY_SECONDS.getSeconds(), TimeUnit.SECONDS); + + assertThat(callbackCount.get(), is(2)); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java new file mode 100644 index 0000000000000..5b7be633bf584 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java @@ -0,0 +1,223 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.persistent.ClusterPersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; +import org.junit.After; +import org.junit.Before; +import org.mockito.Mockito; + +import static org.elasticsearch.cluster.metadata.Metadata.EMPTY_METADATA; +import static org.elasticsearch.persistent.PersistentTasksExecutor.NO_NODE_FOUND; +import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class AuthorizationTaskExecutorTests extends ESTestCase { + + private ThreadPool threadPool; + private ClusterService clusterService; + private PersistentTasksService persistentTasksService; + private String localNodeId; + + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clusterService = createClusterService(threadPool); + persistentTasksService = mock(PersistentTasksService.class); + localNodeId = clusterService.localNode().getId(); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + clusterService.close(); + terminate(threadPool); + } + + public void testCreatesTask_WhenItDoesNotExistOnClusterStateChange() { + var eisUrl = "abc"; + + var executor = new AuthorizationTaskExecutor( + clusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + new ElasticInferenceServiceComponents(eisUrl), + mock(ModelRegistry.class), + mock(Client.class) + ) + ); + executor.init(); + + var listener1 = new PlainActionFuture(); + clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener1); + listener1.actionGet(TimeValue.THIRTY_SECONDS); + verify(persistentTasksService, times(1)).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + + Mockito.clearInvocations(persistentTasksService); + // Ensure that if the task is gone, it will be recreated. + var listener2 = new PlainActionFuture(); + clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener2); + listener2.actionGet(TimeValue.THIRTY_SECONDS); + verify(persistentTasksService, times(1)).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(new AuthorizationTaskParams()), + any(), + any() + ); + } + + private ClusterState initialState() { + DiscoveryNodes.Builder nodes = DiscoveryNodes.builder() + .add(DiscoveryNodeUtils.create(localNodeId)) + .localNodeId(localNodeId) + .masterNodeId(localNodeId); + + return ClusterState.builder(ClusterName.DEFAULT).nodes(nodes).metadata(EMPTY_METADATA).build(); + } + + public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEisUrlIsEmpty() { + var executor = new AuthorizationTaskExecutor( + clusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + new ElasticInferenceServiceComponents(""), + mock(ModelRegistry.class), + mock(Client.class) + ) + ); + executor.init(); + + var listener1 = new PlainActionFuture(); + clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener1); + listener1.actionGet(TimeValue.THIRTY_SECONDS); + verify(persistentTasksService, never()).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + } + + public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEisUrlIsNull() { + var executor = new AuthorizationTaskExecutor( + clusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + new ElasticInferenceServiceComponents(null), + mock(ModelRegistry.class), + mock(Client.class) + ) + ); + executor.init(); + + var listener1 = new PlainActionFuture(); + clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener1); + listener1.actionGet(TimeValue.THIRTY_SECONDS); + verify(persistentTasksService, never()).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + } + + public void testDoesNotCreateTask_OnClusterStateChange_WhenItAlreadyExists() { + var initialState = initialState(); + var event = new ClusterChangedEvent( + "testClusterChanged", + ClusterState.builder(initialState) + .metadata( + Metadata.builder(initialState.metadata()) + .putCustom( + ClusterPersistentTasksCustomMetadata.TYPE, + ClusterPersistentTasksCustomMetadata.builder() + .addTask( + AuthorizationPoller.TASK_NAME, + AuthorizationPoller.TASK_NAME, + AuthorizationTaskParams.INSTANCE, + NO_NODE_FOUND + ) + .build() + ) + ) + .build(), + ClusterState.EMPTY_STATE + ); + + var eisUrl = "abc"; + var executor = new AuthorizationTaskExecutor( + clusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + new ElasticInferenceServiceComponents(eisUrl), + mock(ModelRegistry.class), + mock(Client.class) + ) + ); + executor.init(); + + executor.clusterChanged(event); + verify(persistentTasksService, never()).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParamsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParamsTests.java new file mode 100644 index 0000000000000..b07036bb2a612 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParamsTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; + +public class AuthorizationTaskParamsTests extends AbstractBWCWireSerializationTestCase { + + @Override + protected AuthorizationTaskParams mutateInstanceForVersion(AuthorizationTaskParams instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AuthorizationTaskParams::new; + } + + @Override + protected AuthorizationTaskParams createTestInstance() { + return new AuthorizationTaskParams(); + } + + @Override + protected AuthorizationTaskParams mutateInstance(AuthorizationTaskParams instance) throws IOException { + return null; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java new file mode 100644 index 0000000000000..b3e3305f54dcb --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java @@ -0,0 +1,176 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_RERANK_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_RERANK_MODEL_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DENSE_TEXT_EMBEDDINGS_DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.defaultDenseTextEmbeddingsSimilarity; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +public class PreconfiguredEndpointModelAdapterTests extends ESTestCase { + + private static final ElasticInferenceServiceSparseEmbeddingsServiceSettings SPARSE_SETTINGS = + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null); + private static final ElasticInferenceServiceCompletionServiceSettings COMPLETION_SETTINGS = + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); + private static final ElasticInferenceServiceDenseTextEmbeddingsServiceSettings DENSE_SETTINGS = + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + defaultDenseTextEmbeddingsSimilarity(), + DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + null + ); + private static final ElasticInferenceServiceRerankServiceSettings RERANK_SETTINGS = new ElasticInferenceServiceRerankServiceSettings( + DEFAULT_RERANK_MODEL_ID_V1 + ); + private static final ElasticInferenceServiceComponents EIS_COMPONENTS = new ElasticInferenceServiceComponents(""); + + public void testGetModelsWithValidId() { + var endpointIds = Set.of( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + DEFAULT_ELSER_ENDPOINT_ID_V2, + DEFAULT_RERANK_ENDPOINT_ID_V1, + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID + ); + var models = PreconfiguredEndpointModelAdapter.getModels( + endpointIds, + EIS_COMPONENTS + ); + + assertThat(models, hasSize(endpointIds.size())); + assertThat( + models, + containsInAnyOrder( + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + SPARSE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + SPARSE_SETTINGS, + EIS_COMPONENTS + ), + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + COMPLETION_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + COMPLETION_SETTINGS, + EIS_COMPONENTS + ), + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + DENSE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + DENSE_SETTINGS, + EIS_COMPONENTS + ), + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_RERANK_ENDPOINT_ID_V1, + TaskType.RERANK, + ElasticInferenceService.NAME, + RERANK_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + RERANK_SETTINGS, + EIS_COMPONENTS + ) + ) + ); + } + + public void testGetModelsWithValidAndInvalidIds() { + var models = PreconfiguredEndpointModelAdapter.getModels( + Set.of( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + "some-invalid-id", + DEFAULT_ELSER_ENDPOINT_ID_V2 + ), + EIS_COMPONENTS + ); + + assertThat(models, hasSize(2)); + assertThat( + models, + containsInAnyOrder( + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + SPARSE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + SPARSE_SETTINGS, + EIS_COMPONENTS + ), + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + COMPLETION_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + COMPLETION_SETTINGS, + EIS_COMPONENTS + ) + ) + ); + } + + public void testGetModelsWithOnlyInvalidId() { + assertThat(PreconfiguredEndpointModelAdapter.getModels( + Collections.singleton("nonexistent-id"), + EIS_COMPONENTS + ), is(List.of())); + } +} From 36deff5a0f9419daa808c7b3e22c79f838bea6c8 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 23 Oct 2025 17:23:17 -0400 Subject: [PATCH 10/32] Even more tests --- .../StoreInferenceEndpointsAction.java | 21 ++- .../inference/results/ModelStoreResponse.java | 65 +++++++++ .../xpack/core/inference/ModelTests.java | 129 ++++++++++++++++++ ...eInferenceEndpointsActionRequestTests.java | 61 +++++++++ ...InferenceEndpointsActionResponseTests.java | 45 ++++++ .../results/ModelStoreResponseTests.java | 86 ++++++++++++ .../AuthorizationTaskExecutorIT.java | 62 +++++++++ .../xpack/inference/InferencePlugin.java | 16 +-- .../action/TransportStoreEndpointsAction.java | 4 +- .../inference/registry/ModelRegistry.java | 1 + .../registry/ModelStoreResponse.java | 38 ------ .../authorization/AuthorizationPoller.java | 9 +- .../AuthorizationTaskExecutor.java | 2 +- .../AuthorizationPollerTests.java | 26 +--- .../AuthorizationTaskExecutorTests.java | 5 - .../xpack/security/operator/Constants.java | 1 + 16 files changed, 490 insertions(+), 81 deletions(-) rename x-pack/plugin/{inference/src/main/java/org/elasticsearch/xpack/inference/registry => core/src/main/java/org/elasticsearch/xpack/core/inference/action}/StoreInferenceEndpointsAction.java (79%) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponse.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponseTests.java create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelStoreResponse.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/StoreInferenceEndpointsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java similarity index 79% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/StoreInferenceEndpointsAction.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java index 047f5e5fdb368..1b41e72e2af6f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/StoreInferenceEndpointsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.registry; +package org.elasticsearch.xpack.core.inference.action; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.Model; +import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse; import java.io.IOException; import java.util.List; @@ -41,9 +42,27 @@ public Request(StreamInput in) throws IOException { models = in.readCollectionAsImmutableList(Model::new); } + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeCollection(models); + } + public List getModels() { return models; } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(models, request.models); + } + + @Override + public int hashCode() { + return Objects.hashCode(models); + } } public static class Response extends ActionResponse { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponse.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponse.java new file mode 100644 index 0000000000000..30cbdfdfb96cd --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponse.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; + +import java.io.IOException; +import java.util.Objects; + +/** + * Response for storing a model in the model registry using the bulk API. + */ +public record ModelStoreResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) implements Writeable { + + public ModelStoreResponse(StreamInput in) throws IOException { + this(in.readString(), RestStatus.readFrom(in), in.readException()); + } + + public boolean failed() { + return failureCause != null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(inferenceId); + RestStatus.writeTo(out, status); + out.writeException(failureCause); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + ModelStoreResponse that = (ModelStoreResponse) o; + return status == that.status && Objects.equals(inferenceId, that.inferenceId) + // Exception does not have hashCode() or equals() so assume errors are equal iff class and message are equal + && Objects.equals( + failureCause == null ? null : failureCause.getMessage(), + that.failureCause == null ? null : that.failureCause.getMessage() + ) + && Objects.equals( + failureCause == null ? null : failureCause.getClass(), + that.failureCause == null ? null : that.failureCause.getClass() + ); + } + + @Override + public int hashCode() { + return Objects.hash( + inferenceId, + status, + // Exception does not have hashCode() or equals() so assume errors are equal iff class and message are equal + failureCause == null ? null : failureCause.getMessage(), + failureCause == null ? null : failureCause.getClass() + ); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java new file mode 100644 index 0000000000000..e616954c89c45 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests; + +import java.io.IOException; +import java.util.List; + +public class ModelTests extends ESTestCase { + public static Model randomModel() { + return new Model( + new ModelConfigurations( + randomAlphaOfLength(6), + randomFrom(TaskType.values()), + randomAlphaOfLength(6), + new TestServiceSettings( + randomAlphaOfLength(10), + randomIntBetween(1, 1024), + randomFrom(SimilarityMeasure.values()), + randomFrom(DenseVectorFieldMapper.ElementType.values()) + ), + EmptyTaskSettings.INSTANCE, + randomBoolean() ? ChunkingSettingsTests.createRandomChunkingSettings() : null + ), + new ModelSecrets(EmptySecretSettings.INSTANCE) + ); + } + + public record TestServiceSettings( + String model, + Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable DenseVectorFieldMapper.ElementType elementType + ) implements ServiceSettings { + + static final String NAME = "test_text_embedding_service_settings"; + + public TestServiceSettings(StreamInput in) throws IOException { + this( + in.readString(), + in.readInt(), + in.readOptionalEnum(SimilarityMeasure.class), + in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class) + ); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("model", model); + builder.field("dimensions", dimensions); + if (similarity != null) { + builder.field("similarity", similarity); + } + if (elementType != null) { + builder.field("element_type", elementType); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(model); + out.writeInt(dimensions); + out.writeOptionalEnum(similarity); + out.writeOptionalEnum(elementType); + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public SimilarityMeasure similarity() { + return similarity != null ? similarity : SimilarityMeasure.COSINE; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return elementType != null ? elementType : DenseVectorFieldMapper.ElementType.FLOAT; + } + + @Override + public String modelId() { + return model; + } + } + + public static List getNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(ServiceSettings.class, TestServiceSettings.NAME, TestServiceSettings::new) + ); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java new file mode 100644 index 0000000000000..3673296c29ce7 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.core.XPackClientPlugin; +import org.elasticsearch.xpack.core.inference.ModelTests; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.ArrayList; + +public class StoreInferenceEndpointsActionRequestTests extends AbstractBWCWireSerializationTestCase { + + @Override + protected StoreInferenceEndpointsAction.Request mutateInstanceForVersion( + StoreInferenceEndpointsAction.Request instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return StoreInferenceEndpointsAction.Request::new; + } + + @Override + protected StoreInferenceEndpointsAction.Request createTestInstance() { + return new StoreInferenceEndpointsAction.Request(randomList(5, ModelTests::randomModel), randomTimeValue()); + } + + @Override + protected StoreInferenceEndpointsAction.Request mutateInstance(StoreInferenceEndpointsAction.Request instance) throws IOException { + var newModels = new ArrayList<>(instance.getModels()); + newModels.add(ModelTests.randomModel()); + return new StoreInferenceEndpointsAction.Request(newModels, instance.masterNodeTimeout()); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + var namedWriteables = new ArrayList(); + namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, EmptySecretSettings.NAME, EmptySecretSettings::new)); + namedWriteables.addAll(ModelTests.getNamedWriteables()); + namedWriteables.addAll(XPackClientPlugin.getChunkingSettingsNamedWriteables()); + + return new NamedWriteableRegistry(namedWriteables); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionResponseTests.java new file mode 100644 index 0000000000000..c7a692f0c5e9e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionResponseTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.inference.results.ModelStoreResponseTests; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.ArrayList; + +public class StoreInferenceEndpointsActionResponseTests extends AbstractBWCWireSerializationTestCase< + StoreInferenceEndpointsAction.Response> { + + @Override + protected StoreInferenceEndpointsAction.Response mutateInstanceForVersion( + StoreInferenceEndpointsAction.Response instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return StoreInferenceEndpointsAction.Response::new; + } + + @Override + protected StoreInferenceEndpointsAction.Response createTestInstance() { + return new StoreInferenceEndpointsAction.Response(randomList(5, ModelStoreResponseTests::randomModelStoreResponse)); + } + + @Override + protected StoreInferenceEndpointsAction.Response mutateInstance(StoreInferenceEndpointsAction.Response instance) throws IOException { + var newResults = new ArrayList<>(instance.getResults()); + newResults.add(ModelStoreResponseTests.randomModelStoreResponse()); + return new StoreInferenceEndpointsAction.Response(newResults); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponseTests.java new file mode 100644 index 0000000000000..b13772d00a840 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponseTests.java @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; + +public class ModelStoreResponseTests extends AbstractBWCWireSerializationTestCase { + + public static ModelStoreResponse randomModelStoreResponse() { + return new ModelStoreResponse( + randomAlphaOfLength(10), + randomFrom(RestStatus.values()), + randomBoolean() ? null : new IllegalStateException("Test exception") + ); + } + + public void testFailed() { + { + var successResponse = new ModelStoreResponse("model_1", RestStatus.OK, null); + assertFalse(successResponse.failed()); + } + { + var failedResponse = new ModelStoreResponse( + "model_2", + RestStatus.INTERNAL_SERVER_ERROR, + new IllegalStateException("Test failure") + ); + assertTrue(failedResponse.failed()); + } + { + var failedResponse = new ModelStoreResponse( + "model_2", + RestStatus.OK, + new IllegalStateException("Test failure") + ); + assertTrue(failedResponse.failed()); + } + } + + @Override + protected ModelStoreResponse mutateInstanceForVersion(ModelStoreResponse instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return ModelStoreResponse::new; + } + + @Override + protected ModelStoreResponse createTestInstance() { + return randomModelStoreResponse(); + } + + @Override + protected ModelStoreResponse mutateInstance(ModelStoreResponse instance) throws IOException { + int choice = randomIntBetween(0, 2); + return switch (choice) { + case 0 -> { + String newInferenceId = instance.inferenceId() + "_mutated"; + yield new ModelStoreResponse(newInferenceId, instance.status(), instance.failureCause()); + } + case 1 -> + new ModelStoreResponse( + instance.inferenceId(), + randomValueOtherThan(instance.status(), () -> randomFrom(RestStatus.values())), + instance.failureCause() + ); + case 2 -> { + Exception newFailureCause = instance.failureCause() == null ? new IllegalStateException("Mutated exception") : null; + yield new ModelStoreResponse(instance.inferenceId(), instance.status(), newFailureCause); + } + default -> throw new IllegalStateException("Unexpected value: " + choice); + }; + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java new file mode 100644 index 0000000000000..48251009ab42f --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; +import org.junit.After; +import org.junit.Before; + +import java.util.Collection; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; + +public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { + + private ModelRegistry modelRegistry; + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private String gatewayUrl; + + @Before + public void createComponents() throws Exception { + threadPool = createThreadPool(inferenceUtilityExecutors()); + webServer.start(); + gatewayUrl = getUrl(webServer); + modelRegistry = node().injector().getInstance(ModelRegistry.class); + node().injector().getInstance(ClusterState.class); + } + + @After + public void shutdown() { + terminate(threadPool); + webServer.close(); + } + + @Override + protected boolean resetNodeAfterTest() { + return true; + } + + @Override + protected Collection> getPlugins() { + return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); + } + + public void testCreateEndpoints() { + var executor = new AuthorizationTaskExecutor(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 54276a10c0d8d..fd4f97b1c7c60 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.inference; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.support.MappedActionFilter; import org.elasticsearch.client.internal.Client; @@ -72,6 +70,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; import org.elasticsearch.xpack.core.ssl.SSLService; @@ -116,7 +115,6 @@ import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistryMetadata; -import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceDiagnosticsAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -220,12 +218,10 @@ public class InferencePlugin extends Plugin public static final String UTILITY_THREAD_POOL_NAME = "inference_utility"; public static final String INFERENCE_RESPONSE_THREAD_POOL_NAME = "inference_response"; - private static final Logger log = LogManager.getLogger(InferencePlugin.class); - private final Settings settings; private final SetOnce httpFactory = new SetOnce<>(); private final SetOnce amazonBedrockFactory = new SetOnce<>(); - private final SetOnce elasicInferenceServiceFactory = new SetOnce<>(); + private final SetOnce elasticInferenceServiceFactory = new SetOnce<>(); private final SetOnce serviceComponents = new SetOnce<>(); // This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it // not being initialized yet @@ -328,7 +324,7 @@ public Collection createComponents(PluginServices services) { elasticInferenceServiceHttpClientManager, services.clusterService() ); - elasicInferenceServiceFactory.set(elasticInferenceServiceRequestSenderFactory); + elasticInferenceServiceFactory.set(elasticInferenceServiceRequestSenderFactory); var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler( inferenceServiceSettings.getElasticInferenceServiceUrl(), @@ -342,7 +338,7 @@ public Collection createComponents(PluginServices services) { new AuthorizationPoller.Parameters( serviceComponents.get(), authorizationHandler, - elasicInferenceServiceFactory.get().createSender(), + elasticInferenceServiceFactory.get().createSender(), inferenceServiceSettings, eisComponents, modelRegistry.get(), @@ -356,7 +352,7 @@ public Collection createComponents(PluginServices services) { inferenceServices.add( () -> List.of( context -> new ElasticInferenceService( - elasicInferenceServiceFactory.get(), + elasticInferenceServiceFactory.get(), serviceComponents.get(), inferenceServiceSettings, modelRegistry.get(), @@ -414,7 +410,7 @@ public Collection createComponents(PluginServices services) { ); components.add(inferenceStatsBinding); components.add(authorizationHandler); - components.add(new PluginComponentBinding<>(Sender.class, elasicInferenceServiceFactory.get().createSender())); + components.add(new PluginComponentBinding<>(Sender.class, elasticInferenceServiceFactory.get().createSender())); components.add( new InferenceEndpointRegistry( services.clusterService(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportStoreEndpointsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportStoreEndpointsAction.java index a1cd042467a72..96905892c5c4f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportStoreEndpointsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportStoreEndpointsAction.java @@ -20,9 +20,9 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; +import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ModelStoreResponse; -import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction; import java.util.List; import java.util.Objects; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index c4e087ef4cc53..7968745fff2a8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -73,6 +73,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; +import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse; import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.services.ServiceUtils; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelStoreResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelStoreResponse.java deleted file mode 100644 index 9689e875f62ad..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelStoreResponse.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.registry; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.rest.RestStatus; - -import java.io.IOException; - -public record ModelStoreResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) implements Writeable { - - public ModelStoreResponse(StreamInput in) throws IOException { - this( - in.readString(), - RestStatus.readFrom(in), - in.readException() - ); - } - - public boolean failed() { - return failureCause != null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(inferenceId); - RestStatus.writeTo(out, status); - out.writeException(failureCause); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index c578f991df2a5..8ace4613b6c76 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -20,9 +20,9 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; @@ -64,7 +64,6 @@ public record Parameters( ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, Sender sender, ElasticInferenceServiceSettings elasticInferenceServiceSettings, - ElasticInferenceServiceComponents eisComponents, ModelRegistry modelRegistry, Client client ) {} @@ -80,7 +79,6 @@ private AuthorizationPoller(TaskFields taskFields, Parameters parameters) { parameters.authorizationRequestHandler, parameters.sender, parameters.elasticInferenceServiceSettings, - parameters.eisComponents, parameters.modelRegistry, parameters.client, null @@ -94,7 +92,6 @@ private AuthorizationPoller(TaskFields taskFields, Parameters parameters) { ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, Sender sender, ElasticInferenceServiceSettings elasticInferenceServiceSettings, - ElasticInferenceServiceComponents components, ModelRegistry modelRegistry, Client client, // this is a hack to facilitate testing @@ -105,7 +102,9 @@ private AuthorizationPoller(TaskFields taskFields, Parameters parameters) { this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); this.sender = Objects.requireNonNull(sender); this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); - this.elasticInferenceServiceComponents = Objects.requireNonNull(components); + this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( + elasticInferenceServiceSettings.getElasticInferenceServiceUrl() + ); this.modelRegistry = Objects.requireNonNull(modelRegistry); this.client = new OriginSettingClient(Objects.requireNonNull(client), ClientHelper.INFERENCE_ORIGIN); this.callback = callback; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java index b932d91e42745..26fe45cb2ee73 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java @@ -74,7 +74,7 @@ public static AuthorizationTaskExecutor create(ClusterService clusterService, Au // default for testing void init() { // If the EIS url is not configured, then we won't be able to interact with the service, so don't start the task. - if (Strings.isNullOrEmpty(pollerParameters.eisComponents().elasticInferenceServiceUrl()) == false) { + if (Strings.isNullOrEmpty(pollerParameters.elasticInferenceServiceSettings().getElasticInferenceServiceUrl()) == false) { clusterService.addListener(this); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index 8e0896fe65ad9..0e96f2d7f5310 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -14,9 +14,9 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; @@ -62,8 +62,7 @@ public void testDoesNotSendAuthorizationRequest_WhenModelRegistryIsNotReady() { createWithEmptySettings(taskQueue.getThreadPool()), authorizationRequestHandler, mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - new ElasticInferenceServiceComponents(""), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mockRegistry, mock(Client.class), null @@ -100,15 +99,12 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { var mockClient = mock(Client.class); when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); - var eisComponents = new ElasticInferenceServiceComponents(""); - var poller = new AuthorizationPoller( new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), createWithEmptySettings(taskQueue.getThreadPool()), mockAuthHandler, mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - eisComponents, + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mockRegistry, mockClient, null @@ -125,7 +121,7 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { List.of( PreconfiguredEndpointModelAdapter.createModel( InternalPreconfiguredEndpoints.getWithInferenceId(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2), - eisComponents + new ElasticInferenceServiceComponents("") ) ) ) @@ -159,15 +155,12 @@ public void testDoesNotAttemptToStoreModelIds_ThatDoNotExistInThePreconfiguredMa var mockClient = mock(Client.class); when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); - var eisComponents = new ElasticInferenceServiceComponents(""); - var poller = new AuthorizationPoller( new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), createWithEmptySettings(taskQueue.getThreadPool()), mockAuthHandler, mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - eisComponents, + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mockRegistry, mockClient, null @@ -206,15 +199,12 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra var mockClient = mock(Client.class); when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); - var eisComponents = new ElasticInferenceServiceComponents(""); - var poller = new AuthorizationPoller( new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), createWithEmptySettings(taskQueue.getThreadPool()), mockAuthHandler, mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - eisComponents, + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mockRegistry, mockClient, null @@ -252,7 +242,6 @@ public void testSendsTwoAuthorizationRequests() throws InterruptedException { }).when(mockAuthHandler).getAuthorization(any(), any()); var mockClient = mock(Client.class); - var eisComponents = new ElasticInferenceServiceComponents(""); var callbackCount = new AtomicInteger(0); var latch = new CountDownLatch(2); @@ -276,8 +265,7 @@ public void testSendsTwoAuthorizationRequests() throws InterruptedException { createWithEmptySettings(taskQueue.getThreadPool()), mockAuthHandler, mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - eisComponents, + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), mockRegistry, mockClient, callback diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java index 5b7be633bf584..1e375c3df1a2e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; import org.junit.After; import org.junit.Before; @@ -75,7 +74,6 @@ public void testCreatesTask_WhenItDoesNotExistOnClusterStateChange() { mock(ElasticInferenceServiceAuthorizationRequestHandler.class), mock(Sender.class), ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - new ElasticInferenceServiceComponents(eisUrl), mock(ModelRegistry.class), mock(Client.class) ) @@ -125,7 +123,6 @@ public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEis mock(ElasticInferenceServiceAuthorizationRequestHandler.class), mock(Sender.class), ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - new ElasticInferenceServiceComponents(""), mock(ModelRegistry.class), mock(Client.class) ) @@ -153,7 +150,6 @@ public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEis mock(ElasticInferenceServiceAuthorizationRequestHandler.class), mock(Sender.class), ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - new ElasticInferenceServiceComponents(null), mock(ModelRegistry.class), mock(Client.class) ) @@ -204,7 +200,6 @@ public void testDoesNotCreateTask_OnClusterStateChange_WhenItAlreadyExists() { mock(ElasticInferenceServiceAuthorizationRequestHandler.class), mock(Sender.class), ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - new ElasticInferenceServiceComponents(eisUrl), mock(ModelRegistry.class), mock(Client.class) ) diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 89a47a9e070d0..621ae74ecd4aa 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -175,6 +175,7 @@ public class Constants { "cluster:admin/xpack/enrich/put", "cluster:admin/xpack/enrich/reindex", "cluster:internal/xpack/inference/clear_inference_endpoint_cache", + "cluster:internal/xpack/inference/create_endpoints", "cluster:admin/xpack/inference/delete", "cluster:admin/xpack/inference/put", "cluster:admin/xpack/inference/update", From 83f2c655c8d3fc8f998b371acb6bacc54e51d5be Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 21 Oct 2025 19:55:48 +0000 Subject: [PATCH 11/32] [CI] Auto commit changes from spotless --- .../xpack/inference/integration/ModelRegistryIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 28541abeb1390..09d444638564d 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -48,12 +48,12 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse; import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ModelStoreResponse; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; From d39d7ef7be338bc12454d5c62cfaf76b133e80ea Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 24 Oct 2025 17:05:59 -0400 Subject: [PATCH 12/32] Starting integration tests --- .../AuthorizationTaskExecutorIT.java | 28 +++++++++++++------ .../xpack/inference/InferencePlugin.java | 3 -- .../ElasticInferenceServiceSettings.java | 4 +-- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index 48251009ab42f..fa9b8735bdca8 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.inference.integration; -import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; @@ -15,10 +15,13 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.junit.After; import org.junit.Before; +import org.junit.BeforeClass; +import java.io.IOException; import java.util.Collection; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; @@ -26,18 +29,22 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { + private static final MockWebServer webServer = new MockWebServer(); + private static String gatewayUrl; + private ModelRegistry modelRegistry; - private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; - private String gatewayUrl; - @Before - public void createComponents() throws Exception { - threadPool = createThreadPool(inferenceUtilityExecutors()); + @BeforeClass + public static void initClass() throws IOException { webServer.start(); gatewayUrl = getUrl(webServer); + } + + @Before + public void createComponents() { + threadPool = createThreadPool(inferenceUtilityExecutors()); modelRegistry = node().injector().getInstance(ModelRegistry.class); - node().injector().getInstance(ClusterState.class); } @After @@ -47,8 +54,11 @@ public void shutdown() { } @Override - protected boolean resetNodeAfterTest() { - return true; + protected Settings nodeSettings() { + return Settings.builder() + .put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), gatewayUrl) + .put(ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED.getKey(), false) + .build(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index fd4f97b1c7c60..90f72dba305f6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -331,8 +331,6 @@ public Collection createComponents(PluginServices services) { services.threadPool() ); - var eisComponents = new ElasticInferenceServiceComponents(inferenceServiceSettings.getElasticInferenceServiceUrl()); - var authTaskExecutor = AuthorizationTaskExecutor.create( services.clusterService(), new AuthorizationPoller.Parameters( @@ -340,7 +338,6 @@ public Collection createComponents(PluginServices services) { authorizationHandler, elasticInferenceServiceFactory.get().createSender(), inferenceServiceSettings, - eisComponents, modelRegistry.get(), services.client() ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index 0d8bef246b35d..fcc9808c7b8ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -28,7 +28,7 @@ public class ElasticInferenceServiceSettings { @Deprecated static final Setting EIS_GATEWAY_URL = Setting.simpleString("xpack.inference.eis.gateway.url", Setting.Property.NodeScope); - static final Setting ELASTIC_INFERENCE_SERVICE_URL = Setting.simpleString( + public static final Setting ELASTIC_INFERENCE_SERVICE_URL = Setting.simpleString( "xpack.inference.elastic.url", Setting.Property.NodeScope ); @@ -37,7 +37,7 @@ public class ElasticInferenceServiceSettings { * This setting is for testing only. It controls whether authorization is only performed once at bootup. If set to true, an * authorization request will be made repeatedly on an interval. */ - static final Setting PERIODIC_AUTHORIZATION_ENABLED = Setting.boolSetting( + public static final Setting PERIODIC_AUTHORIZATION_ENABLED = Setting.boolSetting( "xpack.inference.elastic.periodic_authorization_enabled", true, Setting.Property.NodeScope From 02d47661be15325727f08a314328f09845ddc775 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 24 Oct 2025 17:08:14 -0400 Subject: [PATCH 13/32] Adding test stub --- .../integration/AuthorizationTaskExecutorIT.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index fa9b8735bdca8..a7a2fd02bc1b4 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -16,7 +16,6 @@ import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; @@ -39,6 +38,7 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { public static void initClass() throws IOException { webServer.start(); gatewayUrl = getUrl(webServer); + // TODO add response to the web server to return no authorized models } @Before @@ -67,6 +67,11 @@ protected Collection> getPlugins() { } public void testCreateEndpoints() { - var executor = new AuthorizationTaskExecutor(); + // verify that no models are authorized + // add request to return an authorized model + // cancel the task + // ensure the task is recreated? + // verify that the authorized model is present + fail("Not implemented yet"); } } From b41b1d419ef4fd0b9648bb1341324430e327c121 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 24 Oct 2025 21:15:28 +0000 Subject: [PATCH 14/32] [CI] Auto commit changes from spotless --- .../xpack/core/inference/ModelTests.java | 4 +--- .../results/ModelStoreResponseTests.java | 17 ++++++----------- .../PreconfiguredEndpointModelAdapterTests.java | 16 +++------------- 3 files changed, 10 insertions(+), 27 deletions(-) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java index e616954c89c45..86787473edd3a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java @@ -122,8 +122,6 @@ public String modelId() { } public static List getNamedWriteables() { - return List.of( - new NamedWriteableRegistry.Entry(ServiceSettings.class, TestServiceSettings.NAME, TestServiceSettings::new) - ); + return List.of(new NamedWriteableRegistry.Entry(ServiceSettings.class, TestServiceSettings.NAME, TestServiceSettings::new)); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponseTests.java index b13772d00a840..e7160fc360bb3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponseTests.java @@ -38,11 +38,7 @@ public void testFailed() { assertTrue(failedResponse.failed()); } { - var failedResponse = new ModelStoreResponse( - "model_2", - RestStatus.OK, - new IllegalStateException("Test failure") - ); + var failedResponse = new ModelStoreResponse("model_2", RestStatus.OK, new IllegalStateException("Test failure")); assertTrue(failedResponse.failed()); } } @@ -70,12 +66,11 @@ protected ModelStoreResponse mutateInstance(ModelStoreResponse instance) throws String newInferenceId = instance.inferenceId() + "_mutated"; yield new ModelStoreResponse(newInferenceId, instance.status(), instance.failureCause()); } - case 1 -> - new ModelStoreResponse( - instance.inferenceId(), - randomValueOtherThan(instance.status(), () -> randomFrom(RestStatus.values())), - instance.failureCause() - ); + case 1 -> new ModelStoreResponse( + instance.inferenceId(), + randomValueOtherThan(instance.status(), () -> randomFrom(RestStatus.values())), + instance.failureCause() + ); case 2 -> { Exception newFailureCause = instance.failureCause() == null ? new IllegalStateException("Mutated exception") : null; yield new ModelStoreResponse(instance.inferenceId(), instance.status(), newFailureCause); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java index b3e3305f54dcb..e718c83c3f965 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java @@ -64,10 +64,7 @@ public void testGetModelsWithValidId() { DEFAULT_RERANK_ENDPOINT_ID_V1, DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID ); - var models = PreconfiguredEndpointModelAdapter.getModels( - endpointIds, - EIS_COMPONENTS - ); + var models = PreconfiguredEndpointModelAdapter.getModels(endpointIds, EIS_COMPONENTS); assertThat(models, hasSize(endpointIds.size())); assertThat( @@ -127,11 +124,7 @@ public void testGetModelsWithValidId() { public void testGetModelsWithValidAndInvalidIds() { var models = PreconfiguredEndpointModelAdapter.getModels( - Set.of( - DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, - "some-invalid-id", - DEFAULT_ELSER_ENDPOINT_ID_V2 - ), + Set.of(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, "some-invalid-id", DEFAULT_ELSER_ENDPOINT_ID_V2), EIS_COMPONENTS ); @@ -168,9 +161,6 @@ public void testGetModelsWithValidAndInvalidIds() { } public void testGetModelsWithOnlyInvalidId() { - assertThat(PreconfiguredEndpointModelAdapter.getModels( - Collections.singleton("nonexistent-id"), - EIS_COMPONENTS - ), is(List.of())); + assertThat(PreconfiguredEndpointModelAdapter.getModels(Collections.singleton("nonexistent-id"), EIS_COMPONENTS), is(List.of())); } } From 10ad8f2d9bbd28a675d0960c250828a829d3fa9e Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 28 Oct 2025 16:01:06 -0400 Subject: [PATCH 15/32] Adding integration test --- .../AuthorizationTaskExecutorIT.java | 173 +++++++++++++++++- .../authorization/AuthorizationPoller.java | 18 ++ .../AuthorizationTaskExecutor.java | 40 ++++ 3 files changed, 223 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index a7a2fd02bc1b4..76da8e7f62882 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -7,49 +7,92 @@ package org.elasticsearch.xpack.inference.integration; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.junit.After; +import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import java.io.IOException; import java.util.Collection; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { + private static final String EMPTY_AUTH_RESPONSE = """ + { + "models": [ + ] + } + """; + + private static final String AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + } + ] + } + """; + private static final MockWebServer webServer = new MockWebServer(); private static String gatewayUrl; private ModelRegistry modelRegistry; private ThreadPool threadPool; + private AuthorizationTaskExecutor authorizationTaskExecutor; @BeforeClass public static void initClass() throws IOException { webServer.start(); gatewayUrl = getUrl(webServer); - // TODO add response to the web server to return no authorized models + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); } @Before public void createComponents() { threadPool = createThreadPool(inferenceUtilityExecutors()); modelRegistry = node().injector().getInstance(ModelRegistry.class); + authorizationTaskExecutor = node().injector().getInstance(AuthorizationTaskExecutor.class); } @After public void shutdown() { + // Delete all the eis preconfigured endpoints + var listener = new PlainActionFuture(); + modelRegistry.deleteModels(InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS, listener); + listener.actionGet(TimeValue.THIRTY_SECONDS); + terminate(threadPool); + } + + @AfterClass + public static void cleanUpClass() { webServer.close(); } @@ -66,12 +109,126 @@ protected Collection> getPlugins() { return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); } - public void testCreateEndpoints() { - // verify that no models are authorized - // add request to return an authorized model - // cancel the task - // ensure the task is recreated? - // verify that the authorized model is present - fail("Not implemented yet"); + public void testCreatesEisChatCompletionEndpoint() throws Exception { + assertNoAuthorizedEisEndpoints(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + waitForNewAuthorizationResponse(); + + assertChatCompletionEndpointExists(); + } + + private void assertNoAuthorizedEisEndpoints() throws Exception { + assertBusy(() -> { + var newPoller = authorizationTaskExecutor.getCurrentPollerTask(); + assertNotNull(newPoller); + newPoller.waitForAuthorizationToComplete(TimeValue.THIRTY_SECONDS); + }); + + var eisEndpoints = getEisEndpoints(); + assertThat(eisEndpoints, empty()); + } + + private List getEisEndpoints() { + var listener = new PlainActionFuture>(); + modelRegistry.getAllModels(false, listener); + + var endpoints = listener.actionGet(TimeValue.THIRTY_SECONDS); + return endpoints.stream().filter(m -> m.service().equals(ElasticInferenceService.NAME)).toList(); + } + + private void waitForNewAuthorizationResponse() throws Exception { + var taskListener = new PlainActionFuture(); + + authorizationTaskExecutor.abortTask(TimeValue.THIRTY_SECONDS, taskListener); + // Ensure that the listener doesn't return a failure + assertNull(taskListener.actionGet(TimeValue.THIRTY_SECONDS)); + + // wait for the new task to be recreated + assertBusy(() -> { + var newPoller = authorizationTaskExecutor.getCurrentPollerTask(); + assertNotNull(newPoller); + newPoller.waitForAuthorizationToComplete(TimeValue.THIRTY_SECONDS); + }); + } + + public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized() throws Exception { + assertNoAuthorizedEisEndpoints(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + waitForNewAuthorizationResponse(); + + assertChatCompletionEndpointExists(); + + // Simulate that the model is no longer authorized + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + waitForNewAuthorizationResponse(); + + assertChatCompletionEndpointExists(); + } + + private void assertChatCompletionEndpointExists() { + var eisEndpoints = getEisEndpoints(); + assertThat(eisEndpoints.size(), is(1)); + + var rainbowSprinklesModel = eisEndpoints.get(0); + assertChatCompletionUnparsedModel(rainbowSprinklesModel); + } + + private void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) { + assertThat(rainbowSprinklesModel.taskType(), is(TaskType.CHAT_COMPLETION)); + assertThat(rainbowSprinklesModel.service(), is(ElasticInferenceService.NAME)); + assertThat(rainbowSprinklesModel.inferenceEntityId(), is(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + } + + public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception { + assertNoAuthorizedEisEndpoints(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + waitForNewAuthorizationResponse(); + + assertChatCompletionEndpointExists(); + + // Simulate that the model is no longer authorized + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + waitForNewAuthorizationResponse(); + + assertChatCompletionEndpointExists(); + + // Simulate that a text embedding model is now authorized + var authorizedTextEmbeddingResponse = """ + { + "models": [ + { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authorizedTextEmbeddingResponse)); + waitForNewAuthorizationResponse(); + + var eisEndpoints = getEisEndpoints().stream().collect(Collectors.toMap(UnparsedModel::inferenceEntityId, Function.identity())); + assertThat(eisEndpoints.size(), is(2)); + + assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + assertChatCompletionUnparsedModel(eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + + assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID)); + + var textEmbeddingEndpoint = eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID); + assertThat(textEmbeddingEndpoint.taskType(), is(TaskType.TEXT_EMBEDDING)); + assertThat(textEmbeddingEndpoint.service(), is(ElasticInferenceService.NAME)); + } + + public void testRestartsTaskAfterAbort() throws Exception { + // Ensure the task is created and we get an initial authorization response + assertNoAuthorizedEisEndpoints(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + // Abort the task and ensure it is restarted + waitForNewAuthorizationResponse(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index 8ace4613b6c76..c8bc51603b0e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -32,6 +32,8 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -56,6 +58,7 @@ public class AuthorizationPoller extends AllocatedPersistentTask { private final AtomicBoolean initialized = new AtomicBoolean(false); private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; private final Client client; + private final CountDownLatch receivedFirstAuthResponseLatch = new CountDownLatch(1); public record TaskFields(long id, String type, String action, String description, TaskId parentTask, Map headers) {} @@ -117,9 +120,23 @@ public void start() { } } + /** + * This should only be used for testing to wait for the first authorization response to be received. + */ + public void waitForAuthorizationToComplete(TimeValue waitTime) { + try { + if (receivedFirstAuthResponseLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { + throw new IllegalStateException("The wait time has expired for first authorization response to be received."); + } + } catch (InterruptedException e) { + throw new IllegalStateException("Waiting for first authorization response to complete was interrupted"); + } + } + @Override protected void onCancelled() { shutdown(); + markAsCompleted(); } // default for testing @@ -182,6 +199,7 @@ void sendAuthorizationRequest() { if (callback != null) { callback.run(); } + receivedFirstAuthResponseLatch.countDown(); }).delegateResponse((delegate, e) -> { logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints"); delegate.onResponse(null); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java index 26fe45cb2ee73..b637700dbc754 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java @@ -79,9 +79,49 @@ void init() { } } + /** + * This method should only be used for testing purposes to simulate a task being recreated. + */ + public void abortTask(TimeValue timeout, ActionListener listener) { + var task = currentTask.get(); + if (task != null && task.isCancelled() == false) { + task.markAsLocallyAborted("testing task cancellation"); + currentTask.set(null); + waitForNullTask(task, timeout, listener); + } else { + listener.onFailure(new IllegalStateException("Authorization poller task was not created yet, or was already aborted")); + } + } + + private void waitForNullTask(AllocatedPersistentTask task, TimeValue timeout, ActionListener listener) { + task.waitForPersistentTask( + Objects::isNull, + timeout, + new PersistentTasksService.WaitForPersistentTaskListener() { + @Override + public void onResponse(PersistentTasksCustomMetadata.PersistentTask persistentTask) { + listener.onResponse(null); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + } + ); + } + + /** + * This method should only be used for testing purposes to get the current running task. + */ + public AuthorizationPoller getCurrentPollerTask() { + return currentTask.get(); + } + @Override protected void nodeOperation(AllocatedPersistentTask task, AuthorizationTaskParams params, PersistentTaskState state) { var authPoller = (AuthorizationPoller) task; + currentTask.set(authPoller); authPoller.start(); } From 9762fc63a45541855e9ec9229e7042e7eb4560bd Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 29 Oct 2025 17:12:12 -0400 Subject: [PATCH 16/32] Fixing relocation test --- ..._api_eis_authorization_persistent_task.csv | 1 + .../resources/transport/upper_bounds/9.3.csv | 2 +- .../AuthorizationTaskExecutorIT.java | 10 +- ...AuthorizationTaskExecutorRelocationIT.java | 174 ++++++++++++++++++ .../AuthorizationTaskParams.java | 6 +- 5 files changed, 182 insertions(+), 11 deletions(-) create mode 100644 server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java diff --git a/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv b/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv new file mode 100644 index 0000000000000..bdb12ee6b228e --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv @@ -0,0 +1 @@ +9205000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index 80e35f2bfbc93..bc59eeabda3d2 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -esql_resolve_fields_response_used,9204000 +inference_api_eis_authorization_persistent_task,9205000 diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index 76da8e7f62882..f2a2dae30d10f 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -17,7 +17,6 @@ import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; @@ -35,21 +34,20 @@ import java.util.function.Function; import java.util.stream.Collectors; -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.is; public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { - private static final String EMPTY_AUTH_RESPONSE = """ + public static final String EMPTY_AUTH_RESPONSE = """ { "models": [ ] } """; - private static final String AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE = """ + public static final String AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE = """ { "models": [ { @@ -64,7 +62,6 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { private static String gatewayUrl; private ModelRegistry modelRegistry; - private ThreadPool threadPool; private AuthorizationTaskExecutor authorizationTaskExecutor; @BeforeClass @@ -76,7 +73,6 @@ public static void initClass() throws IOException { @Before public void createComponents() { - threadPool = createThreadPool(inferenceUtilityExecutors()); modelRegistry = node().injector().getInstance(ModelRegistry.class); authorizationTaskExecutor = node().injector().getInstance(AuthorizationTaskExecutor.class); } @@ -87,8 +83,6 @@ public void shutdown() { var listener = new PlainActionFuture(); modelRegistry.deleteModels(InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS, listener); listener.actionGet(TimeValue.THIRTY_SECONDS); - - terminate(threadPool); } @AfterClass diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java new file mode 100644 index 0000000000000..7a17d8a331809 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java @@ -0,0 +1,174 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.tasks.TaskInfo; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.EMPTY_AUTH_RESPONSE; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.not; + +/** + * These tests ensure that when a node is shutdown that is running an AuthorizationTaskExecutor, + * the task is properly relocated to another node. + */ +@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 +public class AuthorizationTaskExecutorRelocationIT extends ESIntegTestCase { + + private static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; + private static final MockWebServer webServer = new MockWebServer(); + private static String gatewayUrl; + + @BeforeClass + public static void initClass() throws IOException { + webServer.start(); + gatewayUrl = getUrl(webServer); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + } + + @AfterClass + public static void cleanUpClass() { + webServer.close(); + } + + @Override + protected Collection> nodePlugins() { + return List.of(ReindexPlugin.class, LocalStateInferencePlugin.class); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal, otherSettings)) + .put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial") + .put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), gatewayUrl) + .put(ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED.getKey(), false) + .build(); + } + + @Override + public Settings indexSettings() { + return Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(3, 10)).build(); + } + + public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRunningItShutsDown() throws Exception { + // Ensure we have multiple master and data nodes so we have somewhere to place the inference indices and so that we can safely + // shut down the node that is running the authorization task. If there is only one master we'll get an error that we can't shut + // down the only eligible master node + internalCluster().startMasterOnlyNodes(2); + internalCluster().ensureAtLeastNumDataNodes(2); + awaitMasterNode(); + + var nodeNameMapping = getNodeNames(internalCluster().getNodeNames()); + + var pollerTask = waitForTask(internalCluster().getNodeNames(), AUTH_TASK_ACTION); + + var getAllEndpointsRequest = new GetInferenceModelAction.Request("*", TaskType.ANY, true); + var endpoints = client().execute(GetInferenceModelAction.INSTANCE, getAllEndpointsRequest).actionGet(); + assertTrue( + "expected no authorized EIS endpoints", + endpoints.getEndpoints().stream().noneMatch(endpoint -> endpoint.getService().equals(ElasticInferenceService.NAME)) + ); + + // queue a response that authorizes one model + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + + assertTrue(internalCluster().stopNode(nodeNameMapping.get(pollerTask.node()))); + awaitMasterNode(); + + assertBusy(() -> { + var relocatedPollerTask = waitForTask(internalCluster().getNodeNames(), AUTH_TASK_ACTION); + assertThat(relocatedPollerTask.node(), not(is(pollerTask.node()))); + }); + + assertBusy(() -> { + var allEndpoints = client().execute(GetInferenceModelAction.INSTANCE, getAllEndpointsRequest).actionGet(); + var eisEndpoints = allEndpoints.getEndpoints() + .stream() + .filter(endpoint -> endpoint.getService().equals(ElasticInferenceService.NAME)) + .toList(); + assertThat(eisEndpoints.size(), is(1)); + + var rainbowSprinklesEndpoint = eisEndpoints.get(0); + assertThat(rainbowSprinklesEndpoint.getService(), is(ElasticInferenceService.NAME)); + assertThat( + rainbowSprinklesEndpoint.getInferenceEntityId(), + is(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1) + ); + assertThat(rainbowSprinklesEndpoint.getTaskType(), is(TaskType.CHAT_COMPLETION)); + }); + } + + private TaskInfo waitForTask(String[] nodes, String taskAction) throws Exception { + var taskRef = new AtomicReference(); + assertBusy(() -> { + var response = admin().cluster().prepareListTasks(nodes).get(); + var authPollerTask = response.getTasks().stream().filter(task -> task.action().equals(taskAction)).findFirst(); + assertTrue(authPollerTask.isPresent()); + taskRef.set(authPollerTask.get()); + }); + + return taskRef.get(); + } + + private record NodeNameMapping(Map nodeNamesMap) { + public String get(String rawNodeName) { + var nodeName = nodeNamesMap.get(rawNodeName); + if (nodeName == null) { + throw new IllegalArgumentException("No node name found for raw node name: " + rawNodeName); + } + + return nodeName; + } + } + + /** + * The node names created by the integration test framework take the form of "node_#", but the task api gives a raw node name + * like 02PT2SBzRxC3cG-9mKCigQ, so we need to map between them to be able to act on a node that the task is currently running on. + */ + private static NodeNameMapping getNodeNames(String[] nodes) { + var nodeNamesMap = new HashMap(); + for (var node : nodes) { + var nodeTasks = admin().cluster().prepareListTasks(node).get(); + assertThat(nodeTasks.getTasks().size(), greaterThanOrEqualTo(1)); + nodeNamesMap.put(nodeTasks.getTasks().getFirst().node(), node); + } + + return new NodeNameMapping(nodeNamesMap); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java index a93428e35ae3c..976a1e8307427 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.services.elastic.authorization; import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.persistent.PersistentTaskParams; @@ -25,6 +24,9 @@ public class AuthorizationTaskParams implements PersistentTaskParams { public static final AuthorizationTaskParams INSTANCE = new AuthorizationTaskParams(); private static final ObjectParser PARSER = new ObjectParser<>(TASK_NAME, true, () -> INSTANCE); + private static final TransportVersion INFERENCE_API_EIS_AUTHORIZATION_PERSISTENT_TASK = TransportVersion.fromName( + "inference_api_eis_authorization_persistent_task" + ); AuthorizationTaskParams() {} @@ -44,7 +46,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_5_0; + return INFERENCE_API_EIS_AUTHORIZATION_PERSISTENT_TASK; } @Override From 0b1551d2096c173b07d10623cb1662aaa90d2c80 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 29 Oct 2025 21:18:29 +0000 Subject: [PATCH 17/32] [CI] Auto commit changes from spotless --- .../integration/AuthorizationTaskExecutorRelocationIT.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java index 7a17d8a331809..7a56e8f87b2e6 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java @@ -81,8 +81,7 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { @Override public Settings indexSettings() { - return Settings.builder() - .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(3, 10)).build(); + return Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(3, 10)).build(); } public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRunningItShutsDown() throws Exception { From 6f2c27b0cd1bdcf4a709ad90fac00b2f4ae6cd57 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 29 Oct 2025 18:54:44 -0400 Subject: [PATCH 18/32] working test --- ...AuthorizationTaskExecutorRelocationIT.java | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java index 7a17d8a331809..6dc28b3cee8bd 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java @@ -81,8 +81,7 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { @Override public Settings indexSettings() { - return Settings.builder() - .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(3, 10)).build(); + return Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(3, 10)).build(); } public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRunningItShutsDown() throws Exception { @@ -97,8 +96,7 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun var pollerTask = waitForTask(internalCluster().getNodeNames(), AUTH_TASK_ACTION); - var getAllEndpointsRequest = new GetInferenceModelAction.Request("*", TaskType.ANY, true); - var endpoints = client().execute(GetInferenceModelAction.INSTANCE, getAllEndpointsRequest).actionGet(); + var endpoints = getAllEndpoints(); assertTrue( "expected no authorized EIS endpoints", endpoints.getEndpoints().stream().noneMatch(endpoint -> endpoint.getService().equals(ElasticInferenceService.NAME)) @@ -116,7 +114,8 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun }); assertBusy(() -> { - var allEndpoints = client().execute(GetInferenceModelAction.INSTANCE, getAllEndpointsRequest).actionGet(); + var allEndpoints = getAllEndpoints(); + var eisEndpoints = allEndpoints.getEndpoints() .stream() .filter(endpoint -> endpoint.getService().equals(ElasticInferenceService.NAME)) @@ -131,6 +130,7 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun ); assertThat(rainbowSprinklesEndpoint.getTaskType(), is(TaskType.CHAT_COMPLETION)); }); + } private TaskInfo waitForTask(String[] nodes, String taskAction) throws Exception { @@ -171,4 +171,22 @@ private static NodeNameMapping getNodeNames(String[] nodes) { return new NodeNameMapping(nodeNamesMap); } + private GetInferenceModelAction.Response getAllEndpoints() throws Exception { + var getAllEndpointsRequest = new GetInferenceModelAction.Request("*", TaskType.ANY, true); + + var allEndpointsRef = new AtomicReference(); + assertBusy(() -> { + try { + allEndpointsRef.set( + internalCluster().masterClient().execute(GetInferenceModelAction.INSTANCE, getAllEndpointsRequest).actionGet() + ); + } catch (Exception e) { + // We probably got a shards failed exception because the indices aren't ready yet, we'll just try again + logger.warn("Failed to retrieve endpoints", e); + fail("Failed to retrieve endpoints"); + } + }); + + return allEndpointsRef.get(); + } } From fd0b0cfec0f11f930a53b7598540d566bd50b4ef Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 30 Oct 2025 09:16:25 -0400 Subject: [PATCH 19/32] Some clean up --- docs/changelog/136713.yaml | 2 +- .../action/StoreInferenceEndpointsAction.java | 5 + .../xpack/core/inference/ModelTests.java | 120 +++++++++++++++++- .../AuthorizationTaskExecutorIT.java | 18 +-- ...orizationTaskExecutorMultipleNodesIT.java} | 38 ++++-- .../AuthorizationTaskExecutor.java | 8 +- .../AuthorizationTaskExecutorTests.java | 12 +- .../AuthorizationTaskParamsTests.java | 1 + .../xpack/security/operator/Constants.java | 4 +- 9 files changed, 171 insertions(+), 37 deletions(-) rename x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/{AuthorizationTaskExecutorRelocationIT.java => AuthorizationTaskExecutorMultipleNodesIT.java} (83%) diff --git a/docs/changelog/136713.yaml b/docs/changelog/136713.yaml index 9b88a8aed1111..45dedd222f07e 100644 --- a/docs/changelog/136713.yaml +++ b/docs/changelog/136713.yaml @@ -1,5 +1,5 @@ pr: 136713 -summary: Transition EIS auth polling to master node +summary: Transition EIS auth polling to persistent task on a single node area: Machine Learning type: enhancement issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java index 1b41e72e2af6f..aa613cda60399 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java @@ -20,6 +20,11 @@ import java.util.List; import java.util.Objects; +/** + * Internal action to store inference endpoints and return the results of the store operation. This should only be used internally and not + * exposed via a REST API. + * For the exposed REST API action see {@link PutInferenceModelAction}. + */ public class StoreInferenceEndpointsAction extends ActionType { public static final StoreInferenceEndpointsAction INSTANCE = new StoreInferenceEndpointsAction(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java index 86787473edd3a..05a9227ae278e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java @@ -8,9 +8,11 @@ package org.elasticsearch.xpack.core.inference; import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.EmptySecretSettings; @@ -18,18 +20,24 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.XPackClientPlugin; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; +import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Objects; -public class ModelTests extends ESTestCase { +public class ModelTests extends AbstractBWCWireSerializationTestCase { public static Model randomModel() { return new Model( new ModelConfigurations( @@ -121,7 +129,113 @@ public String modelId() { } } + public record SimpleSecretSettings(String field) implements SecretSettings { + public static final String NAME = "simple_secret_settings"; + private static final String FIELD_KEY = "field"; + + public SimpleSecretSettings { + Objects.requireNonNull(field); + } + + public SimpleSecretSettings(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(field); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_KEY, field); + builder.endObject(); + return builder; + } + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + if (newSecrets == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + var value = newSecrets.get(FIELD_KEY); + if (value == null) { + validationException.addValidationError("Missing required secret setting: " + FIELD_KEY); + throw validationException; + } else if (value instanceof String == false) { + validationException.addValidationError("Expected secret setting [" + FIELD_KEY + "] to be of type String"); + throw validationException; + } + return new SimpleSecretSettings((String) value); + } + } + public static List getNamedWriteables() { - return List.of(new NamedWriteableRegistry.Entry(ServiceSettings.class, TestServiceSettings.NAME, TestServiceSettings::new)); + return List.of( + new NamedWriteableRegistry.Entry(ServiceSettings.class, TestServiceSettings.NAME, TestServiceSettings::new), + new NamedWriteableRegistry.Entry(SecretSettings.class, SimpleSecretSettings.NAME, SimpleSecretSettings::new) + ); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + var namedWriteables = new ArrayList(); + namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, EmptySecretSettings.NAME, EmptySecretSettings::new)); + namedWriteables.addAll(getNamedWriteables()); + namedWriteables.addAll(XPackClientPlugin.getChunkingSettingsNamedWriteables()); + + return new NamedWriteableRegistry(namedWriteables); + } + + @Override + protected Model mutateInstanceForVersion(Model instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return Model::new; + } + + @Override + protected Model createTestInstance() { + return randomModel(); + } + + @Override + protected Model mutateInstance(Model instance) throws IOException { + int choice = randomIntBetween(0, 1); + switch (choice) { + case 0 -> { + var originalConfig = instance.getConfigurations(); + ModelConfigurations mutatedConfig = new ModelConfigurations( + originalConfig.getInferenceEntityId() + "_mutated", + originalConfig.getTaskType(), + originalConfig.getService(), + originalConfig.getServiceSettings(), + originalConfig.getTaskSettings(), + originalConfig.getChunkingSettings() + ); + return new Model(mutatedConfig, instance.getSecrets()); + } + case 1 -> { + return new Model(instance.getConfigurations(), new ModelSecrets(new SimpleSecretSettings(randomAlphaOfLength(10)))); + } + default -> throw new IllegalStateException("Unexpected value: " + choice); + } } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index f2a2dae30d10f..e33a07f14c34e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -94,6 +94,8 @@ public static void cleanUpClass() { protected Settings nodeSettings() { return Settings.builder() .put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), gatewayUrl) + // Ensure that the polling logic only occurs once so we can deterministically control when an authorization response is + // received .put(ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED.getKey(), false) .build(); } @@ -107,7 +109,7 @@ public void testCreatesEisChatCompletionEndpoint() throws Exception { assertNoAuthorizedEisEndpoints(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); - waitForNewAuthorizationResponse(); + restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); } @@ -131,7 +133,7 @@ private List getEisEndpoints() { return endpoints.stream().filter(m -> m.service().equals(ElasticInferenceService.NAME)).toList(); } - private void waitForNewAuthorizationResponse() throws Exception { + private void restartPollingTaskAndWaitForAuthResponse() throws Exception { var taskListener = new PlainActionFuture(); authorizationTaskExecutor.abortTask(TimeValue.THIRTY_SECONDS, taskListener); @@ -150,13 +152,13 @@ public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthor assertNoAuthorizedEisEndpoints(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); - waitForNewAuthorizationResponse(); + restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); // Simulate that the model is no longer authorized webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); - waitForNewAuthorizationResponse(); + restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); } @@ -179,13 +181,13 @@ public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Excep assertNoAuthorizedEisEndpoints(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); - waitForNewAuthorizationResponse(); + restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); // Simulate that the model is no longer authorized webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); - waitForNewAuthorizationResponse(); + restartPollingTaskAndWaitForAuthResponse(); assertChatCompletionEndpointExists(); @@ -202,7 +204,7 @@ public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Excep """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authorizedTextEmbeddingResponse)); - waitForNewAuthorizationResponse(); + restartPollingTaskAndWaitForAuthResponse(); var eisEndpoints = getEisEndpoints().stream().collect(Collectors.toMap(UnparsedModel::inferenceEntityId, Function.identity())); assertThat(eisEndpoints.size(), is(2)); @@ -223,6 +225,6 @@ public void testRestartsTaskAfterAbort() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); // Abort the task and ensure it is restarted - waitForNewAuthorizationResponse(); + restartPollingTaskAndWaitForAuthResponse(); } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java similarity index 83% rename from x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java rename to x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java index 6dc28b3cee8bd..73f5c118087d6 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorRelocationIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java @@ -12,10 +12,8 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.tasks.TaskInfo; import org.elasticsearch.test.ESIntegTestCase; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; @@ -42,11 +40,11 @@ import static org.hamcrest.Matchers.not; /** - * These tests ensure that when a node is shutdown that is running an AuthorizationTaskExecutor, - * the task is properly relocated to another node. + * These tests handle testing task relocation and cancellation. + * If the task is running on a node that is shutdown, it should be relocated to another node. + * If the task is cancelled it should be restarted automatically. */ -@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 -public class AuthorizationTaskExecutorRelocationIT extends ESIntegTestCase { +public class AuthorizationTaskExecutorMultipleNodesIT extends ESIntegTestCase { private static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; private static final MockWebServer webServer = new MockWebServer(); @@ -66,7 +64,7 @@ public static void cleanUpClass() { @Override protected Collection> nodePlugins() { - return List.of(ReindexPlugin.class, LocalStateInferencePlugin.class); + return List.of(LocalStateInferencePlugin.class); } @Override @@ -84,13 +82,28 @@ public Settings indexSettings() { return Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(3, 10)).build(); } + public void testCancellingAuthorizationTaskRestartsIt() throws Exception { + var pollerTask = waitForTask(internalCluster().getNodeNames(), AUTH_TASK_ACTION); + + assertBusy(() -> { + var cancelTaskResponse = admin().cluster() + .prepareCancelTasks(internalCluster().getNodeNames()) + .setActions(AUTH_TASK_ACTION) + .get(); + assertThat(cancelTaskResponse.getTasks().size(), is(1)); + assertThat(cancelTaskResponse.getTasks().get(0).action(), is(AUTH_TASK_ACTION)); + }); + + var newPollerTask = waitForTask(internalCluster().getNodeNames(), AUTH_TASK_ACTION); + assertThat(newPollerTask.taskId(), is(not(pollerTask.taskId()))); + } + public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRunningItShutsDown() throws Exception { // Ensure we have multiple master and data nodes so we have somewhere to place the inference indices and so that we can safely - // shut down the node that is running the authorization task. If there is only one master we'll get an error that we can't shut - // down the only eligible master node + // shut down the node that is running the authorization task. If there is only one master and it is running the task, + // we'll get an error that we can't shut down the only eligible master node internalCluster().startMasterOnlyNodes(2); internalCluster().ensureAtLeastNumDataNodes(2); - awaitMasterNode(); var nodeNameMapping = getNodeNames(internalCluster().getNodeNames()); @@ -105,8 +118,7 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun // queue a response that authorizes one model webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); - assertTrue(internalCluster().stopNode(nodeNameMapping.get(pollerTask.node()))); - awaitMasterNode(); + assertTrue("expected the node to shutdown properly", internalCluster().stopNode(nodeNameMapping.get(pollerTask.node()))); assertBusy(() -> { var relocatedPollerTask = waitForTask(internalCluster().getNodeNames(), AUTH_TASK_ACTION); @@ -181,7 +193,7 @@ private GetInferenceModelAction.Response getAllEndpoints() throws Exception { internalCluster().masterClient().execute(GetInferenceModelAction.INSTANCE, getAllEndpointsRequest).actionGet() ); } catch (Exception e) { - // We probably got a shards failed exception because the indices aren't ready yet, we'll just try again + // We probably got an all shards failed exception because the indices aren't ready yet, we'll just try again logger.warn("Failed to retrieve endpoints", e); fail("Failed to retrieve endpoints"); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java index b637700dbc754..f80fa1b0345fa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java @@ -156,10 +156,10 @@ public void clusterChanged(ClusterChangedEvent event) { TASK_NAME, new AuthorizationTaskParams(), TimeValue.THIRTY_SECONDS, - ActionListener.wrap(persistentTask -> logger.debug("Created authorization poller task"), e -> { - var t = e instanceof RemoteTransportException ? e.getCause() : e; - if (t instanceof ResourceAlreadyExistsException == false) { - logger.error("Failed to create authorization poller task", e); + ActionListener.wrap(persistentTask -> logger.debug("Created authorization poller task"), exception -> { + var thrownException = exception instanceof RemoteTransportException ? exception.getCause() : exception; + if (thrownException instanceof ResourceAlreadyExistsException == false) { + logger.error("Failed to create authorization poller task", exception); } }) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java index 1e375c3df1a2e..15b586d62890d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java @@ -129,9 +129,9 @@ public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEis ); executor.init(); - var listener1 = new PlainActionFuture(); - clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener1); - listener1.actionGet(TimeValue.THIRTY_SECONDS); + var listener = new PlainActionFuture(); + clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener); + listener.actionGet(TimeValue.THIRTY_SECONDS); verify(persistentTasksService, never()).sendClusterStartRequest( eq(AuthorizationPoller.TASK_NAME), eq(AuthorizationPoller.TASK_NAME), @@ -156,9 +156,9 @@ public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEis ); executor.init(); - var listener1 = new PlainActionFuture(); - clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener1); - listener1.actionGet(TimeValue.THIRTY_SECONDS); + var listener = new PlainActionFuture(); + clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener); + listener.actionGet(TimeValue.THIRTY_SECONDS); verify(persistentTasksService, never()).sendClusterStartRequest( eq(AuthorizationPoller.TASK_NAME), eq(AuthorizationPoller.TASK_NAME), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParamsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParamsTests.java index b07036bb2a612..e57bcb6c99a49 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParamsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParamsTests.java @@ -32,6 +32,7 @@ protected AuthorizationTaskParams createTestInstance() { @Override protected AuthorizationTaskParams mutateInstance(AuthorizationTaskParams instance) throws IOException { + // need to return null here because the instances will always be identical return null; } } diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 621ae74ecd4aa..b5ec2b85212f6 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -174,8 +174,6 @@ public class Constants { "cluster:admin/xpack/enrich/get", "cluster:admin/xpack/enrich/put", "cluster:admin/xpack/enrich/reindex", - "cluster:internal/xpack/inference/clear_inference_endpoint_cache", - "cluster:internal/xpack/inference/create_endpoints", "cluster:admin/xpack/inference/delete", "cluster:admin/xpack/inference/put", "cluster:admin/xpack/inference/update", @@ -328,6 +326,8 @@ public class Constants { "cluster:admin/xpack/watcher/watch/put", "cluster:internal/remote_cluster/nodes", "cluster:internal/xpack/inference", + "cluster:internal/xpack/inference/clear_inference_endpoint_cache", + "cluster:internal/xpack/inference/create_endpoints", "cluster:internal/xpack/inference/rerankwindowsize/get", "cluster:internal/xpack/inference/unified", "cluster:internal/xpack/ml/auditor/reset", From 5f99fe278574205c8343ca8c9ca550e1d46bbabc Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 30 Oct 2025 14:11:21 -0400 Subject: [PATCH 20/32] Removing unneeded tests --- .../InferenceRevokeDefaultEndpointsIT.java | 357 -------------- .../xpack/inference/InferencePlugin.java | 2 - .../mapper/SemanticTextFieldMapper.java | 2 +- .../elastic/ElasticInferenceService.java | 141 +----- ...cInferenceServiceAuthorizationHandler.java | 336 -------------- .../elastic/ElasticInferenceServiceTests.java | 436 +++--------------- ...renceServiceAuthorizationHandlerTests.java | 283 ------------ 7 files changed, 80 insertions(+), 1477 deletions(-) delete mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java deleted file mode 100644 index 8c4a2b1b2504c..0000000000000 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ /dev/null @@ -1,357 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.integration; - -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.MinimalServiceSettings; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.reindex.ReindexPlugin; -import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.http.MockResponse; -import org.elasticsearch.test.http.MockWebServer; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; -import org.junit.After; -import org.junit.Before; - -import java.util.Collection; -import java.util.EnumSet; -import java.util.List; -import java.util.concurrent.TimeUnit; - -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; -import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.mockito.Mockito.mock; - -@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 -public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { - private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); - - private ModelRegistry modelRegistry; - private final MockWebServer webServer = new MockWebServer(); - private ThreadPool threadPool; - private String gatewayUrl; - - @Before - public void createComponents() throws Exception { - threadPool = createThreadPool(inferenceUtilityExecutors()); - webServer.start(); - gatewayUrl = getUrl(webServer); - modelRegistry = node().injector().getInstance(ModelRegistry.class); - } - - @After - public void shutdown() { - terminate(threadPool); - webServer.close(); - } - - @Override - protected boolean resetNodeAfterTest() { - return true; - } - - @Override - protected Collection> getPlugins() { - return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); - } - - public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - } - } - - public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception { - { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - - var getModelListener = new PlainActionFuture(); - // persists the default endpoints - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var inferenceEntity = getModelListener.actionGet(TIMEOUT); - assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION)); - } - } - { - String noAuthorizationResponseJson = """ - { - "models": [] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertTrue(service.defaultConfigIds().isEmpty()); - assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); - - var getModelListener = new PlainActionFuture(); - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]")); - } - } - } - - public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt() throws Exception { - { - String responseJson = """ - { - "models": [ - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - }, - { - "model_name": "jina-embeddings-v3", - "task_types": ["embed/text/dense"] - }, - { - "model_name": "elastic-rerank-v1", - "task_types": ["rerank/text/text-similarity"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - containsInAnyOrder( - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".jina-embeddings-v3", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".elastic-rerank-v1", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING)) - ); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elastic-rerank-v1")); - assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic")); - assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".jina-embeddings-v3")); - assertThat(listener.actionGet(TIMEOUT).get(3).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - - var getModelListener = new PlainActionFuture(); - // persists the default endpoints - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var inferenceEntity = getModelListener.actionGet(TIMEOUT); - assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION)); - } - } - { - String noAuthorizationResponseJson = """ - { - "models": [ - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "elastic-rerank-v1", - "task_types": ["rerank/text/text-similarity"] - }, - { - "model_name": "jina-embeddings-v3", - "task_types": ["embed/text/dense"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - containsInAnyOrder( - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".jina-embeddings-v3", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".elastic-rerank-v1", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)) - ); - - var getModelListener = new PlainActionFuture(); - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]")); - } - } - } - - private void ensureAuthorizationCallFinished(ElasticInferenceService service) { - service.onNodeStarted(); - service.waitForFirstAuthorizationToComplete(TIMEOUT); - } - - private ElasticInferenceService createElasticInferenceService() { - var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager); - - return new ElasticInferenceService( - senderFactory, - createWithEmptySettings(threadPool), - ElasticInferenceServiceSettingsTests.create(gatewayUrl), - modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool), - mockClusterServiceEmpty() - ); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 90f72dba305f6..6125a79c01b58 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -352,8 +352,6 @@ public Collection createComponents(PluginServices services) { elasticInferenceServiceFactory.get(), serviceComponents.get(), inferenceServiceSettings, - modelRegistry.get(), - authorizationHandler, context ), context -> new SageMakerService( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 109646cf0e827..8f6a902773394 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -121,7 +121,7 @@ import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOffsetsFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; -import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_ELSER_ENDPOINT_ID_V2; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.DEFAULT_ELSER_ID; /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 9fa47cdb23b2d..0e89b697734af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -16,17 +16,13 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.EmptySecretSettings; -import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -48,22 +44,16 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionCreator; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; -import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.EnumSet; @@ -108,21 +98,6 @@ public class ElasticInferenceService extends SenderService { // This mirrors the memory constraints observed with sparse embeddings private static final Integer DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE = 16; - // rainbow-sprinkles - static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; - static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); - - // elser-2 - static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; - public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2"); - - // multilingual-text-embed - static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "jina-embeddings-v3"; - static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = ".jina-embeddings-v3"; - - // rerank-v1 - static final String DEFAULT_RERANK_MODEL_ID_V1 = "elastic-rerank-v1"; - static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = ".elastic-rerank-v1"; /** * The task types that the {@link InferenceAction.Request} can accept. @@ -133,27 +108,18 @@ public class ElasticInferenceService extends SenderService { TaskType.TEXT_EMBEDDING ); - public static String defaultEndpointId(String modelId) { - return Strings.format(".%s-elastic", modelId); - } - private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; - private final ElasticInferenceServiceAuthorizationHandler authorizationHandler; public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ElasticInferenceServiceSettings elasticInferenceServiceSettings, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, InferenceServiceExtension.InferenceServiceFactoryContext context ) { this( factory, serviceComponents, elasticInferenceServiceSettings, - modelRegistry, - authorizationRequestHandler, context.clusterService() ); } @@ -162,95 +128,12 @@ public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ElasticInferenceServiceSettings elasticInferenceServiceSettings, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, ClusterService clusterService ) { super(factory, serviceComponents, clusterService); this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( elasticInferenceServiceSettings.getElasticInferenceServiceUrl() ); - authorizationHandler = new ElasticInferenceServiceAuthorizationHandler( - serviceComponents, - modelRegistry, - authorizationRequestHandler, - initDefaultEndpoints(elasticInferenceServiceComponents), - IMPLEMENTED_TASK_TYPES, - this, - getSender(), - elasticInferenceServiceSettings - ); - } - - private static Map initDefaultEndpoints( - ElasticInferenceServiceComponents elasticInferenceServiceComponents - ) { - return Map.of( - DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, - new DefaultModelConfig( - new ElasticInferenceServiceCompletionModel( - DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, - TaskType.CHAT_COMPLETION, - NAME, - new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents - ), - MinimalServiceSettings.chatCompletion(NAME) - ), - DEFAULT_ELSER_2_MODEL_ID, - new DefaultModelConfig( - new ElasticInferenceServiceSparseEmbeddingsModel( - DEFAULT_ELSER_ENDPOINT_ID_V2, - TaskType.SPARSE_EMBEDDING, - NAME, - new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - MinimalServiceSettings.sparseEmbedding(NAME) - ), - DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, - new DefaultModelConfig( - new ElasticInferenceServiceDenseTextEmbeddingsModel( - DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, - TaskType.TEXT_EMBEDDING, - NAME, - new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, - defaultDenseTextEmbeddingsSimilarity(), - null, - null - ), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - MinimalServiceSettings.textEmbedding( - NAME, - DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ) - ), - DEFAULT_RERANK_MODEL_ID_V1, - new DefaultModelConfig( - new ElasticInferenceServiceRerankModel( - DEFAULT_RERANK_ENDPOINT_ID_V1, - TaskType.RERANK, - NAME, - new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents - ), - MinimalServiceSettings.rerank(NAME) - ) - ); } @Override @@ -265,31 +148,11 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V } } - /** - * Only use this in tests. - * - * Waits the specified amount of time for the authorization call to complete. This is mainly to make testing easier. - * @param waitTime the max time to wait - * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} - */ - public void waitForFirstAuthorizationToComplete(TimeValue waitTime) { - authorizationHandler.waitForAuthorizationToComplete(waitTime); - } - @Override public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION); } - @Override - public List defaultConfigIds() { - return authorizationHandler.defaultConfigIds(); - } - - @Override - public void defaultConfigs(ActionListener> defaultsListener) { - authorizationHandler.defaultConfigs(defaultsListener); - } @Override protected void doUnifiedCompletionInfer( @@ -467,7 +330,9 @@ public InferenceServiceConfiguration getConfiguration() { @Override public EnumSet supportedTaskTypes() { - return authorizationHandler.supportedTaskTypes(); + throw new UnsupportedOperationException( + "The EIS supported task types change depending on authorization, requests should be made directly to EIS instead" + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java deleted file mode 100644 index f83542e7fe740..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ /dev/null @@ -1,336 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.Randomness; -import org.elasticsearch.common.Strings; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.threadpool.Scheduler; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; - -import java.io.Closeable; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.EnumSet; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.TreeSet; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; - -public class ElasticInferenceServiceAuthorizationHandler implements Closeable { - private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandler.class); - - private record AuthorizedContent( - ElasticInferenceServiceAuthorizationModel taskTypesAndModels, - List configIds, - List defaultModelConfigs - ) { - static AuthorizedContent empty() { - return new AuthorizedContent(ElasticInferenceServiceAuthorizationModel.newDisabledService(), List.of(), List.of()); - } - } - - private final ServiceComponents serviceComponents; - private final AtomicReference authorizedContent = new AtomicReference<>(AuthorizedContent.empty()); - private final ModelRegistry modelRegistry; - private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler; - private final Map defaultModelsConfigs; - private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1); - private final EnumSet implementedTaskTypes; - private final InferenceService inferenceService; - private final Sender sender; - private final Runnable callback; - private final AtomicReference lastAuthTask = new AtomicReference<>(null); - private final AtomicBoolean shutdown = new AtomicBoolean(false); - private final ElasticInferenceServiceSettings elasticInferenceServiceSettings; - - public ElasticInferenceServiceAuthorizationHandler( - ServiceComponents serviceComponents, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, - Map defaultModelsConfigs, - EnumSet implementedTaskTypes, - InferenceService inferenceService, - Sender sender, - ElasticInferenceServiceSettings elasticInferenceServiceSettings - ) { - this( - serviceComponents, - modelRegistry, - authorizationRequestHandler, - defaultModelsConfigs, - implementedTaskTypes, - Objects.requireNonNull(inferenceService), - sender, - elasticInferenceServiceSettings, - null - ); - } - - // default for testing - ElasticInferenceServiceAuthorizationHandler( - ServiceComponents serviceComponents, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, - Map defaultModelsConfigs, - EnumSet implementedTaskTypes, - InferenceService inferenceService, - Sender sender, - ElasticInferenceServiceSettings elasticInferenceServiceSettings, - // this is a hack to facilitate testing - Runnable callback - ) { - this.serviceComponents = Objects.requireNonNull(serviceComponents); - this.modelRegistry = Objects.requireNonNull(modelRegistry); - this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); - this.defaultModelsConfigs = Objects.requireNonNull(defaultModelsConfigs); - this.implementedTaskTypes = Objects.requireNonNull(implementedTaskTypes); - // allow the service to be null for testing - this.inferenceService = inferenceService; - this.sender = Objects.requireNonNull(sender); - this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); - this.callback = callback; - } - - public void init() { - logger.debug("Initializing authorization logic"); - serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest); - } - - /** - * Waits the specified amount of time for the first authorization call to complete. This is mainly to make testing easier. - * @param waitTime the max time to wait - * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} - */ - public void waitForAuthorizationToComplete(TimeValue waitTime) { - try { - if (firstAuthorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { - throw new IllegalStateException("The wait time has expired for authorization to complete."); - } - } catch (InterruptedException e) { - throw new IllegalStateException("Waiting for authorization to complete was interrupted"); - } - } - - public synchronized Set supportedStreamingTasks() { - var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); - authorizedStreamingTaskTypes.retainAll(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes()); - - return authorizedStreamingTaskTypes; - } - - public synchronized List defaultConfigIds() { - return authorizedContent.get().configIds; - } - - public synchronized void defaultConfigs(ActionListener> defaultsListener) { - var models = authorizedContent.get().defaultModelConfigs.stream().map(DefaultModelConfig::model).toList(); - defaultsListener.onResponse(models); - } - - public synchronized EnumSet supportedTaskTypes() { - return authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes(); - } - - public synchronized boolean hideFromConfigurationApi() { - return authorizedContent.get().taskTypesAndModels.isAuthorized() == false; - } - - @Override - public void close() throws IOException { - shutdown.set(true); - if (lastAuthTask.get() != null) { - lastAuthTask.get().cancel(); - } - } - - private void scheduleAuthorizationRequest() { - try { - if (elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled() == false) { - return; - } - - // this call has to be on the individual thread otherwise we get an exception - var random = Randomness.get(); - var jitter = (long) (elasticInferenceServiceSettings.getMaxAuthorizationRequestJitter().millis() * random.nextDouble()); - var waitTime = TimeValue.timeValueMillis(elasticInferenceServiceSettings.getAuthRequestInterval().millis() + jitter); - - logger.debug( - () -> Strings.format( - "Scheduling the next authorization call with request interval: %s ms, jitter: %d ms", - elasticInferenceServiceSettings.getAuthRequestInterval().millis(), - jitter - ) - ); - logger.debug(() -> Strings.format("Next authorization call in %d minutes", waitTime.getMinutes())); - - lastAuthTask.set( - serviceComponents.threadPool() - .schedule( - this::scheduleAndSendAuthorizationRequest, - waitTime, - serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME) - ) - ); - } catch (Exception e) { - logger.warn("Failed scheduling authorization request", e); - } - } - - private void scheduleAndSendAuthorizationRequest() { - if (shutdown.get()) { - return; - } - - scheduleAuthorizationRequest(); - sendAuthorizationRequest(); - } - - private void sendAuthorizationRequest() { - try { - ActionListener listener = ActionListener.wrap((model) -> { - setAuthorizedContent(model); - if (callback != null) { - callback.run(); - } - }, e -> { - // we don't need to do anything if there was a failure, everything is disabled by default - firstAuthorizationCompletedLatch.countDown(); - }); - - authorizationHandler.getAuthorization(listener, sender); - } catch (Exception e) { - logger.warn("Failure while sending the request to retrieve authorization", e); - // we don't need to do anything if there was a failure, everything is disabled by default - firstAuthorizationCompletedLatch.countDown(); - } - } - - private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) { - logger.debug(() -> Strings.format("Received authorization response, %s", auth)); - - var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes)); - logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels)); - - // recalculate which default config ids and models are authorized now - var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels); - - var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels); - var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds); - authorizedContent.set( - new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) - ); - - authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent); - handleRevokedDefaultConfigs(authorizedDefaultModelIds); - } - - private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorizationModel auth) { - var authorizedModels = auth.getAuthorizedModelIds(); - var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet()); - authorizedDefaultModelIds.retainAll(authorizedModels); - - return authorizedDefaultModelIds; - } - - private List getAuthorizedDefaultConfigIds( - Set authorizedDefaultModelIds, - ElasticInferenceServiceAuthorizationModel auth - ) { - var authorizedConfigIds = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - if (auth.getAuthorizedTaskTypes().contains(modelConfig.model().getTaskType()) == false) { - logger.warn( - org.elasticsearch.common.Strings.format( - "The authorization response included the default model: %s, " - + "but did not authorize the assumed task type of the model: %s. Enabling model.", - id, - modelConfig.model().getTaskType() - ) - ); - } - authorizedConfigIds.add( - new InferenceService.DefaultConfigId( - modelConfig.model().getInferenceEntityId(), - modelConfig.settings(), - inferenceService - ) - ); - } - } - - authorizedConfigIds.sort(Comparator.comparing(InferenceService.DefaultConfigId::inferenceId)); - return authorizedConfigIds; - } - - private List getAuthorizedDefaultModelsObjects(Set authorizedDefaultModelIds) { - var authorizedModels = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - authorizedModels.add(modelConfig); - } - } - - authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model().getInferenceEntityId())); - return authorizedModels; - } - - private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) { - // if a model was initially returned in the authorization response but is absent, then we'll assume authorization was revoked - var unauthorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); - unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds); - - // get all the default inference endpoint ids for the unauthorized model ids - var unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream() - .map(defaultModelsConfigs::get) // get all the model configs - .filter(Objects::nonNull) // limit to only non-null - .map(modelConfig -> modelConfig.model().getInferenceEntityId()) // get the inference ids - .collect(Collectors.toSet()); - - var deleteInferenceEndpointsListener = ActionListener.wrap(result -> { - logger.debug(Strings.format("Successfully revoked access to default inference endpoint IDs: %s", unauthorizedDefaultModelIds)); - firstAuthorizationCompletedLatch.countDown(); - }, e -> { - logger.warn( - Strings.format("Failed to revoke access to default inference endpoint IDs: %s, error: %s", unauthorizedDefaultModelIds, e) - ); - firstAuthorizationCompletedLatch.countDown(); - }); - - logger.debug( - () -> Strings.format( - "Synchronizing default inference endpoints, attempting to remove ids: %s", - unauthorizedDefaultInferenceEndpointIds - ) - ); - modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 8a23057195f4e..4b17cab04471a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -17,16 +17,13 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; @@ -49,12 +46,10 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModelTests; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; @@ -98,7 +93,6 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.isA; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -110,7 +104,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); - private ModelRegistry modelRegistry; private ThreadPool threadPool; private HttpClientManager clientManager; @@ -123,7 +116,6 @@ protected Collection> getPlugins() { @Before public void init() throws Exception { webServer.start(); - modelRegistry = node().injector().getInstance(ModelRegistry.class); threadPool = createThreadPool(inferenceUtilityExecutors()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); } @@ -921,8 +913,6 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio public void testHideFromConfigurationApi_ThrowsUnsupported_WithNoAvailableModels() throws Exception { try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } } @@ -942,119 +932,86 @@ public void testHideFromConfigurationApi_ThrowsUnsupported_WithAvailableModels() ) ) ) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } } public void testCreateConfiguration() throws Exception { - try ( - var service = createServiceWithMockSender( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) - ) - ) - ) - ) - ) - ) { - ensureAuthorizationCallFinished(service); - - String content = XContentHelper.stripWhitespace(""" - { - "service": "elastic", - "name": "Elastic", - "task_types": ["sparse_embedding", "chat_completion", "text_embedding"], - "configurations": { - "model_id": { - "description": "The name of the model to use for the inference task.", - "label": "Model ID", - "required": true, - "sensitive": false, - "updatable": false, - "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] - }, - "max_input_tokens": { - "description": "Allows you to specify the maximum number of tokens per input.", - "label": "Maximum Input Tokens", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding"] - } + String content = XContentHelper.stripWhitespace(""" + { + "service": "elastic", + "name": "Elastic", + "task_types": ["sparse_embedding", "chat_completion", "text_embedding"], + "configurations": { + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] + }, + "max_input_tokens": { + "description": "Allows you to specify the maximum number of tokens per input.", + "label": "Maximum Input Tokens", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding"] } } - """); - InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( - new BytesArray(content), - XContentType.JSON - ); - boolean humanReadable = true; - BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); - InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) - ); - assertToXContentEquivalent( - originalBytes, - toXContent(serviceConfiguration, XContentType.JSON, humanReadable), - XContentType.JSON - ); - } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) + ); + assertToXContentEquivalent(originalBytes, toXContent(serviceConfiguration, XContentType.JSON, humanReadable), XContentType.JSON); } public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { - try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { - ensureAuthorizationCallFinished(service); - - String content = XContentHelper.stripWhitespace(""" - { - "service": "elastic", - "name": "Elastic", - "task_types": [], - "configurations": { - "model_id": { - "description": "The name of the model to use for the inference task.", - "label": "Model ID", - "required": true, - "sensitive": false, - "updatable": false, - "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] - }, - "max_input_tokens": { - "description": "Allows you to specify the maximum number of tokens per input.", - "label": "Maximum Input Tokens", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding"] - } + String content = XContentHelper.stripWhitespace(""" + { + "service": "elastic", + "name": "Elastic", + "task_types": [], + "configurations": { + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] + }, + "max_input_tokens": { + "description": "Allows you to specify the maximum number of tokens per input.", + "label": "Maximum Input Tokens", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding"] } } - """); - InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( - new BytesArray(content), - XContentType.JSON - ); - boolean humanReadable = true; - BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); - InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( - EnumSet.noneOf(TaskType.class) - ); - assertToXContentEquivalent( - originalBytes, - toXContent(serviceConfiguration, XContentType.JSON, humanReadable), - XContentType.JSON - ); - } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + var humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration(EnumSet.noneOf(TaskType.class)); + assertToXContentEquivalent(originalBytes, toXContent(serviceConfiguration, XContentType.JSON, humanReadable), XContentType.JSON); } public void testGetConfiguration_ThrowsUnsupported() throws Exception { @@ -1073,30 +1030,13 @@ public void testGetConfiguration_ThrowsUnsupported() throws Exception { ) ) ) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::getConfiguration); } } - public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWithAValidModel() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse", "chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testSupportedStreamingTasks_ReturnsChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - + try (var service = createService(senderFactory)) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); assertTrue(service.defaultConfigIds().isEmpty()); @@ -1107,79 +1047,10 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi } } - public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimplementedTaskTypes() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "model-b", - "task_types": ["embed"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); - } - } - - public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "model-b", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - } - } - - public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChatCompletion() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testDefaultConfigs_ReturnsEmptyLists() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + try (var service = createService(senderFactory)) { assertTrue(service.defaultConfigIds().isEmpty()); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); @@ -1187,120 +1058,10 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat } } - public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIncorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["embed/text/sparse"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testSupportedTaskTypes_Returns_Unsupported() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - } - } - - public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - }, - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "jina-embeddings-v3", - "task_types": ["embed/text/dense"] - }, - { - "model_name": "elastic-rerank-v1", - "task_types": ["rerank/text/text-similarity"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertFalse(service.canStream(TaskType.ANY)); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".elastic-rerank-v1", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".jina-embeddings-v3", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING)) - ); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - var models = listener.actionGet(TIMEOUT); - assertThat(models.size(), is(4)); - assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elastic-rerank-v1")); - assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic")); - assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".jina-embeddings-v3")); - assertThat(models.get(3).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + try (var service = createService(senderFactory)) { + expectThrows(UnsupportedOperationException.class, service::supportedTaskTypes); } } @@ -1392,23 +1153,11 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp } } - private void ensureAuthorizationCallFinished(ElasticInferenceService service) { - service.onNodeStarted(); - service.waitForFirstAuthorizationToComplete(TIMEOUT); - } - private ElasticInferenceService createServiceWithMockSender() { return createServiceWithMockSender(ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth()); } private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel auth) { - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(auth); - return Void.TYPE; - }).when(mockAuthHandler).getAuthorization(any(), any()); - var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); @@ -1417,52 +1166,19 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ factory, createWithEmptySettings(threadPool), new ElasticInferenceServiceSettings(Settings.EMPTY), - modelRegistry, - mockAuthHandler, mockClusterServiceEmpty() ); } private ElasticInferenceService createService(HttpRequestSender.Factory senderFactory) { - return createService(senderFactory, ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth(), null); + return createService(senderFactory, null); } private ElasticInferenceService createService(HttpRequestSender.Factory senderFactory, String elasticInferenceServiceURL) { - return createService(senderFactory, ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth(), elasticInferenceServiceURL); - } - - private ElasticInferenceService createService( - HttpRequestSender.Factory senderFactory, - ElasticInferenceServiceAuthorizationModel auth, - String elasticInferenceServiceURL - ) { - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(auth); - return Void.TYPE; - }).when(mockAuthHandler).getAuthorization(any(), any()); - - return new ElasticInferenceService( - senderFactory, - createWithEmptySettings(threadPool), - ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), - modelRegistry, - mockAuthHandler, - mockClusterServiceEmpty() - ); - } - - private ElasticInferenceService createServiceWithAuthHandler( - HttpRequestSender.Factory senderFactory, - String elasticInferenceServiceURL - ) { return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), - modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool), mockClusterServiceEmpty() ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java deleted file mode 100644 index fd7bf5c4c56c4..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.EmptySecretSettings; -import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.MinimalServiceSettings; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; -import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.Utils; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; -import org.junit.Before; - -import java.io.IOException; -import java.util.Collection; -import java.util.EnumSet; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - -import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.defaultEndpointId; -import static org.hamcrest.CoreMatchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; - -public class ElasticInferenceServiceAuthorizationHandlerTests extends ESSingleNodeTestCase { - private DeterministicTaskQueue taskQueue; - private ModelRegistry modelRegistry; - - @Override - protected Collection> getPlugins() { - return List.of(LocalStateInferencePlugin.class); - } - - @Before - public void init() throws Exception { - taskQueue = new DeterministicTaskQueue(); - modelRegistry = getInstanceFromNode(ModelRegistry.class); - } - - public void testSecondAuthResultRevokesAuthorization() throws Exception { - var callbackCount = new AtomicInteger(0); - // we're only interested in two authorization calls which is why I'm using a value of 2 here - var latch = new CountDownLatch(2); - final AtomicReference handlerRef = new AtomicReference<>(); - - Runnable callback = () -> { - // the first authorization response contains a streaming task so we're expecting to support streaming here - if (callbackCount.incrementAndGet() == 1) { - assertThat(handlerRef.get().supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - } - latch.countDown(); - - // we only want to run the tasks twice, so advance the time on the queue - // which flags the scheduled authorization request to be ready to run - if (callbackCount.get() == 1) { - taskQueue.advanceTime(); - } else { - try { - handlerRef.get().close(); - } catch (IOException e) { - // ignore - } - } - }; - - var requestHandler = mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "rainbow-sprinkles", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ), - ElasticInferenceServiceAuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntity(List.of())) - ); - - handlerRef.set( - new ElasticInferenceServiceAuthorizationHandler( - createWithEmptySettings(taskQueue.getThreadPool()), - modelRegistry, - requestHandler, - initDefaultEndpoints(), - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION), - null, - mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - callback - ) - ); - - var handler = handlerRef.get(); - handler.init(); - taskQueue.runAllRunnableTasks(); - latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS); - - // this should be after we've received both authorization responses, the second response will revoke authorization - - assertThat(handler.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); - assertThat(handler.defaultConfigIds(), is(List.of())); - assertThat(handler.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - handler.defaultConfigs(listener); - - var configs = listener.actionGet(); - assertThat(configs.size(), is(0)); - } - - public void testSendsAnAuthorizationRequestTwice() throws Exception { - var callbackCount = new AtomicInteger(0); - // we're only interested in two authorization calls which is why I'm using a value of 2 here - var latch = new CountDownLatch(2); - final AtomicReference handlerRef = new AtomicReference<>(); - - Runnable callback = () -> { - // the first authorization response does not contain a streaming task so we're expecting to not support streaming here - if (callbackCount.incrementAndGet() == 1) { - assertThat(handlerRef.get().supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); - } - latch.countDown(); - - // we only want to run the tasks twice, so advance the time on the queue - // which flags the scheduled authorization request to be ready to run - if (callbackCount.get() == 1) { - taskQueue.advanceTime(); - } else { - try { - handlerRef.get().close(); - } catch (IOException e) { - // ignore - } - } - }; - - var requestHandler = mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("abc", EnumSet.of(TaskType.SPARSE_EMBEDDING)) - ) - ) - ), - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "abc", - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "rainbow-sprinkles", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ) - ); - - handlerRef.set( - new ElasticInferenceServiceAuthorizationHandler( - createWithEmptySettings(taskQueue.getThreadPool()), - modelRegistry, - requestHandler, - initDefaultEndpoints(), - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION), - null, - mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - callback - ) - ); - - var handler = handlerRef.get(); - handler.init(); - taskQueue.runAllRunnableTasks(); - latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS); - // this should be after we've received both authorization responses - - assertThat(handler.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - handler.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - null - ) - ) - ) - ); - assertThat(handler.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - handler.defaultConfigs(listener); - - var configs = listener.actionGet(); - assertThat(configs.get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - } - - private static ElasticInferenceServiceAuthorizationRequestHandler mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel firstAuthResponse, - ElasticInferenceServiceAuthorizationModel secondAuthResponse - ) { - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(firstAuthResponse); - return Void.TYPE; - }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(secondAuthResponse); - return Void.TYPE; - }).when(mockAuthHandler).getAuthorization(any(), any()); - - return mockAuthHandler; - } - - private static Map initDefaultEndpoints() { - return Map.of( - "rainbow-sprinkles", - new DefaultModelConfig( - new ElasticInferenceServiceCompletionModel( - defaultEndpointId("rainbow-sprinkles"), - TaskType.CHAT_COMPLETION, - "test", - new ElasticInferenceServiceCompletionServiceSettings("rainbow-sprinkles"), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.EMPTY_INSTANCE - ), - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME) - ), - "elser-2", - new DefaultModelConfig( - new ElasticInferenceServiceSparseEmbeddingsModel( - defaultEndpointId("elser-2"), - TaskType.SPARSE_EMBEDDING, - "test", - new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-2", null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.EMPTY_INSTANCE, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME) - ) - ); - } -} From 75005b76217f18f19a99b4c0272a0938aa404858 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 30 Oct 2025 18:18:37 +0000 Subject: [PATCH 21/32] [CI] Auto commit changes from spotless --- .../services/elastic/ElasticInferenceService.java | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 0e89b697734af..8a6d2626ad521 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -98,7 +98,6 @@ public class ElasticInferenceService extends SenderService { // This mirrors the memory constraints observed with sparse embeddings private static final Integer DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE = 16; - /** * The task types that the {@link InferenceAction.Request} can accept. */ @@ -116,12 +115,7 @@ public ElasticInferenceService( ElasticInferenceServiceSettings elasticInferenceServiceSettings, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - this( - factory, - serviceComponents, - elasticInferenceServiceSettings, - context.clusterService() - ); + this(factory, serviceComponents, elasticInferenceServiceSettings, context.clusterService()); } public ElasticInferenceService( @@ -153,7 +147,6 @@ public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION); } - @Override protected void doUnifiedCompletionInfer( Model model, From c34ac072413ee7575d3a45d8f34a64c50d689e34 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 30 Oct 2025 15:06:47 -0400 Subject: [PATCH 22/32] Refactoring tests --- .../AuthorizationTaskExecutorIT.java | 46 +++++++++++++++--- ...horizationTaskExecutorMultipleNodesIT.java | 39 ++------------- .../AuthorizationTaskExecutor.java | 32 ------------- .../AuthorizationTaskParams.java | 3 ++ .../AuthorizationPollerTests.java | 48 +++++++++++++++++-- 5 files changed, 91 insertions(+), 77 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index e33a07f14c34e..2c0efd0e0699d 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -7,13 +7,17 @@ package org.elasticsearch.xpack.inference.integration; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequestBuilder; +import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequestBuilder; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.AdminClient; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.tasks.TaskInfo; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; @@ -22,6 +26,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.junit.After; import org.junit.AfterClass; @@ -31,14 +36,17 @@ import java.io.IOException; import java.util.Collection; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { + public static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; public static final String EMPTY_AUTH_RESPONSE = """ { @@ -115,6 +123,8 @@ public void testCreatesEisChatCompletionEndpoint() throws Exception { } private void assertNoAuthorizedEisEndpoints() throws Exception { + waitForTask(AUTH_TASK_ACTION, admin()); + assertBusy(() -> { var newPoller = authorizationTaskExecutor.getCurrentPollerTask(); assertNotNull(newPoller); @@ -125,6 +135,20 @@ private void assertNoAuthorizedEisEndpoints() throws Exception { assertThat(eisEndpoints, empty()); } + public static TaskInfo waitForTask(String taskAction, AdminClient adminClient) throws Exception { + var taskRef = new AtomicReference(); + var builder = new ListTasksRequestBuilder(adminClient.cluster()); + + assertBusy(() -> { + var response = builder.get(); + var authPollerTask = response.getTasks().stream().filter(task -> task.action().equals(taskAction)).findFirst(); + assertTrue(authPollerTask.isPresent()); + taskRef.set(authPollerTask.get()); + }); + + return taskRef.get(); + } + private List getEisEndpoints() { var listener = new PlainActionFuture>(); modelRegistry.getAllModels(false, listener); @@ -134,13 +158,9 @@ private List getEisEndpoints() { } private void restartPollingTaskAndWaitForAuthResponse() throws Exception { - var taskListener = new PlainActionFuture(); + cancelAuthorizationTask(admin()); - authorizationTaskExecutor.abortTask(TimeValue.THIRTY_SECONDS, taskListener); - // Ensure that the listener doesn't return a failure - assertNull(taskListener.actionGet(TimeValue.THIRTY_SECONDS)); - - // wait for the new task to be recreated + // wait for the new task to be recreated and an authorization response to be processed assertBusy(() -> { var newPoller = authorizationTaskExecutor.getCurrentPollerTask(); assertNotNull(newPoller); @@ -148,6 +168,20 @@ private void restartPollingTaskAndWaitForAuthResponse() throws Exception { }); } + public static void cancelAuthorizationTask(AdminClient adminClient) throws Exception { + var pollerTask = waitForTask(AUTH_TASK_ACTION, adminClient); + var builder = new CancelTasksRequestBuilder(adminClient.cluster()); + + assertBusy(() -> { + var cancelTaskResponse = builder.setActions(AUTH_TASK_ACTION).get(); + assertThat(cancelTaskResponse.getTasks().size(), is(1)); + assertThat(cancelTaskResponse.getTasks().get(0).action(), is(AUTH_TASK_ACTION)); + }); + + var newPollerTask = waitForTask(AUTH_TASK_ACTION, adminClient); + assertThat(newPollerTask.taskId(), is(not(pollerTask.taskId()))); + } + public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized() throws Exception { assertNoAuthorizedEisEndpoints(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java index 73f5c118087d6..597fcef4a10fe 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java @@ -7,12 +7,10 @@ package org.elasticsearch.xpack.inference.integration; -import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.tasks.TaskInfo; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; @@ -35,6 +33,8 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE; import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.EMPTY_AUTH_RESPONSE; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.cancelAuthorizationTask; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.waitForTask; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.not; @@ -77,25 +77,8 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { .build(); } - @Override - public Settings indexSettings() { - return Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(3, 10)).build(); - } - public void testCancellingAuthorizationTaskRestartsIt() throws Exception { - var pollerTask = waitForTask(internalCluster().getNodeNames(), AUTH_TASK_ACTION); - - assertBusy(() -> { - var cancelTaskResponse = admin().cluster() - .prepareCancelTasks(internalCluster().getNodeNames()) - .setActions(AUTH_TASK_ACTION) - .get(); - assertThat(cancelTaskResponse.getTasks().size(), is(1)); - assertThat(cancelTaskResponse.getTasks().get(0).action(), is(AUTH_TASK_ACTION)); - }); - - var newPollerTask = waitForTask(internalCluster().getNodeNames(), AUTH_TASK_ACTION); - assertThat(newPollerTask.taskId(), is(not(pollerTask.taskId()))); + cancelAuthorizationTask(admin()); } public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRunningItShutsDown() throws Exception { @@ -107,7 +90,7 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun var nodeNameMapping = getNodeNames(internalCluster().getNodeNames()); - var pollerTask = waitForTask(internalCluster().getNodeNames(), AUTH_TASK_ACTION); + var pollerTask = waitForTask(AUTH_TASK_ACTION, admin()); var endpoints = getAllEndpoints(); assertTrue( @@ -121,7 +104,7 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun assertTrue("expected the node to shutdown properly", internalCluster().stopNode(nodeNameMapping.get(pollerTask.node()))); assertBusy(() -> { - var relocatedPollerTask = waitForTask(internalCluster().getNodeNames(), AUTH_TASK_ACTION); + var relocatedPollerTask = waitForTask(AUTH_TASK_ACTION, admin()); assertThat(relocatedPollerTask.node(), not(is(pollerTask.node()))); }); @@ -145,18 +128,6 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun } - private TaskInfo waitForTask(String[] nodes, String taskAction) throws Exception { - var taskRef = new AtomicReference(); - assertBusy(() -> { - var response = admin().cluster().prepareListTasks(nodes).get(); - var authPollerTask = response.getTasks().stream().filter(task -> task.action().equals(taskAction)).findFirst(); - assertTrue(authPollerTask.isPresent()); - taskRef.set(authPollerTask.get()); - }); - - return taskRef.get(); - } - private record NodeNameMapping(Map nodeNamesMap) { public String get(String rawNodeName) { var nodeName = nodeNamesMap.get(rawNodeName); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java index f80fa1b0345fa..161267d86f559 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java @@ -79,38 +79,6 @@ void init() { } } - /** - * This method should only be used for testing purposes to simulate a task being recreated. - */ - public void abortTask(TimeValue timeout, ActionListener listener) { - var task = currentTask.get(); - if (task != null && task.isCancelled() == false) { - task.markAsLocallyAborted("testing task cancellation"); - currentTask.set(null); - waitForNullTask(task, timeout, listener); - } else { - listener.onFailure(new IllegalStateException("Authorization poller task was not created yet, or was already aborted")); - } - } - - private void waitForNullTask(AllocatedPersistentTask task, TimeValue timeout, ActionListener listener) { - task.waitForPersistentTask( - Objects::isNull, - timeout, - new PersistentTasksService.WaitForPersistentTaskListener() { - @Override - public void onResponse(PersistentTasksCustomMetadata.PersistentTask persistentTask) { - listener.onResponse(null); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - } - ); - } - /** * This method should only be used for testing purposes to get the current running task. */ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java index 976a1e8307427..7b2791169e872 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java @@ -20,6 +20,9 @@ import static org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller.TASK_NAME; +/** + * Empty parameters for the authorization persistent task. + */ public class AuthorizationTaskParams implements PersistentTaskParams { public static final AuthorizationTaskParams INSTANCE = new AuthorizationTaskParams(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index 0e96f2d7f5310..df2aa7b8cb821 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -128,6 +128,47 @@ public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { ); } + public void testSendsAuthorizationRequest_ButDoesNotStoreAnyModels_WhenTheirInferenceIdAlreadyExists() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2, "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mockRegistry, + mockClient, + null + ); + + poller.sendAuthorizationRequest(); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); + } + public void testDoesNotAttemptToStoreModelIds_ThatDoNotExistInThePreconfiguredMapping() { var mockRegistry = mock(ModelRegistry.class); when(mockRegistry.isReady()).thenReturn(true); @@ -166,10 +207,9 @@ public void testDoesNotAttemptToStoreModelIds_ThatDoNotExistInThePreconfiguredMa null ); - var requestArgCaptor = ArgumentCaptor.forClass(StoreInferenceEndpointsAction.Request.class); poller.sendAuthorizationRequest(); - verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any()); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); } public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegration_DoesNotSupport() { @@ -210,10 +250,8 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra null ); - var requestArgCaptor = ArgumentCaptor.forClass(StoreInferenceEndpointsAction.Request.class); - poller.sendAuthorizationRequest(); - verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any()); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); } public void testSendsTwoAuthorizationRequests() throws InterruptedException { From df94ce951037536128db4a264bd0ab9a0429b2a3 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 30 Oct 2025 15:09:41 -0400 Subject: [PATCH 23/32] updating transport version --- .../inference_api_eis_authorization_persistent_task.csv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv b/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv index bdb12ee6b228e..0135b6d28dffd 100644 --- a/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv +++ b/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv @@ -1 +1 @@ -9205000 +9208000 From a7e67f8522082a6771bac79c67ab3078e70455bb Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 30 Oct 2025 19:16:28 +0000 Subject: [PATCH 24/32] [CI] Auto commit changes from spotless --- .../services/elastic/authorization/AuthorizationPollerTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index df2aa7b8cb821..d276419570d0e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -207,7 +207,6 @@ public void testDoesNotAttemptToStoreModelIds_ThatDoNotExistInThePreconfiguredMa null ); - poller.sendAuthorizationRequest(); verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); } From 4b2b33f6ed0afcdf13b14b5f8f39719fca55d47a Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 30 Oct 2025 16:04:33 -0400 Subject: [PATCH 25/32] Fixing transport version --- .../inference_api_eis_authorization_persistent_task.csv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv b/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv index 0135b6d28dffd..d21a6d4514613 100644 --- a/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv +++ b/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv @@ -1 +1 @@ -9208000 +9209000 From 4b7e6cf5685293f355faa1f5cb8711acd6559caf Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 3 Nov 2025 13:06:16 -0500 Subject: [PATCH 26/32] Fixing check for preconfigured endpoints --- .../AuthorizationTaskExecutorIT.java | 7 ++ ...ransportDeleteInferenceEndpointAction.java | 2 +- .../action/TransportInferenceUsageAction.java | 2 +- .../TransportPutInferenceModelAction.java | 2 +- .../mapper/SemanticTextFieldMapper.java | 2 +- .../inference/registry/ModelRegistry.java | 59 +++++++------- .../registry/ModelRegistryMetadata.java | 52 +++++++++++-- ...ortDeleteInferenceEndpointActionTests.java | 16 ++-- .../TransportInferenceUsageActionTests.java | 2 +- .../registry/ModelRegistryMetadataTests.java | 77 +++++++++++++++++++ .../registry/ModelRegistryTests.java | 6 +- 11 files changed, 179 insertions(+), 48 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index 2c0efd0e0699d..8a731a6d6681a 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -133,6 +133,10 @@ private void assertNoAuthorizedEisEndpoints() throws Exception { var eisEndpoints = getEisEndpoints(); assertThat(eisEndpoints, empty()); + + for (String eisPreconfiguredEndpoints : InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS) { + assertFalse(modelRegistry.containsPreconfiguredInferenceEndpointId(eisPreconfiguredEndpoints)); + } } public static TaskInfo waitForTask(String taskAction, AdminClient adminClient) throws Exception { @@ -203,6 +207,9 @@ private void assertChatCompletionEndpointExists() { var rainbowSprinklesModel = eisEndpoints.get(0); assertChatCompletionUnparsedModel(rainbowSprinklesModel); + assertTrue( + modelRegistry.containsPreconfiguredInferenceEndpointId(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1) + ); } private void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index c100c9926b451..c604bbacb4010 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -224,7 +224,7 @@ private static String endpointIsReferencedInPipelinesOrIndexes(final ClusterStat } private boolean isInferenceIdReserved(String inferenceEndpointId) { - return modelRegistry.containsDefaultConfigId(inferenceEndpointId); + return modelRegistry.containsPreconfiguredInferenceEndpointId(inferenceEndpointId); } private static String buildErrorString(String inferenceEndpointId, Set pipelines, Set indexes) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java index 609a1e4df62d8..4415da2c1b99a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java @@ -265,7 +265,7 @@ private Map createStatsKeysWithEndpointCountsForDefa // may only happen for external services. Set modelIds = endpoints.stream() .filter(endpoint -> TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT.contains(endpoint.getTaskType())) - .filter(endpoint -> modelRegistry.containsDefaultConfigId(endpoint.getInferenceEntityId())) + .filter(endpoint -> modelRegistry.containsPreconfiguredInferenceEndpointId(endpoint.getInferenceEntityId())) .filter(endpoint -> endpoint.getServiceSettings().modelId() != null) .map(endpoint -> stripLinuxSuffix(endpoint.getServiceSettings().modelId())) .collect(Collectors.toSet()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 80d57f888ef6e..64a58b672fd41 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -111,7 +111,7 @@ protected void masterOperation( return; } - if (modelRegistry.containsDefaultConfigId(request.getInferenceEntityId())) { + if (modelRegistry.containsPreconfiguredInferenceEndpointId(request.getInferenceEntityId())) { listener.onFailure( new ElasticsearchStatusException( "[{}] is a reserved inference ID. Cannot create a new inference endpoint with a reserved ID.", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 8f6a902773394..6436442749c0c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -171,7 +171,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie * This enables automatic selection of EIS for better performance while maintaining compatibility with on-prem deployments. */ private static String getPreferredElserInferenceId(ModelRegistry modelRegistry) { - if (modelRegistry != null && modelRegistry.containsDefaultConfigId(DEFAULT_EIS_ELSER_INFERENCE_ID)) { + if (modelRegistry != null && modelRegistry.containsPreconfiguredInferenceEndpointId(DEFAULT_EIS_ELSER_INFERENCE_ID)) { return DEFAULT_EIS_ELSER_INFERENCE_ID; } return DEFAULT_FALLBACK_ELSER_INFERENCE_ID; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 7968745fff2a8..ba665b4a0e1ad 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -77,6 +77,7 @@ import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import java.io.IOException; import java.util.ArrayList; @@ -93,6 +94,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; @@ -150,8 +152,7 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false); private final Set preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>()); private final ClusterService clusterService; - - private volatile Metadata lastMetadata; + private final AtomicReference lastMetadata = new AtomicReference<>(); public ModelRegistry(ClusterService clusterService, Client client) { this.clusterService = Objects.requireNonNull(clusterService); @@ -170,13 +171,24 @@ public Tuple executeTask(MetadataTask tas } /** - * Returns true if the provided inference entity id is the same as one of the default - * endpoints ids. + * Returns true if the model registry contains (whether it has persisted it or not) the provided inference entity id. + * EIS preconfigured endpoints are also considered. * @param inferenceEntityId the id to search for * @return true if we find a match and false if not */ - public boolean containsDefaultConfigId(String inferenceEntityId) { - return defaultConfigIds.containsKey(inferenceEntityId); + public boolean containsPreconfiguredInferenceEndpointId(String inferenceEntityId) { + if (defaultConfigIds.containsKey(inferenceEntityId)) { + return true; + } + + if (lastMetadata.get() != null) { + var project = lastMetadata.get().getProject(ProjectId.DEFAULT); + var state = ModelRegistryMetadata.fromState(project); + var eisPreconfiguredEndpoints = state.getServiceInferenceIds(ElasticInferenceService.NAME); + return eisPreconfiguredEndpoints.contains(inferenceEntityId); + } + + return false; } /** @@ -229,16 +241,15 @@ public void clearDefaultIds() { * @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster. */ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException { - synchronized (this) { - if (lastMetadata == null) { - throw new IllegalStateException("initial cluster state not set yet"); - } + if (lastMetadata.get() == null) { + throw new IllegalStateException("initial cluster state not set yet"); } + var config = defaultConfigIds.get(inferenceEntityId); if (config != null) { return config.settings(); } - var project = lastMetadata.getProject(ProjectId.DEFAULT); + var project = lastMetadata.get().getProject(ProjectId.DEFAULT); var state = ModelRegistryMetadata.fromState(project); var existing = state.getMinimalServiceSettings(inferenceEntityId); if (state.isUpgraded() && existing == null) { @@ -248,14 +259,14 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId } public Set getInferenceIds() { - synchronized (this) { - if (lastMetadata == null) { - throw new IllegalStateException("initial cluster state not set yet"); - } + Set metadataInferenceIds = Set.of(); + if (lastMetadata.get() != null) { + var project = lastMetadata.get().getProject(ProjectId.DEFAULT); + var state = ModelRegistryMetadata.fromState(project); + metadataInferenceIds = state.getInferenceIds(); } - var project = lastMetadata.getProject(ProjectId.DEFAULT); - var state = ModelRegistryMetadata.fromState(project); - var ids = new HashSet<>(state.getInferenceIds()); + + var ids = new HashSet<>(metadataInferenceIds); ids.addAll(Set.copyOf(defaultConfigIds.keySet())); return ids; } @@ -953,10 +964,8 @@ private void updateClusterState(List models, ActionListener taskTypeMatchedDefaults( @Override public void clusterChanged(ClusterChangedEvent event) { - if (lastMetadata == null || event.metadataChanged()) { + if (lastMetadata.get() == null || event.metadataChanged()) { // keep track of the last applied cluster state - synchronized (this) { - lastMetadata = event.state().metadata(); - } + lastMetadata.set(event.state().metadata()); } if (event.localNodeMaster() == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java index 9ba6bf38416c9..857dd2d79476c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java @@ -30,6 +30,7 @@ import java.util.Collection; import java.util.Collections; import java.util.EnumSet; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -150,24 +151,52 @@ public ModelRegistryMetadata withUpgradedModels(Map modelMap; + private final Map> serviceToInferenceEndpointIds; private final Set tombstones; public ModelRegistryMetadata(ImmutableOpenMap modelMap) { - this.isUpgraded = true; - this.modelMap = modelMap; - this.tombstones = null; + this(modelMap, null, true); } public ModelRegistryMetadata(ImmutableOpenMap modelMap, Set tombstone) { - this.isUpgraded = false; - this.modelMap = modelMap; - this.tombstones = Collections.unmodifiableSet(tombstone); + this(modelMap, Collections.unmodifiableSet(tombstone), false); } public ModelRegistryMetadata(StreamInput in) throws IOException { this.isUpgraded = in.readBoolean(); this.modelMap = in.readImmutableOpenMap(StreamInput::readString, MinimalServiceSettings::new); this.tombstones = isUpgraded ? null : in.readCollectionAsSet(StreamInput::readString); + this.serviceToInferenceEndpointIds = buildServiceToInferenceEndpointIdsMap(modelMap); + } + + private ModelRegistryMetadata( + ImmutableOpenMap modelMap, + Set tombstones, + boolean isUpgraded + ) { + this.isUpgraded = isUpgraded; + this.modelMap = modelMap; + this.tombstones = tombstones; + this.serviceToInferenceEndpointIds = buildServiceToInferenceEndpointIdsMap(modelMap); + } + + private static Map> buildServiceToInferenceEndpointIdsMap( + ImmutableOpenMap modelMap + ) { + var serviceToInferenceIds = new HashMap>(); + for (var entry : modelMap.entrySet()) { + var settings = entry.getValue(); + var serviceName = settings.service(); + + var existingSet = serviceToInferenceIds.get(serviceName); + if (existingSet == null) { + existingSet = new HashSet<>(); + } + + existingSet.add(entry.getKey()); + serviceToInferenceIds.put(serviceName, existingSet); + } + return serviceToInferenceIds; } @Override @@ -221,6 +250,17 @@ public ImmutableOpenMap getModelMap() { return modelMap; } + /** + * Returns all inference entity IDs for a given service. + */ + public Set getServiceInferenceIds(String service) { + if (serviceToInferenceEndpointIds.containsKey(service) == false) { + return Set.of(); + } + + return Set.copyOf(serviceToInferenceEndpointIds.get(service)); + } + public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) { return modelMap.get(inferenceEntityId); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java index 27952f23f37f8..d53c9d5eebbc9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java @@ -85,7 +85,7 @@ public void testFailsToDelete_ADefaultEndpoint_WithoutPassingForceQueryParameter listener.onResponse(new UnparsedModel("model_id", TaskType.COMPLETION, "service", Map.of(), Map.of())); return Void.TYPE; }).when(mockModelRegistry).getModel(anyString(), any()); - when(mockModelRegistry.containsDefaultConfigId(anyString())).thenReturn(true); + when(mockModelRegistry.containsPreconfiguredInferenceEndpointId(anyString())).thenReturn(true); var listener = new PlainActionFuture(); @@ -109,7 +109,7 @@ public void testDeletesDefaultEndpoint_WhenForceIsTrue() { listener.onResponse(new UnparsedModel("model_id", TaskType.COMPLETION, "service", Map.of(), Map.of())); return Void.TYPE; }).when(mockModelRegistry).getModel(anyString(), any()); - when(mockModelRegistry.containsDefaultConfigId(anyString())).thenReturn(true); + when(mockModelRegistry.containsPreconfiguredInferenceEndpointId(anyString())).thenReturn(true); doAnswer(invocationOnMock -> { ActionListener listener = invocationOnMock.getArgument(1); listener.onResponse(true); @@ -145,7 +145,7 @@ public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() { var taskType = randomFrom(TaskType.values()); var mockService = mock(InferenceService.class); mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService); - when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false); + when(mockModelRegistry.containsPreconfiguredInferenceEndpointId(inferenceEndpointId)).thenReturn(false); var listener = new PlainActionFuture(); action.masterOperation( @@ -160,7 +160,7 @@ public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() { verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); verify(mockInferenceServiceRegistry).getService(eq(serviceName)); - verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId)); + verify(mockModelRegistry).containsPreconfiguredInferenceEndpointId(eq(inferenceEndpointId)); verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService); } @@ -240,7 +240,7 @@ public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() { var serviceName = randomAlphanumericOfLength(10); var taskType = randomFrom(TaskType.values()); mockNoService(inferenceEndpointId, serviceName, taskType); - when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false); + when(mockModelRegistry.containsPreconfiguredInferenceEndpointId(inferenceEndpointId)).thenReturn(false); var listener = new PlainActionFuture(); @@ -255,7 +255,7 @@ public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() { assertThat(exception.getMessage(), containsString("No service found for this inference endpoint")); verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); verify(mockInferenceServiceRegistry).getService(eq(serviceName)); - verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId)); + verify(mockModelRegistry).containsPreconfiguredInferenceEndpointId(eq(inferenceEndpointId)); verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry); } @@ -275,7 +275,7 @@ public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse var mockService = mock(InferenceService.class); var mockModel = mock(Model.class); mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel); - when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false); + when(mockModelRegistry.containsPreconfiguredInferenceEndpointId(inferenceEndpointId)).thenReturn(false); var listener = new PlainActionFuture(); action.masterOperation( @@ -289,7 +289,7 @@ public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse assertThat(exception.getMessage(), containsString("Failed to stop model deployment")); verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); verify(mockInferenceServiceRegistry).getService(eq(serviceName)); - verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId)); + verify(mockModelRegistry).containsPreconfiguredInferenceEndpointId(eq(inferenceEndpointId)); verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); verify(mockService).stop(eq(mockModel), any()); verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java index d56b3fd8037c7..6d25b37649772 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java @@ -509,7 +509,7 @@ private XContentSource executeAction() throws ExecutionException, InterruptedExc private void givenDefaultEndpoints(String... ids) { for (String id : ids) { - when(modelRegistry.containsDefaultConfigId(id)).thenReturn(true); + when(modelRegistry.containsPreconfiguredInferenceEndpointId(id)).thenReturn(true); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java index 9af21386e93d3..5a59ab56efa88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java @@ -24,6 +24,7 @@ import java.util.Set; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -314,4 +315,80 @@ public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId_Wit ) ); } + + public void testGetServiceInferenceIds_ReturnsCorrectIdsForKnownService() { + var serviceA = "service_a"; + var endpointId1 = "endpointId1"; + var endpointId2 = "endpointId2"; + + var settings1 = MinimalServiceSettings.chatCompletion(serviceA); + var settings2 = MinimalServiceSettings.sparseEmbedding(serviceA); + var models = Map.of(endpointId1, settings1, endpointId2, settings2); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var serviceEndpoints = metadata.getServiceInferenceIds(serviceA); + assertThat(serviceEndpoints, is(Set.of(endpointId1, endpointId2))); + } + + public void testGetServiceInferenceIds_AcceptsNullKeys() { + var serviceA = "service_a"; + var endpointId1 = "endpointId1"; + var endpointId2 = "endpointId2"; + var nullEndpoint1 = "nullEndpoint1"; + var nullEndpoint2 = "nullEndpoint2"; + + var settings1 = MinimalServiceSettings.chatCompletion(serviceA); + var settings2 = MinimalServiceSettings.sparseEmbedding(serviceA); + // I'm not sure why minimal service settings would have a null service name, but testing it anyway + var nullServiceNameSettings1 = MinimalServiceSettings.sparseEmbedding(null); + var nullServiceNameSettings2 = MinimalServiceSettings.sparseEmbedding(null); + var models = Map.of( + endpointId1, + settings1, + endpointId2, + settings2, + nullEndpoint1, + nullServiceNameSettings1, + nullEndpoint2, + nullServiceNameSettings2 + ); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var serviceEndpoints = metadata.getServiceInferenceIds(serviceA); + assertThat(serviceEndpoints, is(Set.of(endpointId1, endpointId2))); + assertThat(metadata.getServiceInferenceIds(null), is(Set.of(nullEndpoint1, nullEndpoint2))); + } + + public void testGetServiceInferenceIds_ReturnsEmptySetForUnknownService() { + var serviceA = "service_a"; + var serviceB = "service_b"; + var endpointId = "endpointId1"; + + var settings = MinimalServiceSettings.chatCompletion(serviceA); + var models = Map.of(endpointId, settings); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var serviceEndpoints = metadata.getServiceInferenceIds(serviceB); + assertThat(serviceEndpoints, is(empty())); + } + + public void testGetServiceInferenceIds_ReturnsEmptySetForEmptyModelMap() { + var serviceA = "service_a"; + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.of()); + + var serviceEndpoints = metadata.getServiceInferenceIds(serviceA); + assertThat(serviceEndpoints, is(empty())); + } + + public void testGetServiceInferenceIds_ReturnedSetIsImmutable_WhenAttemptingToModifyIt() { + var serviceA = "service_a"; + var endpointId = "endpointId1"; + + var settings = MinimalServiceSettings.chatCompletion(serviceA); + var models = Map.of(endpointId, settings); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var serviceEndpoints = metadata.getServiceInferenceIds(serviceA); + expectThrows(UnsupportedOperationException.class, () -> serviceEndpoints.add("newId")); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 44f0dcc1d8962..a54eb379a054c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -90,15 +90,15 @@ public void testIdMatchedDefault() { assertFalse(matched.isPresent()); } - public void testContainsDefaultConfigId() { + public void testContainsPreconfiguredInferenceEndpointId() { registry.addDefaultIds( new InferenceService.DefaultConfigId("foo", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class)) ); registry.addDefaultIds( new InferenceService.DefaultConfigId("bar", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class)) ); - assertTrue(registry.containsDefaultConfigId("foo")); - assertFalse(registry.containsDefaultConfigId("baz")); + assertTrue(registry.containsPreconfiguredInferenceEndpointId("foo")); + assertFalse(registry.containsPreconfiguredInferenceEndpointId("baz")); } public void testTaskTypeMatchedDefaults() { From 909ef5e88e874b1806db6ab5f0f597a6687bbf28 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 3 Nov 2025 18:16:19 +0000 Subject: [PATCH 27/32] [CI] Auto commit changes from spotless --- .../xpack/inference/registry/ModelRegistryMetadata.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java index 857dd2d79476c..359be95d8a4b4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java @@ -169,11 +169,7 @@ public ModelRegistryMetadata(StreamInput in) throws IOException { this.serviceToInferenceEndpointIds = buildServiceToInferenceEndpointIdsMap(modelMap); } - private ModelRegistryMetadata( - ImmutableOpenMap modelMap, - Set tombstones, - boolean isUpgraded - ) { + private ModelRegistryMetadata(ImmutableOpenMap modelMap, Set tombstones, boolean isUpgraded) { this.isUpgraded = isUpgraded; this.modelMap = modelMap; this.tombstones = tombstones; From f16b912b973cec727d2414ad0f55f3b1a58a1900 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 3 Nov 2025 15:29:21 -0500 Subject: [PATCH 28/32] Fixing tests --- .../inference/BaseMockEISAuthServerTest.java | 20 ++++++++++++++++++ .../inference/InferenceBaseRestTest.java | 2 +- .../InternalPreconfiguredEndpoints.java | 21 ++++++------------- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index 09834e6a91210..a2241d7e93ce7 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -17,11 +17,15 @@ import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; +import org.junit.Before; import org.junit.ClassRule; import org.junit.Rule; import org.junit.rules.RuleChain; import org.junit.rules.TestRule; +import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModel; + public class BaseMockEISAuthServerTest extends ESRestTestCase { protected static final MockElasticInferenceServiceAuthorizationServer mockEISServer = @@ -71,4 +75,20 @@ protected Settings restClientSettings() { String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); } + + @Override + protected boolean preserveClusterUponCompletion() { + // Keep the cluster around so the EIS preconfigured endpoints still exist between tests. Otherwise, the inference indices will + // be removed when the cluster is wiped which causes the tests after the first one to fail. + return true; + } + + @Before + public void ensureEisPreconfiguredEndpointsExist() throws Exception { + // Ensure that the authorization logic has completed prior to running each test so we have the correct EIS preconfigured endpoints + // available + // Technically this only needs to be done before the suite runs but the underlying client is created in @Before and not statically + // for the suite + assertBusy(() -> getModel(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2)); + } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 69256d49fe1d2..cb1103081f31a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -331,7 +331,7 @@ protected Map deployE5TrainedModels() throws IOException { } @SuppressWarnings("unchecked") - protected Map getModel(String modelId) throws IOException { + static Map getModel(String modelId) throws IOException { var endpoint = Strings.format("_inference/%s?error_trace", modelId); return ((List>) getInternalAsMap(endpoint).get("endpoints")).get(0); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java index e1751b5edbe7a..f74eac700465e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.elastic; -import org.elasticsearch.common.Strings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -30,20 +29,20 @@ public class InternalPreconfiguredEndpoints { // rainbow-sprinkles public static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; - public static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); + public static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = ".rainbow-sprinkles-elastic"; // elser-2 public static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; - public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2"); + public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = ".elser-2-elastic"; // multilingual-text-embed public static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024; - public static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed-v1"; - public static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); + public static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "jina-embeddings-v3"; + public static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = ".jina-embeddings-v3"; // rerank-v1 - public static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; - public static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); + public static final String DEFAULT_RERANK_MODEL_ID_V1 = "elastic-rerank-v1"; + public static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = ".elastic-rerank-v1"; public record MinimalModel( ModelConfigurations configurations, @@ -121,14 +120,6 @@ public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { return SimilarityMeasure.COSINE; } - public static String defaultEndpointId(String modelId) { - return Strings.format(".%s-elastic", modelId); - } - - public static boolean containsModelName(String modelName) { - return MODEL_NAME_TO_MINIMAL_MODEL.containsKey(modelName); - } - public static MinimalModel getWithModelName(String modelName) { return MODEL_NAME_TO_MINIMAL_MODEL.get(modelName); } From 45d167d6a0a8c7199f5f75803ef8f3db886db091 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 3 Nov 2025 17:54:28 -0500 Subject: [PATCH 29/32] Fixing text embedding test --- .../inference/integration/AuthorizationTaskExecutorIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java index 8a731a6d6681a..8450ceab04848 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -237,7 +237,7 @@ public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Excep { "models": [ { - "model_name": "multilingual-embed-v1", + "model_name": "jina-embeddings-v3", "task_types": ["embed/text/dense"] } ] From a0a07bc623d7411b3739c3ef9d5ba391ad85eeae Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 6 Nov 2025 15:04:05 -0500 Subject: [PATCH 30/32] Addressing feedback --- .../authorization/AuthorizationPoller.java | 26 +++++++- .../AuthorizationTaskExecutor.java | 5 ++ .../AuthorizationPollerTests.java | 63 +++++++++++++++++++ 3 files changed, 92 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index c8bc51603b0e1..67668b718c171 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -17,7 +17,9 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.persistent.PersistentTasksService; import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; @@ -133,6 +135,17 @@ public void waitForAuthorizationToComplete(TimeValue waitTime) { } } + // Overriding so tests in the same package can access + @Override + protected void init( + PersistentTasksService persistentTasksService, + TaskManager taskManager, + String persistentTaskId, + long allocationId + ) { + super.init(persistentTasksService, taskManager, persistentTaskId, allocationId); + } + @Override protected void onCancelled() { shutdown(); @@ -142,11 +155,18 @@ protected void onCancelled() { // default for testing void shutdown() { shutdown.set(true); - if (lastAuthTask.get() != null) { - lastAuthTask.get().cancel(); + + var authTask = lastAuthTask.get(); + if (authTask != null) { + authTask.cancel(); } } + // default for testing + boolean isShutdown() { + return shutdown.get(); + } + private void scheduleAuthorizationRequest() { try { if (elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled() == false) { @@ -177,6 +197,8 @@ private void scheduleAuthorizationRequest() { ); } catch (Exception e) { logger.warn("Failed scheduling authorization request", e); + // Shutdown and complete the task so it will be restarted + onCancelled(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java index 161267d86f559..bca830eb7b948 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java @@ -16,6 +16,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.core.FixForMultiProject; import org.elasticsearch.core.TimeValue; import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.persistent.ClusterPersistentTasksCustomMetadata; @@ -93,6 +94,10 @@ protected void nodeOperation(AllocatedPersistentTask task, AuthorizationTaskPara authPoller.start(); } + @FixForMultiProject( + description = "A single cluster can have multiple projects, " + + "we'll need to either make a call per project/org or use a bulk authorization api that EIS provides" + ) @Override public Scope scope() { return Scope.CLUSTER; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index d276419570d0e..937d6ba20b3aa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -12,12 +12,15 @@ import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.persistent.PersistentTasksService; import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; @@ -315,4 +318,64 @@ public void testSendsTwoAuthorizationRequests() throws InterruptedException { assertThat(callbackCount.get(), is(2)); verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); } + + public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throws InterruptedException { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + // this is an unknown model id so it won't trigger storing an inference endpoint because + // it doesn't map to a known one + "abc", + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + + var callbackCount = new AtomicInteger(0); + var latch = new CountDownLatch(1); + + Runnable callback = () -> { + callbackCount.incrementAndGet(); + latch.countDown(); + }; + + // Simulate scheduling failure by having the settings throw an exception when queried + // Throwing an exception should cause the poller to shutdown and mark itself as completed + var settingsMock = mock(ElasticInferenceServiceSettings.class); + when(settingsMock.isPeriodicAuthorizationEnabled()).thenThrow(new IllegalStateException("failing")); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + settingsMock, + mockRegistry, + mockClient, + callback + ); + poller.init(mock(PersistentTasksService.class), mock(TaskManager.class), "id", 0); + poller.start(); + taskQueue.runAllRunnableTasks(); + latch.await(TimeValue.THIRTY_SECONDS.getSeconds(), TimeUnit.SECONDS); + + assertThat(callbackCount.get(), is(1)); + assertTrue(poller.isShutdown()); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); + } } From f994811da414f3fc99a72587979fb24899d16b3b Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 6 Nov 2025 15:30:52 -0500 Subject: [PATCH 31/32] Marking task as failed --- .../authorization/AuthorizationPoller.java | 7 ++++++- .../AuthorizationPollerTests.java | 19 +++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java index 67668b718c171..16646cbde4e89 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -152,6 +152,11 @@ protected void onCancelled() { markAsCompleted(); } + private void shutdownAndMarkTaskAsFailed(Exception e) { + shutdown(); + markAsFailed(e); + } + // default for testing void shutdown() { shutdown.set(true); @@ -198,7 +203,7 @@ private void scheduleAuthorizationRequest() { } catch (Exception e) { logger.warn("Failed scheduling authorization request", e); // Shutdown and complete the task so it will be restarted - onCancelled(); + shutdownAndMarkTaskAsFailed(e); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index 937d6ba20b3aa..d0d3a67b2d9d5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -43,6 +43,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -354,10 +355,11 @@ public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throw latch.countDown(); }; + var exception = new IllegalStateException("failing"); // Simulate scheduling failure by having the settings throw an exception when queried // Throwing an exception should cause the poller to shutdown and mark itself as completed var settingsMock = mock(ElasticInferenceServiceSettings.class); - when(settingsMock.isPeriodicAuthorizationEnabled()).thenThrow(new IllegalStateException("failing")); + when(settingsMock.isPeriodicAuthorizationEnabled()).thenThrow(exception); var poller = new AuthorizationPoller( new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), @@ -369,13 +371,26 @@ public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throw mockClient, callback ); - poller.init(mock(PersistentTasksService.class), mock(TaskManager.class), "id", 0); + + var persistentTaskId = "id"; + var allocationId = 0L; + + var mockPersistentTasksService = mock(PersistentTasksService.class); + poller.init(mockPersistentTasksService, mock(TaskManager.class), persistentTaskId, allocationId); poller.start(); taskQueue.runAllRunnableTasks(); latch.await(TimeValue.THIRTY_SECONDS.getSeconds(), TimeUnit.SECONDS); assertThat(callbackCount.get(), is(1)); assertTrue(poller.isShutdown()); + verify(mockPersistentTasksService, times(1)).sendCompletionRequest( + eq(persistentTaskId), + eq(allocationId), + eq(exception), + eq(null), + any(), + any() + ); verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); } } From 415c23bdb770ce39b71931eab4461e3d4955c8da Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 7 Nov 2025 16:55:58 -0500 Subject: [PATCH 32/32] Fixing flaky test --- ...horizationTaskExecutorMultipleNodesIT.java | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java index 597fcef4a10fe..cb92c70d27442 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.junit.AfterClass; +import org.junit.Before; import org.junit.BeforeClass; import java.io.IOException; @@ -44,8 +45,11 @@ * If the task is running on a node that is shutdown, it should be relocated to another node. * If the task is cancelled it should be restarted automatically. */ +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0) public class AuthorizationTaskExecutorMultipleNodesIT extends ESIntegTestCase { + private static final int NUM_DATA_NODES = 2; + private static final int NUM_MASTER_NODES = 2; private static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; private static final MockWebServer webServer = new MockWebServer(); private static String gatewayUrl; @@ -57,6 +61,16 @@ public static void initClass() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); } + @Before + public void startNodes() { + // Ensure we have multiple master and data nodes so we have somewhere to place the inference indices and so that we can safely + // shut down the node that is running the authorization task. If there is only one master and it is running the task, + // we'll get an error that we can't shut down the only eligible master node + internalCluster().startMasterOnlyNodes(NUM_MASTER_NODES); + internalCluster().ensureAtLeastNumDataNodes(NUM_DATA_NODES); + ensureStableCluster(NUM_MASTER_NODES + NUM_DATA_NODES); + } + @AfterClass public static void cleanUpClass() { webServer.close(); @@ -82,12 +96,6 @@ public void testCancellingAuthorizationTaskRestartsIt() throws Exception { } public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRunningItShutsDown() throws Exception { - // Ensure we have multiple master and data nodes so we have somewhere to place the inference indices and so that we can safely - // shut down the node that is running the authorization task. If there is only one master and it is running the task, - // we'll get an error that we can't shut down the only eligible master node - internalCluster().startMasterOnlyNodes(2); - internalCluster().ensureAtLeastNumDataNodes(2); - var nodeNameMapping = getNodeNames(internalCluster().getNodeNames()); var pollerTask = waitForTask(AUTH_TASK_ACTION, admin()); @@ -125,7 +133,6 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun ); assertThat(rainbowSprinklesEndpoint.getTaskType(), is(TaskType.CHAT_COMPLETION)); }); - } private record NodeNameMapping(Map nodeNamesMap) {