From f2605af02394826980ad47c7f7cfd47a37e7709e Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 19 Sep 2025 11:42:15 -0400 Subject: [PATCH 01/18] Adding todos --- .../inference/registry/ModelRegistry.java | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 7cd1cf5999d11..3a7171e5211eb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -249,6 +249,12 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId * @param listener Model listener */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + // TODO add a SubscribableListener here + // 1. Do search + // 2. If we don't find it, check in defaultConfigIds + // 3. If we don't find it, make a call to EIS? maybe only if it has -elastic in the name? + // 4. If we still don't find it, return not found + ActionListener searchListener = ActionListener.wrap((searchResponse) -> { // There should be a hit for the configurations if (searchResponse.getHits().getHits().length == 0) { @@ -289,6 +295,12 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + // TODO add a SubscribableListener here + // 1. Do search + // 2. If we don't find it, check in defaultConfigIds + // 3. If we don't find it, make a call to EIS? maybe only if it has -elastic in the name? + // 4. If we still don't find it, return not found + ActionListener searchListener = ActionListener.wrap((searchResponse) -> { // There should be a hit for the configurations if (searchResponse.getHits().getHits().length == 0) { @@ -325,6 +337,7 @@ public void getModel(String inferenceEntityId, ActionListener lis } private ResourceNotFoundException inferenceNotFoundException(String inferenceEntityId) { + // TODO add some logic here to check if it is EIS related and return a message indicating that the endpoint may not be authorized return new ResourceNotFoundException("Inference endpoint not found [{}]", inferenceEntityId); } @@ -335,6 +348,12 @@ private ResourceNotFoundException inferenceNotFoundException(String inferenceEnt * @param listener Models listener */ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { + // TODO add a SubscribableListener here + // 1. Do search + // 2. If we don't find it, check in defaultConfigIds + // 3. If we don't find it, make a call to EIS? maybe only if it has -elastic in the name? + // 4. If we still don't find it, return not found + ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds.values()); @@ -366,6 +385,13 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { + // TODO add a SubscribableListener here + // 1. Do search + // 2. If we don't find it, check in defaultConfigIds + // 3. If we don't find it, make a call to EIS? maybe only if it has -elastic in the name? + // 4. If we still don't find it, return not found + + ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds.values(), delegate); From 4750b48992f50631d9637860d44e640793a18fbb Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 26 Sep 2025 10:27:06 -0400 Subject: [PATCH 02/18] Starting changes --- .../inference/registry/ModelRegistry.java | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 3a7171e5211eb..56c431cedc4ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -249,6 +249,38 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId * @param listener Model listener */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + SubscribableListener.newForked(searchResponseListener -> { + QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId); + SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) + .setQuery(queryBuilder) + .setSize(2) + .setAllowPartialSearchResults(false) + .request(); + + client.search(modelSearch, searchResponseListener); + }).andThen((unparsedModelListener, searchResponse) -> { + // There should be a hit for the configurations + if (searchResponse.getHits().getHits().length == 0) { + var maybeDefault = defaultConfigIds.get(inferenceEntityId); + if (maybeDefault != null) { + getDefaultConfig(true, maybeDefault, unparsedModelListener); + } else { + unparsedModelListener.onFailure(inferenceNotFoundException(inferenceEntityId)); + } + return; + } + + unparsedModelListener.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId))); + } ).addListener(listener.delegateResponse((failureListener, e) -> { + logger.warn(format("Failed to load inference endpoint with secrets [%s]", inferenceEntityId), e); + failureListener.onFailure( + new ElasticsearchException( + format("Failed to load inference endpoint with secrets [%s], error: [%s]", inferenceEntityId, e.getMessage()), + e + ) + ); + })); + // TODO add a SubscribableListener here // 1. Do search // 2. If we don't find it, check in defaultConfigIds From a009e366c95eb7db7dfd67821d6bf446db4c2ef3 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 3 Oct 2025 15:31:56 -0400 Subject: [PATCH 03/18] Creating conversion functionality --- .../inference/src/main/java/module-info.java | 1 + .../xpack/inference/InferencePlugin.java | 11 +- .../TransportGetInferenceServicesAction.java | 9 +- .../inference/registry/ModelRegistry.java | 111 ++++++++------ .../elastic/ElasticInferenceService.java | 43 ++---- ...lasticInferenceServiceMinimalSettings.java | 108 ++++++++++++++ ...cInferenceServiceAuthorizationHandler.java | 1 + .../PreconfiguredEndpointsModel.java | 137 ++++++++++++++++++ .../PreconfiguredEndpointsRequestHandler.java | 41 ++++++ 9 files changed, 377 insertions(+), 85 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 6aae961d45048..182e462360eb4 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -37,6 +37,7 @@ requires org.elasticsearch.sslconfig; requires org.apache.commons.text; requires software.amazon.awssdk.services.sagemakerruntime; + requires org.elasticsearch.inference; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 60592c5dd1dbd..447f33308284d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -134,6 +134,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; +import org.elasticsearch.xpack.inference.services.elastic.authorization.PreconfiguredEndpointsRequestHandler; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; @@ -288,9 +289,6 @@ public Collection createComponents(PluginServices services) { var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService()); amazonBedrockFactory.set(amazonBedrockRequestSenderFactory); - modelRegistry.set(new ModelRegistry(services.clusterService(), services.client())); - services.clusterService().addListener(modelRegistry.get()); - if (inferenceServiceExtensions == null) { inferenceServiceExtensions = new ArrayList<>(); } @@ -322,6 +320,11 @@ public Collection createComponents(PluginServices services) { services.threadPool() ); + var eisSender = elasicInferenceServiceFactory.get().createSender(); + var preconfigEndpointsHandler = new PreconfiguredEndpointsRequestHandler(authorizationHandler, eisSender); + modelRegistry.set(new ModelRegistry(services.clusterService(), services.client(), preconfigEndpointsHandler)); + services.clusterService().addListener(modelRegistry.get()); + var sageMakerSchemas = new SageMakerSchemas(); var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas)); inferenceServices.add( @@ -385,7 +388,7 @@ public Collection createComponents(PluginServices services) { ); components.add(inferenceStatsBinding); components.add(authorizationHandler); - components.add(new PluginComponentBinding<>(Sender.class, elasicInferenceServiceFactory.get().createSender())); + components.add(new PluginComponentBinding<>(Sender.class, eisSender)); components.add( new InferenceEndpointRegistry( services.clusterService(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java index 18c83df4067ed..b07de7434f36a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java @@ -21,7 +21,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -35,8 +34,6 @@ import java.util.Map; import java.util.stream.Collectors; -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; - public class TransportGetInferenceServicesAction extends HandledTransportAction< GetInferenceServicesAction.Request, GetInferenceServicesAction.Response> { @@ -46,13 +43,11 @@ public class TransportGetInferenceServicesAction extends HandledTransportAction< private final InferenceServiceRegistry serviceRegistry; private final ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler; private final Sender eisSender; - private final ThreadPool threadPool; @Inject public TransportGetInferenceServicesAction( TransportService transportService, ActionFilters actionFilters, - ThreadPool threadPool, InferenceServiceRegistry serviceRegistry, ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler, Sender sender @@ -67,7 +62,6 @@ public TransportGetInferenceServicesAction( this.serviceRegistry = serviceRegistry; this.eisAuthorizationRequestHandler = eisAuthorizationRequestHandler; this.eisSender = sender; - this.threadPool = threadPool; } @Override @@ -123,8 +117,7 @@ private void getServiceConfigurationsForServicesAndEis( @Nullable TaskType requestedTaskType ) { SubscribableListener.newForked(authModelListener -> { - // Executing on a separate thread because there's a chance the authorization call needs to do some initialization for the Sender - threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> getEisAuthorization(authModelListener, eisSender)); + getEisAuthorization(authModelListener, eisSender); }).>andThen((configurationListener, authorizationModel) -> { var serviceConfigs = getServiceConfigurationsForServices(availableServices); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 56c431cedc4ca..c6f4a959463b1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -76,6 +76,8 @@ import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; +import org.elasticsearch.xpack.inference.services.elastic.authorization.PreconfiguredEndpointsRequestHandler; import java.io.IOException; import java.util.ArrayList; @@ -87,6 +89,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -147,10 +150,15 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private final MasterServiceTaskQueue metadataTaskQueue; private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false); private final Set preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>()); + private final PreconfiguredEndpointsRequestHandler preconfiguredEndpointsRequestHandler; private volatile Metadata lastMetadata; - public ModelRegistry(ClusterService clusterService, Client client) { + public ModelRegistry( + ClusterService clusterService, + Client client, + PreconfiguredEndpointsRequestHandler preconfiguredEndpointsRequestHandler + ) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); this.defaultConfigIds = new ConcurrentHashMap<>(); var executor = new SimpleBatchedAckListenerTaskExecutor() { @@ -163,6 +171,7 @@ public Tuple executeTask(MetadataTask tas } }; this.metadataTaskQueue = clusterService.createTaskQueue("model_registry", Priority.NORMAL, executor); + this.preconfiguredEndpointsRequestHandler = Objects.requireNonNull(preconfiguredEndpointsRequestHandler); } /** @@ -249,6 +258,13 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId * @param listener Model listener */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + // If we know it's an EIS preconfigured endpoint, skip looking in the index because it could have an outdated version of the + // endpoint and go directly to EIS to retrieve it + if (ElasticInferenceServiceMinimalSettings.isEisPreconfiguredEndpoint(inferenceEntityId)) { + retrievePreconfiguredEndpointFromEisElseNotAuthorized(listener, inferenceEntityId); + return; + } + SubscribableListener.newForked(searchResponseListener -> { QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId); SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) @@ -259,19 +275,21 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListenerandThen((unparsedModelListener, searchResponse) -> { - // There should be a hit for the configurations - if (searchResponse.getHits().getHits().length == 0) { - var maybeDefault = defaultConfigIds.get(inferenceEntityId); - if (maybeDefault != null) { - getDefaultConfig(true, maybeDefault, unparsedModelListener); - } else { - unparsedModelListener.onFailure(inferenceNotFoundException(inferenceEntityId)); - } + // We likely found the configuration, so parse it and return it + if (searchResponse.getHits().getHits().length != 0) { + unparsedModelListener.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId))); return; } - unparsedModelListener.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId))); - } ).addListener(listener.delegateResponse((failureListener, e) -> { + // we didn't find the configuration in the inference index, so check if it is a pre-configured endpoint + var maybeDefault = defaultConfigIds.get(inferenceEntityId); + if (maybeDefault != null) { + getDefaultConfig(true, maybeDefault, unparsedModelListener); + return; + } + + retrievePreconfiguredEndpointFromEisElseNotFound(unparsedModelListener, inferenceEntityId); + }).addListener(listener.delegateResponse((failureListener, e) -> { logger.warn(format("Failed to load inference endpoint with secrets [%s]", inferenceEntityId), e); failureListener.onFailure( new ElasticsearchException( @@ -280,44 +298,39 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener, String inferenceEntityId) { + var eisFailureListener = listener.delegateResponse( + (delegate, e) -> { delegate.onFailure(eisNotAuthorizedException(inferenceEntityId)); } + ); - ActionListener searchListener = ActionListener.wrap((searchResponse) -> { - // There should be a hit for the configurations - if (searchResponse.getHits().getHits().length == 0) { - var maybeDefault = defaultConfigIds.get(inferenceEntityId); - if (maybeDefault != null) { - getDefaultConfig(true, maybeDefault, listener); - } else { - listener.onFailure(inferenceNotFoundException(inferenceEntityId)); - } - return; - } + retrieveEisPreconfiguredEndpoint(eisFailureListener, inferenceEntityId); + } - listener.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId))); - }, (e) -> { - logger.warn(format("Failed to load inference endpoint with secrets [%s]", inferenceEntityId), e); - listener.onFailure( - new ElasticsearchException( - format("Failed to load inference endpoint with secrets [%s], error: [%s]", inferenceEntityId, e.getMessage()), - e - ) - ); - }); + private ElasticsearchStatusException eisNotAuthorizedException(String inferenceEntityId) { + return new ElasticsearchStatusException( + "Unauthorized to access inference endpoint [{}]", + RestStatus.UNAUTHORIZED, + inferenceEntityId + ); + } - QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId); - SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) - .setQuery(queryBuilder) - .setSize(2) - .setAllowPartialSearchResults(false) - .request(); + private void retrievePreconfiguredEndpointFromEisElseNotFound(ActionListener listener, String inferenceEntityId) { + var eisFailureListener = listener.delegateResponse( + (delegate, e) -> { delegate.onFailure(inferenceNotFoundException(inferenceEntityId)); } + ); - client.search(modelSearch, searchListener); + retrieveEisPreconfiguredEndpoint(eisFailureListener, inferenceEntityId); + } + + private void retrieveEisPreconfiguredEndpoint(ActionListener listener, String inferenceEntityId) { + var eisFailureListener = listener.delegateResponse((delegate, e) -> { + logger.debug("Failed to retrieve preconfigured endpoint from EIS", e); + delegate.onFailure(e); + }); + + preconfiguredEndpointsRequestHandler.getPreconfiguredEndpointAsUnparsedModel(inferenceEntityId, eisFailureListener); } /** @@ -327,6 +340,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + // TODO add a SubscribableListener here // 1. Do search // 2. If we don't find it, check in defaultConfigIds @@ -370,7 +384,7 @@ public void getModel(String inferenceEntityId, ActionListener lis private ResourceNotFoundException inferenceNotFoundException(String inferenceEntityId) { // TODO add some logic here to check if it is EIS related and return a message indicating that the endpoint may not be authorized - return new ResourceNotFoundException("Inference endpoint not found [{}]", inferenceEntityId); + return new ResourceNotFoundException("Inference endpoint [{}] not found or you are not authorized to access it", inferenceEntityId); } /** @@ -380,6 +394,9 @@ private ResourceNotFoundException inferenceNotFoundException(String inferenceEnt * @param listener Models listener */ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { + // TODO we need to explicitly filter out any existing EIS PIEs because after we move the authorization logic to the master node we + // won't be removing them from the index anymore + // TODO add a SubscribableListener here // 1. Do search // 2. If we don't find it, check in defaultConfigIds @@ -418,12 +435,14 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { // TODO add a SubscribableListener here + // TODO we need to explicitly filter out any existing EIS PIEs because after we move the authorization logic to the master node we + // won't be removing them from the index anymore + // 1. Do search // 2. If we don't find it, check in defaultConfigIds // 3. If we don't find it, make a call to EIS? maybe only if it has -elastic in the name? // 4. If we still don't find it, return not found - ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds.values(), delegate); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index c4156c0bfd6b9..aae03ed13cd05 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -16,7 +16,6 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; @@ -26,7 +25,6 @@ import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -85,6 +83,18 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.CHAT_COMPLETION_V1_MINIMAL_SETTINGS; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_2_MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_MODEL_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.ELSER_V2_MINIMAL_SETTINGS; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.MULTILINGUAL_EMBED_MINIMAL_SETTINGS; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.RERANK_V1_MINIMAL_SETTINGS; public class ElasticInferenceService extends SenderService { @@ -109,22 +119,6 @@ public class ElasticInferenceService extends SenderService { // This mirrors the memory constraints observed with sparse embeddings private static final Integer DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE = 16; - // rainbow-sprinkles - static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; - static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); - - // elser-2 - static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; - static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2"); - - // multilingual-text-embed - static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed-v1"; - static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); - - // rerank-v1 - static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; - static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); - /** * The task types that the {@link InferenceAction.Request} can accept. */ @@ -198,7 +192,7 @@ private static Map initDefaultEndpoints( EmptySecretSettings.INSTANCE, elasticInferenceServiceComponents ), - MinimalServiceSettings.chatCompletion(NAME) + CHAT_COMPLETION_V1_MINIMAL_SETTINGS ), DEFAULT_ELSER_2_MODEL_ID, new DefaultModelConfig( @@ -212,7 +206,7 @@ private static Map initDefaultEndpoints( elasticInferenceServiceComponents, ChunkingSettingsBuilder.DEFAULT_SETTINGS ), - MinimalServiceSettings.sparseEmbedding(NAME) + ELSER_V2_MINIMAL_SETTINGS ), DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, new DefaultModelConfig( @@ -231,12 +225,7 @@ private static Map initDefaultEndpoints( elasticInferenceServiceComponents, ChunkingSettingsBuilder.DEFAULT_SETTINGS ), - MinimalServiceSettings.textEmbedding( - NAME, - DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ) + MULTILINGUAL_EMBED_MINIMAL_SETTINGS ), DEFAULT_RERANK_MODEL_ID_V1, new DefaultModelConfig( @@ -249,7 +238,7 @@ private static Map initDefaultEndpoints( EmptySecretSettings.INSTANCE, elasticInferenceServiceComponents ), - MinimalServiceSettings.rerank(NAME) + RERANK_V1_MINIMAL_SETTINGS ) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java new file mode 100644 index 0000000000000..e115bae97c207 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.MinimalServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; + +import java.util.Map; +import java.util.Set; + +import static java.util.stream.Collectors.toMap; + +public class ElasticInferenceServiceMinimalSettings { + + // rainbow-sprinkles + static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; + static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); + + // elser-2 + static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; + static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2"); + + // multilingual-text-embed + static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024; + static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed-v1"; + static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); + + // rerank-v1 + static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; + static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); + + private static final Set EIS_PRECONFIGURED_ENDPOINTS = Set.of( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + DEFAULT_ELSER_ENDPOINT_ID_V2, + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + DEFAULT_RERANK_ENDPOINT_ID_V1 + ); + + static final MinimalServiceSettings CHAT_COMPLETION_V1_MINIMAL_SETTINGS = MinimalServiceSettings.chatCompletion( + ElasticInferenceService.NAME + ); + static final MinimalServiceSettings ELSER_V2_MINIMAL_SETTINGS = MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME); + static final MinimalServiceSettings MULTILINGUAL_EMBED_MINIMAL_SETTINGS = MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ); + static final MinimalServiceSettings RERANK_V1_MINIMAL_SETTINGS = MinimalServiceSettings.rerank(ElasticInferenceService.NAME); + + public record SettingsWithEndpointInfo(String inferenceId, String modelId, MinimalServiceSettings minimalSettings) {} + + private static final Map MODEL_NAME_TO_MINIMAL_SETTINGS = Map.of( + DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, + new SettingsWithEndpointInfo( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, + CHAT_COMPLETION_V1_MINIMAL_SETTINGS + ), + DEFAULT_ELSER_2_MODEL_ID, + new SettingsWithEndpointInfo(DEFAULT_ELSER_ENDPOINT_ID_V2, DEFAULT_ELSER_2_MODEL_ID, ELSER_V2_MINIMAL_SETTINGS), + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + new SettingsWithEndpointInfo( + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + MULTILINGUAL_EMBED_MINIMAL_SETTINGS + ), + DEFAULT_RERANK_MODEL_ID_V1, + new SettingsWithEndpointInfo(DEFAULT_RERANK_ENDPOINT_ID_V1, DEFAULT_RERANK_MODEL_ID_V1, RERANK_V1_MINIMAL_SETTINGS) + ); + + private static final Map INFERENCE_ID_TO_MINIMAL_SETTINGS = MODEL_NAME_TO_MINIMAL_SETTINGS.entrySet() + .stream() + .collect(toMap(e -> e.getValue().inferenceId(), Map.Entry::getValue)); + + public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { + return SimilarityMeasure.COSINE; + } + + public static String defaultEndpointId(String modelId) { + return Strings.format(".%s-elastic", modelId); + } + + public static boolean isEisPreconfiguredEndpoint(String inferenceEntityId) { + return EIS_PRECONFIGURED_ENDPOINTS.contains(inferenceEntityId); + } + + public static boolean containsModelName(String modelName) { + return MODEL_NAME_TO_MINIMAL_SETTINGS.containsKey(modelName); + } + + public static SettingsWithEndpointInfo getWithModelName(String modelName) { + return MODEL_NAME_TO_MINIMAL_SETTINGS.get(modelName); + } + + public static SettingsWithEndpointInfo getWithInferenceId(String inferenceId) { + return INFERENCE_ID_TO_MINIMAL_SETTINGS.get(inferenceId); + } + + private ElasticInferenceServiceMinimalSettings() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java index f83542e7fe740..bea9523d9576b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java @@ -246,6 +246,7 @@ private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizat new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) ); + // TODO remove adding it to the registry, I think we can still revoke for now though authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent); handleRevokedDefaultConfigs(authorizedDefaultModelIds); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java new file mode 100644 index 0000000000000..0c34545c689e6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; + +public record PreconfiguredEndpointsModel(Map preconfiguredEndpoints) { + + public static PreconfiguredEndpointsModel of(ElasticInferenceServiceAuthorizationModel authModel) { + // TODO convert the auth model to a list of preconfigured endpoints + // iterate over the authorized model ids and retrieve the configurations from a new class that has the information + + var endpoints = authModel.getAuthorizedModelIds() + .stream() + .filter(ElasticInferenceServiceMinimalSettings::containsModelName) + .map((modelId) -> of(ElasticInferenceServiceMinimalSettings.getWithModelName(modelId))) + .filter(Objects::nonNull).collect(Collectors.toMap(PreconfiguredEndpoint::inferenceEntityId, Function.identity())); + + return new PreconfiguredEndpointsModel(endpoints); + } + + private static PreconfiguredEndpoint of(ElasticInferenceServiceMinimalSettings.SettingsWithEndpointInfo settings) { + return switch (settings.minimalSettings().taskType()) { + case TEXT_EMBEDDING -> { + if (settings.minimalSettings().dimensions() == null + || settings.minimalSettings().similarity() == null + || settings.minimalSettings().elementType() == null) { + // TODO log a warning + yield null; + } + + yield new EmbeddingPreConfiguredEndpoint( + settings.inferenceId(), + settings.minimalSettings().taskType(), + settings.modelId(), + settings.minimalSettings().similarity(), + settings.minimalSettings().dimensions(), + settings.minimalSettings().elementType() + ); + } + case SPARSE_EMBEDDING, RERANK, COMPLETION, CHAT_COMPLETION -> new BasePreconfiguredEndpoint( + settings.inferenceId(), + settings.minimalSettings().taskType(), + settings.modelId() + ); + case ANY -> null; + }; + } + + public sealed interface PreconfiguredEndpoint permits BasePreconfiguredEndpoint, EmbeddingPreConfiguredEndpoint { + String inferenceEntityId(); + + TaskType taskType(); + + String modelId(); + + UnparsedModel toUnparsedModel(); + } + + private record EmbeddingPreConfiguredEndpoint( + String inferenceEntityId, + TaskType taskType, + String modelId, + SimilarityMeasure similarity, + int dimension, + DenseVectorFieldMapper.ElementType elementType + ) implements PreconfiguredEndpoint { + + @Override + public UnparsedModel toUnparsedModel() { + return new UnparsedModel( + inferenceEntityId, + taskType, + ElasticInferenceService.NAME, + embeddingSettings(modelId, similarity, dimension, elementType), + Map.of() + ); + } + } + + private static Map embeddingSettings( + String modelId, + SimilarityMeasure similarityMeasure, + int dimension, + DenseVectorFieldMapper.ElementType elementType + ) { + return new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + ServiceFields.SIMILARITY, + similarityMeasure.toString(), + ServiceFields.DIMENSIONS, + dimension, + ServiceFields.ELEMENT_TYPE, + elementType.toString() + ) + ); + } + + private record BasePreconfiguredEndpoint(String inferenceEntityId, TaskType taskType, String modelId) implements PreconfiguredEndpoint { + @Override + public UnparsedModel toUnparsedModel() { + return new UnparsedModel(inferenceEntityId, taskType, ElasticInferenceService.NAME, settingsWithModelId(modelId), Map.of()); + } + } + + private static Map settingsWithModelId(String modelId) { + return new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)); + } + + public UnparsedModel toUnparsedModel(String inferenceId) { + PreconfiguredEndpoint endpoint = preconfiguredEndpoints.get(inferenceId); + if (endpoint == null) { + throw new IllegalArgumentException("No EIS preconfigured endpoint found for inference ID: " + inferenceId); + } + + return endpoint.toUnparsedModel(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java new file mode 100644 index 0000000000000..59f4e6d1a2cc1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; + +import java.util.Objects; + +/** + * This class is responsible for converting the current EIS authorization response structure + * into Models that + */ +public class PreconfiguredEndpointsRequestHandler { + private final ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler; + private final Sender sender; + + public PreconfiguredEndpointsRequestHandler( + ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler, + Sender sender + ) { + this.eisAuthorizationRequestHandler = Objects.requireNonNull(eisAuthorizationRequestHandler); + this.sender = Objects.requireNonNull(sender); + } + + public void getPreconfiguredEndpointAsUnparsedModel(String inferenceId, ActionListener listener) { + SubscribableListener.newForked(authListener -> { + eisAuthorizationRequestHandler.getAuthorization(authListener, sender); + }) + .andThenApply(PreconfiguredEndpointsModel::of) + .andThenApply(preconfiguredEndpointsModel -> preconfiguredEndpointsModel.toUnparsedModel(inferenceId)) + .addListener(listener); + } +} From e44d20085a053d44ac3ab9a1188f7942e52217d1 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 6 Oct 2025 15:07:09 -0400 Subject: [PATCH 04/18] Trying to figure out bug --- .../inference/src/main/java/module-info.java | 1 - .../inference/registry/ModelRegistry.java | 257 +++++++++--------- ...lasticInferenceServiceMinimalSettings.java | 2 +- ...cInferenceServiceAuthorizationHandler.java | 2 +- .../PreconfiguredEndpointsModel.java | 13 +- .../PreconfiguredEndpointsRequestHandler.java | 10 + 6 files changed, 150 insertions(+), 135 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 182e462360eb4..6aae961d45048 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -37,7 +37,6 @@ requires org.elasticsearch.sslconfig; requires org.apache.commons.text; requires software.amazon.awssdk.services.sagemakerruntime; - requires org.elasticsearch.inference; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index c6f4a959463b1..f8c6fe7442178 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -49,6 +49,7 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.gateway.GatewayService; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.reindex.BulkByScrollResponse; @@ -95,9 +96,11 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.EIS_PRECONFIGURED_ENDPOINTS; /** * A class responsible for persisting and reading inference endpoint configurations. @@ -258,6 +261,49 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId * @param listener Model listener */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + getModelHelper( + inferenceEntityId, + client.prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) + .setQuery(documentIdQuery(inferenceEntityId)) + .setSize(2) + .setAllowPartialSearchResults(false) + .request(), + searchResponse -> unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId)), + listener + ); + } + + /** + * Get a model. + * Secret settings are not included + * @param inferenceEntityId Model to get + * @param listener Model listener + */ + public void getModel(String inferenceEntityId, ActionListener listener) { + getModelHelper( + inferenceEntityId, + client.prepareSearch(InferenceIndex.INDEX_PATTERN) + .setQuery(documentIdQuery(inferenceEntityId)) + .setSize(1) + .setTrackTotalHits(false) + .request(), + searchResponse -> { + var modelConfigs = parseHitsAsModelsWithoutSecrets(searchResponse.getHits()).stream() + .map(ModelRegistry::unparsedModelFromMap) + .toList(); + assert modelConfigs.size() == 1; + return modelConfigs.get(0); + }, + listener + ); + } + + private void getModelHelper( + String inferenceEntityId, + SearchRequest modelSearch, + Function unparsedModelCreator, + ActionListener listener + ) { // If we know it's an EIS preconfigured endpoint, skip looking in the index because it could have an outdated version of the // endpoint and go directly to EIS to retrieve it if (ElasticInferenceServiceMinimalSettings.isEisPreconfiguredEndpoint(inferenceEntityId)) { @@ -265,39 +311,32 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListenernewForked(searchResponseListener -> { - QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId); - SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) - .setQuery(queryBuilder) - .setSize(2) - .setAllowPartialSearchResults(false) - .request(); - - client.search(modelSearch, searchResponseListener); - }).andThen((unparsedModelListener, searchResponse) -> { - // We likely found the configuration, so parse it and return it - if (searchResponse.getHits().getHits().length != 0) { - unparsedModelListener.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId))); - return; - } + SubscribableListener.newForked(searchResponseListener -> client.search(modelSearch, searchResponseListener)) + .andThen((unparsedModelListener, searchResponse) -> { + // We likely found the configuration, so parse it and return it + if (searchResponse.getHits().getHits().length != 0) { + unparsedModelListener.onResponse(unparsedModelCreator.apply(searchResponse)); + return; + } - // we didn't find the configuration in the inference index, so check if it is a pre-configured endpoint - var maybeDefault = defaultConfigIds.get(inferenceEntityId); - if (maybeDefault != null) { - getDefaultConfig(true, maybeDefault, unparsedModelListener); - return; - } + // we didn't find the configuration in the inference index, so check if it is a pre-configured endpoint + var maybeDefault = defaultConfigIds.get(inferenceEntityId); + if (maybeDefault != null) { + getDefaultConfig(true, maybeDefault, unparsedModelListener); + return; + } - retrievePreconfiguredEndpointFromEisElseNotFound(unparsedModelListener, inferenceEntityId); - }).addListener(listener.delegateResponse((failureListener, e) -> { - logger.warn(format("Failed to load inference endpoint with secrets [%s]", inferenceEntityId), e); - failureListener.onFailure( - new ElasticsearchException( - format("Failed to load inference endpoint with secrets [%s], error: [%s]", inferenceEntityId, e.getMessage()), - e - ) - ); - })); + retrievePreconfiguredEndpointFromEisElseNotFound(unparsedModelListener, inferenceEntityId); + }) + .addListener(listener.delegateResponse((failureListener, e) -> { + logger.warn(format("Failed to load inference endpoint [%s]", inferenceEntityId), e); + failureListener.onFailure( + new ElasticsearchException( + format("Failed to load inference endpoint [%s], error: [%s]", inferenceEntityId, e.getMessage()), + e + ) + ); + })); } private void retrievePreconfiguredEndpointFromEisElseNotAuthorized(ActionListener listener, String inferenceEntityId) { @@ -333,57 +372,7 @@ private void retrieveEisPreconfiguredEndpoint(ActionListener list preconfiguredEndpointsRequestHandler.getPreconfiguredEndpointAsUnparsedModel(inferenceEntityId, eisFailureListener); } - /** - * Get a model. - * Secret settings are not included - * @param inferenceEntityId Model to get - * @param listener Model listener - */ - public void getModel(String inferenceEntityId, ActionListener listener) { - - // TODO add a SubscribableListener here - // 1. Do search - // 2. If we don't find it, check in defaultConfigIds - // 3. If we don't find it, make a call to EIS? maybe only if it has -elastic in the name? - // 4. If we still don't find it, return not found - - ActionListener searchListener = ActionListener.wrap((searchResponse) -> { - // There should be a hit for the configurations - if (searchResponse.getHits().getHits().length == 0) { - var maybeDefault = defaultConfigIds.get(inferenceEntityId); - if (maybeDefault != null) { - getDefaultConfig(true, maybeDefault, listener); - } else { - listener.onFailure(inferenceNotFoundException(inferenceEntityId)); - } - return; - } - - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); - assert modelConfigs.size() == 1; - listener.onResponse(modelConfigs.get(0)); - }, e -> { - logger.warn(format("Failed to load inference endpoint [%s]", inferenceEntityId), e); - listener.onFailure( - new ElasticsearchException( - format("Failed to load inference endpoint [%s], error: [%s]", inferenceEntityId, e.getMessage()), - e - ) - ); - }); - - QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId); - SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) - .setQuery(queryBuilder) - .setSize(1) - .setTrackTotalHits(false) - .request(); - - client.search(modelSearch, searchListener); - } - private ResourceNotFoundException inferenceNotFoundException(String inferenceEntityId) { - // TODO add some logic here to check if it is EIS related and return a message indicating that the endpoint may not be authorized return new ResourceNotFoundException("Inference endpoint [{}] not found or you are not authorized to access it", inferenceEntityId); } @@ -394,31 +383,58 @@ private ResourceNotFoundException inferenceNotFoundException(String inferenceEnt * @param listener Models listener */ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { - // TODO we need to explicitly filter out any existing EIS PIEs because after we move the authorization logic to the master node we - // won't be removing them from the index anymore - - // TODO add a SubscribableListener here - // 1. Do search - // 2. If we don't find it, check in defaultConfigIds - // 3. If we don't find it, make a call to EIS? maybe only if it has -elastic in the name? - // 4. If we still don't find it, return not found - - ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); - var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds.values()); - addAllDefaultConfigsIfMissing(true, modelConfigs, defaultConfigsForTaskType, delegate); - }); + getModelsHelper( + QueryBuilders.boolQuery().filter(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString())), + () -> taskTypeMatchedDefaults(taskType, defaultConfigIds.values()), + true, + listener + ); + } + + public void getModelsHelper( + BoolQueryBuilder boolQueryBuilder, + Supplier> defaultConfigIdsSupplier, + boolean persistDefaultEndpoints, + ActionListener> listener + ) { + SubscribableListener.newForked(searchResponseListener -> { + var eisEndpointIds = EIS_PRECONFIGURED_ENDPOINTS.stream().map(Model::documentId).toArray(String[]::new); + + // exclude the EIS preconfigured endpoints so we can query EIS directly for them + var queryBuilder = boolQueryBuilder.filter(QueryBuilders.boolQuery().mustNot(QueryBuilders.idsQuery().addIds(eisEndpointIds))); - QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString())); + var modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) + .setQuery(queryBuilder) + // .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.existsQuery(TASK_TYPE_FIELD))) + .setSize(10_000) + .setTrackTotalHits(false) + .addSort(MODEL_ID_FIELD, SortOrder.ASC) + .request(); - SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) - .setQuery(queryBuilder) - .setSize(10_000) - .setTrackTotalHits(false) - .addSort(MODEL_ID_FIELD, SortOrder.ASC) - .request(); + client.search(modelSearch, searchResponseListener); + }).>andThen((missingDefaultConfigsAddedListener, searchResponse) -> { + var modelConfigs = parseHitsAsModelsWithoutSecrets(searchResponse.getHits()).stream() + .map(ModelRegistry::unparsedModelFromMap) + .toList(); + addAllDefaultConfigsIfMissing( + persistDefaultEndpoints, + modelConfigs, + defaultConfigIdsSupplier.get(), + missingDefaultConfigsAddedListener + ); + }).>andThen((eisPreconfiguredEndpointsAddedListener, unparsedModels) -> { + ActionListener> eisListener = ActionListener.wrap(response -> { + var allModels = new ArrayList<>(unparsedModels); + allModels.addAll(response); + allModels.sort(Comparator.comparing(UnparsedModel::inferenceEntityId)); + eisPreconfiguredEndpointsAddedListener.onResponse(allModels); + }, e -> { + logger.debug("Failed to retrieve preconfigured endpoint from EIS", e); + eisPreconfiguredEndpointsAddedListener.onResponse(unparsedModels); + }); - client.search(modelSearch, searchListener); + preconfiguredEndpointsRequestHandler.getAllPreconfiguredEndpointsAsUnparsedModels(eisListener); + }).addListener(listener); } /** @@ -434,33 +450,12 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { - // TODO add a SubscribableListener here - // TODO we need to explicitly filter out any existing EIS PIEs because after we move the authorization logic to the master node we - // won't be removing them from the index anymore - - // 1. Do search - // 2. If we don't find it, check in defaultConfigIds - // 3. If we don't find it, make a call to EIS? maybe only if it has -elastic in the name? - // 4. If we still don't find it, return not found - - ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { - var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); - addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds.values(), delegate); - }); - - // In theory the index should only contain model config documents - // and a match all query would be sufficient. But just in case the - // index has been polluted return only docs with a task_type field - QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.existsQuery(TASK_TYPE_FIELD)); - - SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) - .setQuery(queryBuilder) - .setSize(10_000) - .setTrackTotalHits(false) - .addSort(MODEL_ID_FIELD, SortOrder.ASC) - .request(); - - client.search(modelSearch, searchListener); + getModelsHelper( + QueryBuilders.boolQuery().filter(QueryBuilders.constantScoreQuery(QueryBuilders.existsQuery(TASK_TYPE_FIELD))), + () -> new ArrayList<>(defaultConfigIds.values()), + persistDefaultEndpoints, + listener + ); } private void addAllDefaultConfigsIfMissing( @@ -535,7 +530,7 @@ private void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) { storeModel(preconfigured, false, ActionListener.runAfter(responseListener, runAfter), AcknowledgedRequest.DEFAULT_ACK_TIMEOUT); } - private ArrayList parseHitsAsModels(SearchHits hits) { + private ArrayList parseHitsAsModelsWithoutSecrets(SearchHits hits) { var modelConfigs = new ArrayList(); for (var hit : hits) { modelConfigs.add(new ModelConfigMap(hit.getSourceAsMap(), Map.of())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java index e115bae97c207..1b349737a16ac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java @@ -36,7 +36,7 @@ public class ElasticInferenceServiceMinimalSettings { static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); - private static final Set EIS_PRECONFIGURED_ENDPOINTS = Set.of( + public static final Set EIS_PRECONFIGURED_ENDPOINTS = Set.of( DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, DEFAULT_ELSER_ENDPOINT_ID_V2, DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java index bea9523d9576b..d916346863731 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java @@ -247,7 +247,7 @@ private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizat ); // TODO remove adding it to the registry, I think we can still revoke for now though - authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent); + // authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent); handleRevokedDefaultConfigs(authorizedDefaultModelIds); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java index 0c34545c689e6..9932e041244e6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java @@ -15,7 +15,9 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; +import java.util.Comparator; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.Function; @@ -31,7 +33,8 @@ public static PreconfiguredEndpointsModel of(ElasticInferenceServiceAuthorizatio .stream() .filter(ElasticInferenceServiceMinimalSettings::containsModelName) .map((modelId) -> of(ElasticInferenceServiceMinimalSettings.getWithModelName(modelId))) - .filter(Objects::nonNull).collect(Collectors.toMap(PreconfiguredEndpoint::inferenceEntityId, Function.identity())); + .filter(Objects::nonNull) + .collect(Collectors.toMap(PreconfiguredEndpoint::inferenceEntityId, Function.identity())); return new PreconfiguredEndpointsModel(endpoints); } @@ -134,4 +137,12 @@ public UnparsedModel toUnparsedModel(String inferenceId) { return endpoint.toUnparsedModel(); } + + public List toUnparsedModels() { + return preconfiguredEndpoints.values() + .stream() + .map(PreconfiguredEndpoint::toUnparsedModel) + .sorted(Comparator.comparing(UnparsedModel::inferenceEntityId)) + .toList(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java index 59f4e6d1a2cc1..6cfe163b21904 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import java.util.List; import java.util.Objects; /** @@ -38,4 +39,13 @@ public void getPreconfiguredEndpointAsUnparsedModel(String inferenceId, ActionLi .andThenApply(preconfiguredEndpointsModel -> preconfiguredEndpointsModel.toUnparsedModel(inferenceId)) .addListener(listener); } + + public void getAllPreconfiguredEndpointsAsUnparsedModels(ActionListener> listener) { + SubscribableListener.newForked(authListener -> { + eisAuthorizationRequestHandler.getAuthorization(authListener, sender); + }) + .andThenApply(PreconfiguredEndpointsModel::of) + .andThenApply(PreconfiguredEndpointsModel::toUnparsedModels) + .addListener(listener); + } } From 5e5e17e9f71672567d676fb13fdf847b04066f18 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 7 Oct 2025 16:55:58 -0400 Subject: [PATCH 05/18] Starting test changes --- .../InferenceRevokeDefaultEndpointsIT.java | 359 ------------------ 1 file changed, 359 deletions(-) delete mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java deleted file mode 100644 index 72109e43bb6ac..0000000000000 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ /dev/null @@ -1,359 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.integration; - -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.MinimalServiceSettings; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.reindex.ReindexPlugin; -import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.http.MockResponse; -import org.elasticsearch.test.http.MockWebServer; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; -import org.junit.After; -import org.junit.Before; - -import java.util.Collection; -import java.util.EnumSet; -import java.util.List; -import java.util.concurrent.TimeUnit; - -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; -import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.mockito.Mockito.mock; - -@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 -public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { - private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); - - private ModelRegistry modelRegistry; - private final MockWebServer webServer = new MockWebServer(); - private ThreadPool threadPool; - private String gatewayUrl; - - @Before - public void createComponents() throws Exception { - threadPool = createThreadPool(inferenceUtilityExecutors()); - webServer.start(); - gatewayUrl = getUrl(webServer); - modelRegistry = node().injector().getInstance(ModelRegistry.class); - } - - @After - public void shutdown() { - terminate(threadPool); - webServer.close(); - } - - @Override - protected boolean resetNodeAfterTest() { - return true; - } - - @Override - protected Collection> getPlugins() { - return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); - } - - public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - } - } - - public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception { - { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - - var getModelListener = new PlainActionFuture(); - // persists the default endpoints - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var inferenceEntity = getModelListener.actionGet(TIMEOUT); - assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION)); - } - } - { - String noAuthorizationResponseJson = """ - { - "models": [] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertTrue(service.defaultConfigIds().isEmpty()); - assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); - - var getModelListener = new PlainActionFuture(); - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]")); - } - } - } - - public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt() throws Exception { - { - String responseJson = """ - { - "models": [ - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - }, - { - "model_name": "multilingual-embed-v1", - "task_types": ["embed/text/dense"] - }, - { - "model_name": "rerank-v1", - "task_types": ["rerank/text/text-similarity"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - containsInAnyOrder( - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".multilingual-embed-v1-elastic", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING)) - ); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic")); - assertThat( - listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), - is(".multilingual-embed-v1-elastic") - ); - assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(listener.actionGet(TIMEOUT).get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic")); - - var getModelListener = new PlainActionFuture(); - // persists the default endpoints - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var inferenceEntity = getModelListener.actionGet(TIMEOUT); - assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION)); - } - } - { - String noAuthorizationResponseJson = """ - { - "models": [ - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "rerank-v1", - "task_types": ["rerank/text/text-similarity"] - }, - { - "model_name": "multilingual-embed-v1", - "task_types": ["embed/text/dense"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - containsInAnyOrder( - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".multilingual-embed-v1-elastic", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)) - ); - - var getModelListener = new PlainActionFuture(); - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]")); - } - } - } - - private void ensureAuthorizationCallFinished(ElasticInferenceService service) { - service.onNodeStarted(); - service.waitForFirstAuthorizationToComplete(TIMEOUT); - } - - private ElasticInferenceService createElasticInferenceService() { - var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager); - - return new ElasticInferenceService( - senderFactory, - createWithEmptySettings(threadPool), - ElasticInferenceServiceSettingsTests.create(gatewayUrl), - modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool), - mockClusterServiceEmpty() - ); - } -} From b4b80a4ee53359fb2640c7195949827c17dfba33 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 7 Oct 2025 16:56:10 -0400 Subject: [PATCH 06/18] Test changes --- .../integration/ModelRegistryIT.java | 52 +++- .../xpack/inference/InferencePlugin.java | 4 +- .../inference/registry/ModelRegistry.java | 135 +++++++---- .../elastic/ElasticInferenceService.java | 42 +--- ...lasticInferenceServiceMinimalSettings.java | 4 +- .../ElasticInferenceServiceSettings.java | 3 +- ...cInferenceServiceAuthorizationHandler.java | 6 +- ...nceServiceAuthorizationRequestHandler.java | 2 +- .../mapper/SemanticTextFieldMapperTests.java | 4 +- ...erceptedInferenceQueryBuilderTestCase.java | 6 +- .../SemanticMultiMatchQueryBuilderTests.java | 5 +- .../queries/SemanticQueryBuilderTests.java | 4 +- .../registry/ModelRegistryTests.java | 16 ++ .../elastic/ElasticInferenceServiceTests.java | 224 +++++++----------- ...renceServiceAuthorizationHandlerTests.java | 8 + 15 files changed, 278 insertions(+), 237 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 92eea9599ec5d..71b2c7acc323e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -36,6 +36,8 @@ import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -44,11 +46,14 @@ import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistryTests; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettingsTests; +import org.junit.AfterClass; import org.junit.Before; +import org.junit.BeforeClass; import java.io.IOException; import java.util.ArrayList; @@ -65,6 +70,8 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; @@ -81,6 +88,17 @@ public class ModelRegistryIT extends ESSingleNodeTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private ModelRegistry modelRegistry; + private static final MockWebServer webServer = new MockWebServer(); + + @BeforeClass + public static void init() throws Exception { + webServer.start(); + } + + @AfterClass + public static void shutdown() { + webServer.close(); + } @Before public void createComponents() { @@ -93,13 +111,18 @@ protected Collection> getPlugins() { return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); } - public void testStoreModel() throws Exception { + @Override + protected Settings nodeSettings() { + return Settings.builder().put(super.nodeSettings()).put(ELASTIC_INFERENCE_SERVICE_URL.getKey(), getUrl(webServer)).build(); + } + + public void testStoreModel() { String inferenceEntityId = "test-store-model"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); ModelRegistryTests.assertStoreModel(modelRegistry, model); } - public void testStoreModelWithUnknownFields() throws Exception { + public void testStoreModelWithUnknownFields() { String inferenceEntityId = "test-store-model-unknown-field"; Model model = buildModelWithUnknownField(inferenceEntityId); ElasticsearchStatusException statusException = expectThrows( @@ -145,7 +168,7 @@ public void testGetModel() throws Exception { assertEquals(model, roundTripModel); } - public void testStoreModelFailsWhenModelExists() throws Exception { + public void testStoreModelFailsWhenModelExists() { String inferenceEntityId = "test-put-trained-model-config-exists"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); ModelRegistryTests.assertStoreModel(modelRegistry, model); @@ -175,7 +198,7 @@ public void testDeleteModel() throws Exception { assertThat(exceptionHolder.get(), not(nullValue())); assertFalse(deleteResponseHolder.get()); - assertThat(exceptionHolder.get().getMessage(), containsString("Inference endpoint not found [model1]")); + assertThat(exceptionHolder.get().getMessage(), containsString("Inference endpoint [model1] not found")); } public void testNonExistentDeleteModel_DoesNotThrowAnException() { @@ -577,6 +600,27 @@ public void testGetByTaskType_WithDefaults() throws Exception { assertReturnModelIsModifiable(modelHolder.get().get(0)); } + public void testGetModel_RetrievesAnEisPreconfiguredEndpoint() { + var responseJson = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = new PlainActionFuture<>(); + modelRegistry.getModel(ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, listener); + + var model = listener.actionGet(TIMEOUT); + assertThat(model.inferenceEntityId(), is(ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + assertThat(model.taskType(), is(TaskType.CHAT_COMPLETION)); + } + private void assertInferenceIndexExists() { var indexResponse = client().admin().indices().prepareGetIndex(TEST_REQUEST_TIMEOUT).addIndices(".inference").get(); assertNotNull(indexResponse.getSettings()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 447f33308284d..21ce5b79de725 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -322,7 +322,9 @@ public Collection createComponents(PluginServices services) { var eisSender = elasicInferenceServiceFactory.get().createSender(); var preconfigEndpointsHandler = new PreconfiguredEndpointsRequestHandler(authorizationHandler, eisSender); - modelRegistry.set(new ModelRegistry(services.clusterService(), services.client(), preconfigEndpointsHandler)); + modelRegistry.set( + new ModelRegistry(services.clusterService(), services.client(), preconfigEndpointsHandler) + ); services.clusterService().addListener(modelRegistry.get()); var sageMakerSchemas = new SageMakerSchemas(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index f8c6fe7442178..ccc643b668d20 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -242,6 +242,12 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId throw new IllegalStateException("initial cluster state not set yet"); } } + + var eisConfig = ElasticInferenceServiceMinimalSettings.getWithInferenceId(inferenceEntityId); + if (eisConfig != null) { + return eisConfig.minimalSettings(); + } + var config = defaultConfigIds.get(inferenceEntityId); if (config != null) { return config.settings(); @@ -307,57 +313,81 @@ private void getModelHelper( // If we know it's an EIS preconfigured endpoint, skip looking in the index because it could have an outdated version of the // endpoint and go directly to EIS to retrieve it if (ElasticInferenceServiceMinimalSettings.isEisPreconfiguredEndpoint(inferenceEntityId)) { - retrievePreconfiguredEndpointFromEisElseNotAuthorized(listener, inferenceEntityId); + retrievePreconfiguredEndpointFromEisElseEisError(listener, inferenceEntityId); return; } - SubscribableListener.newForked(searchResponseListener -> client.search(modelSearch, searchResponseListener)) - .andThen((unparsedModelListener, searchResponse) -> { - // We likely found the configuration, so parse it and return it - if (searchResponse.getHits().getHits().length != 0) { - unparsedModelListener.onResponse(unparsedModelCreator.apply(searchResponse)); - return; - } + var failureListener = listener.delegateResponse((delegate, e) -> { + // If the inference endpoint does not exist, we've already created a well-defined exception, so just return it + if (e instanceof ElasticsearchException) { + delegate.onFailure(e); + return; + } - // we didn't find the configuration in the inference index, so check if it is a pre-configured endpoint - var maybeDefault = defaultConfigIds.get(inferenceEntityId); - if (maybeDefault != null) { - getDefaultConfig(true, maybeDefault, unparsedModelListener); - return; - } + logger.warn(format("Failed to load inference endpoint [%s]", inferenceEntityId), e); - retrievePreconfiguredEndpointFromEisElseNotFound(unparsedModelListener, inferenceEntityId); - }) - .addListener(listener.delegateResponse((failureListener, e) -> { - logger.warn(format("Failed to load inference endpoint [%s]", inferenceEntityId), e); - failureListener.onFailure( - new ElasticsearchException( - format("Failed to load inference endpoint [%s], error: [%s]", inferenceEntityId, e.getMessage()), - e - ) - ); - })); + delegate.onFailure( + new ElasticsearchException( + format("Failed to load inference endpoint [%s], error: [%s]", inferenceEntityId, e.getMessage()), + e + ) + ); + }); + + ActionListener searchResponseListener = failureListener.delegateFailureAndWrap( + (delegate, searchResponse) -> searchForEndpointInDefaultAndEis( + inferenceEntityId, + searchResponse, + unparsedModelCreator, + delegate + ) + ); + + client.search(modelSearch, searchResponseListener); } - private void retrievePreconfiguredEndpointFromEisElseNotAuthorized(ActionListener listener, String inferenceEntityId) { + private void searchForEndpointInDefaultAndEis( + String inferenceEntityId, + SearchResponse searchResponse, + Function unparsedModelCreator, + ActionListener listener + ) { + // We likely found the configuration, so parse it and return it + if (searchResponse.getHits().getHits().length != 0) { + listener.onResponse(unparsedModelCreator.apply(searchResponse)); + return; + } + + // we didn't find the configuration in the inference index, so check if it is a pre-configured endpoint + var maybeDefault = defaultConfigIds.get(inferenceEntityId); + if (maybeDefault != null) { + getDefaultConfig(true, maybeDefault, listener); + return; + } + + retrievePreconfiguredEndpointFromEisElseNotFound(listener, inferenceEntityId); + } + + private void retrievePreconfiguredEndpointFromEisElseEisError(ActionListener listener, String inferenceEntityId) { var eisFailureListener = listener.delegateResponse( - (delegate, e) -> { delegate.onFailure(eisNotAuthorizedException(inferenceEntityId)); } + (delegate, e) -> delegate.onFailure(eisBadRequestException(inferenceEntityId, e)) ); retrieveEisPreconfiguredEndpoint(eisFailureListener, inferenceEntityId); } - private ElasticsearchStatusException eisNotAuthorizedException(String inferenceEntityId) { + private ElasticsearchStatusException eisBadRequestException(String inferenceEntityId, Exception exception) { return new ElasticsearchStatusException( - "Unauthorized to access inference endpoint [{}]", - RestStatus.UNAUTHORIZED, + "Unable to retrieve the preconfigured inference endpoint [{}] from the Elastic Inference Service", + RestStatus.BAD_REQUEST, + exception, inferenceEntityId ); } private void retrievePreconfiguredEndpointFromEisElseNotFound(ActionListener listener, String inferenceEntityId) { var eisFailureListener = listener.delegateResponse( - (delegate, e) -> { delegate.onFailure(inferenceNotFoundException(inferenceEntityId)); } + (delegate, e) -> delegate.onFailure(inferenceNotFoundException(inferenceEntityId)) ); retrieveEisPreconfiguredEndpoint(eisFailureListener, inferenceEntityId); @@ -391,28 +421,43 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> defaultConfigIdsSupplier, boolean persistDefaultEndpoints, ActionListener> listener ) { - SubscribableListener.newForked(searchResponseListener -> { - var eisEndpointIds = EIS_PRECONFIGURED_ENDPOINTS.stream().map(Model::documentId).toArray(String[]::new); + ActionListener searchResponseListener = listener.delegateFailureAndWrap( + (delegate, searchResponse) -> includeDefaultAndEisEndpoints( + searchResponse, + defaultConfigIdsSupplier, + persistDefaultEndpoints, + delegate + ) + ); - // exclude the EIS preconfigured endpoints so we can query EIS directly for them - var queryBuilder = boolQueryBuilder.filter(QueryBuilders.boolQuery().mustNot(QueryBuilders.idsQuery().addIds(eisEndpointIds))); + var eisEndpointIds = EIS_PRECONFIGURED_ENDPOINTS.stream().map(Model::documentId).toArray(String[]::new); - var modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) - .setQuery(queryBuilder) - // .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.existsQuery(TASK_TYPE_FIELD))) - .setSize(10_000) - .setTrackTotalHits(false) - .addSort(MODEL_ID_FIELD, SortOrder.ASC) - .request(); + // exclude the EIS preconfigured endpoints so we can query EIS directly for them + var queryBuilder = boolQueryBuilder.filter(QueryBuilders.boolQuery().mustNot(QueryBuilders.idsQuery().addIds(eisEndpointIds))); + + var modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) + .setQuery(queryBuilder) + .setSize(10_000) + .setTrackTotalHits(false) + .addSort(MODEL_ID_FIELD, SortOrder.ASC) + .request(); + + client.search(modelSearch, searchResponseListener); + } - client.search(modelSearch, searchResponseListener); - }).>andThen((missingDefaultConfigsAddedListener, searchResponse) -> { + private void includeDefaultAndEisEndpoints( + SearchResponse searchResponse, + Supplier> defaultConfigIdsSupplier, + boolean persistDefaultEndpoints, + ActionListener> listener + ) { + SubscribableListener.>newForked(missingDefaultConfigsAddedListener -> { var modelConfigs = parseHitsAsModelsWithoutSecrets(searchResponse.getHits()).stream() .map(ModelRegistry::unparsedModelFromMap) .toList(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index aae03ed13cd05..02850015bb632 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -52,7 +52,6 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionCreator; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; @@ -133,7 +132,6 @@ public static String defaultEndpointId(String modelId) { } private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; - private final ElasticInferenceServiceAuthorizationHandler authorizationHandler; public ElasticInferenceService( HttpRequestSender.Factory factory, @@ -165,16 +163,6 @@ public ElasticInferenceService( this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( elasticInferenceServiceSettings.getElasticInferenceServiceUrl() ); - authorizationHandler = new ElasticInferenceServiceAuthorizationHandler( - serviceComponents, - modelRegistry, - authorizationRequestHandler, - initDefaultEndpoints(elasticInferenceServiceComponents), - IMPLEMENTED_TASK_TYPES, - this, - getSender(), - elasticInferenceServiceSettings - ); } private static Map initDefaultEndpoints( @@ -243,11 +231,6 @@ private static Map initDefaultEndpoints( ); } - @Override - public void onNodeStarted() { - authorizationHandler.init(); - } - @Override protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) { if (returnDocuments != null) { @@ -260,32 +243,11 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V } } - /** - * Only use this in tests. - * - * Waits the specified amount of time for the authorization call to complete. This is mainly to make testing easier. - * @param waitTime the max time to wait - * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} - */ - public void waitForFirstAuthorizationToComplete(TimeValue waitTime) { - authorizationHandler.waitForAuthorizationToComplete(waitTime); - } - @Override public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION); } - @Override - public List defaultConfigIds() { - return authorizationHandler.defaultConfigIds(); - } - - @Override - public void defaultConfigs(ActionListener> defaultsListener) { - authorizationHandler.defaultConfigs(defaultsListener); - } - @Override protected void doUnifiedCompletionInfer( Model model, @@ -462,7 +424,9 @@ public InferenceServiceConfiguration getConfiguration() { @Override public EnumSet supportedTaskTypes() { - return authorizationHandler.supportedTaskTypes(); + throw new UnsupportedOperationException( + "The EIS supported task types change depending on authorization, requests should be made directly to EIS instead" + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java index 1b349737a16ac..61aa514ae700a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java @@ -20,8 +20,8 @@ public class ElasticInferenceServiceMinimalSettings { // rainbow-sprinkles - static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; - static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); + public static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; + public static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); // elser-2 static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index 0d8bef246b35d..a2144bc9013ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -28,7 +28,8 @@ public class ElasticInferenceServiceSettings { @Deprecated static final Setting EIS_GATEWAY_URL = Setting.simpleString("xpack.inference.eis.gateway.url", Setting.Property.NodeScope); - static final Setting ELASTIC_INFERENCE_SERVICE_URL = Setting.simpleString( + // public so tests can access it + public static final Setting ELASTIC_INFERENCE_SERVICE_URL = Setting.simpleString( "xpack.inference.elastic.url", Setting.Property.NodeScope ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java index d916346863731..b66569ac4db1c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java @@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +// TODO remove this class public class ElasticInferenceServiceAuthorizationHandler implements Closeable { private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandler.class); @@ -246,8 +247,9 @@ private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizat new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) ); - // TODO remove adding it to the registry, I think we can still revoke for now though - // authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent); + // We are no longer added the authorized preconfigured endpoints to the model registry. The model registry will reach out to + // the EIS gateway directly to get the model information. For now, I'm leaving the revoking logic in place but it will be removed + // when the authorization polling logic is moved to the master node. handleRevokedDefaultConfigs(authorizedDefaultModelIds); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index 02800105ef83d..3068d44bd526c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -79,7 +79,7 @@ public void getAuthorization(ActionListener> extends MapperServiceTestCase { @@ -605,7 +607,9 @@ protected static void disableQueryInterception(QueryRewriteContext queryRewriteC private static ModelRegistry createModelRegistry(ThreadPool threadPool) { ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool); - ModelRegistry modelRegistry = spy(new ModelRegistry(clusterService, new NoOpClient(threadPool))); + ModelRegistry modelRegistry = spy( + new ModelRegistry(clusterService, new NoOpClient(threadPool), mock(PreconfiguredEndpointsRequestHandler.class)) + ); modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { @Override public boolean localNodeMaster() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryBuilderTests.java index b54ca946e6179..cce0f8070f7c5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryBuilderTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.authorization.PreconfiguredEndpointsRequestHandler; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -32,6 +33,8 @@ import java.util.List; import java.util.function.Supplier; +import static org.mockito.Mockito.mock; + public class SemanticMultiMatchQueryBuilderTests extends MapperServiceTestCase { private static TestThreadPool threadPool; private static ModelRegistry modelRegistry; @@ -51,7 +54,7 @@ protected Supplier getModelRegistry() { public static void startModelRegistry() { threadPool = new TestThreadPool(SemanticMultiMatchQueryBuilderTests.class.getName()); var clusterService = ClusterServiceUtils.createClusterService(threadPool); - modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool), mock(PreconfiguredEndpointsRequestHandler.class)); modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { @Override public boolean localNodeMaster() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index b2d7218720a57..4e9ccf7607c81 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -75,6 +75,7 @@ import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.authorization.PreconfiguredEndpointsRequestHandler; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -100,6 +101,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Mockito.mock; public class SemanticQueryBuilderTests extends AbstractQueryTestCase { private static final String SEMANTIC_TEXT_FIELD = "semantic"; @@ -144,7 +146,7 @@ public static void setInferenceResultType() { public static void startModelRegistry() { threadPool = new TestThreadPool(SemanticQueryBuilderTests.class.getName()); var clusterService = ClusterServiceUtils.createClusterService(threadPool); - modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool), mock(PreconfiguredEndpointsRequestHandler.class)); modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { @Override public boolean localNodeMaster() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index eee8550ec6524..24979ac507cd5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -25,6 +25,8 @@ import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; import org.junit.Before; import java.util.ArrayList; @@ -237,6 +239,20 @@ public void testDuplicateDefaultIds() { ); } + public void testGetMinimalServiceSettings_ThrowsResourceNotFound_WhenInferenceIdDoesNotExist() { + var exception = expectThrows(ResourceNotFoundException.class, () -> registry.getMinimalServiceSettings("non_existent_id")); + assertThat(exception.getMessage(), containsString("non_existent_id does not exist in this cluster.")); + } + + public void testGetMinimalServiceSettings_ReturnsEisPreconfiguredEndpoint() { + var minimalSettings = registry.getMinimalServiceSettings( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 + ); + + assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); + assertThat(minimalSettings.taskType(), is(TaskType.CHAT_COMPLETION)); + } + public static void assertStoreModel(ModelRegistry registry, Model model) { PlainActionFuture storeListener = new PlainActionFuture<>(); registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 760312fe2d97b..0a806cad3995a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -921,8 +921,6 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio public void testHideFromConfigurationApi_ThrowsUnsupported_WithNoAvailableModels() throws Exception { try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } } @@ -942,137 +940,104 @@ public void testHideFromConfigurationApi_ThrowsUnsupported_WithAvailableModels() ) ) ) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } } public void testCreateConfiguration() throws Exception { - try ( - var service = createServiceWithMockSender( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) - ) - ) - ) - ) - ) - ) { - ensureAuthorizationCallFinished(service); - - String content = XContentHelper.stripWhitespace(""" - { - "service": "elastic", - "name": "Elastic", - "task_types": ["sparse_embedding", "chat_completion", "text_embedding"], - "configurations": { - "rate_limit.requests_per_minute": { - "description": "Minimize the number of rate limit errors.", - "label": "Rate Limit", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] - }, - "model_id": { - "description": "The name of the model to use for the inference task.", - "label": "Model ID", - "required": true, - "sensitive": false, - "updatable": false, - "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] - }, - "max_input_tokens": { - "description": "Allows you to specify the maximum number of tokens per input.", - "label": "Maximum Input Tokens", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding"] - } + String content = XContentHelper.stripWhitespace(""" + { + "service": "elastic", + "name": "Elastic", + "task_types": ["sparse_embedding", "chat_completion", "text_embedding"], + "configurations": { + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] + }, + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] + }, + "max_input_tokens": { + "description": "Allows you to specify the maximum number of tokens per input.", + "label": "Maximum Input Tokens", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding"] } } - """); - InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( - new BytesArray(content), - XContentType.JSON - ); - boolean humanReadable = true; - BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); - InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) - ); - assertToXContentEquivalent( - originalBytes, - toXContent(serviceConfiguration, XContentType.JSON, humanReadable), - XContentType.JSON - ); - } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) + ); + assertToXContentEquivalent(originalBytes, toXContent(serviceConfiguration, XContentType.JSON, humanReadable), XContentType.JSON); } public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { - try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { - ensureAuthorizationCallFinished(service); - - String content = XContentHelper.stripWhitespace(""" - { - "service": "elastic", - "name": "Elastic", - "task_types": [], - "configurations": { - "rate_limit.requests_per_minute": { - "description": "Minimize the number of rate limit errors.", - "label": "Rate Limit", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] - }, - "model_id": { - "description": "The name of the model to use for the inference task.", - "label": "Model ID", - "required": true, - "sensitive": false, - "updatable": false, - "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] - }, - "max_input_tokens": { - "description": "Allows you to specify the maximum number of tokens per input.", - "label": "Maximum Input Tokens", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding"] - } + String content = XContentHelper.stripWhitespace(""" + { + "service": "elastic", + "name": "Elastic", + "task_types": [], + "configurations": { + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] + }, + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] + }, + "max_input_tokens": { + "description": "Allows you to specify the maximum number of tokens per input.", + "label": "Maximum Input Tokens", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding"] } } - """); - InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( - new BytesArray(content), - XContentType.JSON - ); - boolean humanReadable = true; - BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); - InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( - EnumSet.noneOf(TaskType.class) - ); - assertToXContentEquivalent( - originalBytes, - toXContent(serviceConfiguration, XContentType.JSON, humanReadable), - XContentType.JSON - ); - } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration(EnumSet.noneOf(TaskType.class)); + assertToXContentEquivalent(originalBytes, toXContent(serviceConfiguration, XContentType.JSON, humanReadable), XContentType.JSON); } public void testGetConfiguration_ThrowsUnsupported() throws Exception { @@ -1091,8 +1056,6 @@ public void testGetConfiguration_ThrowsUnsupported() throws Exception { ) ) ) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::getConfiguration); } } @@ -1113,8 +1076,6 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); assertTrue(service.defaultConfigIds().isEmpty()); @@ -1145,8 +1106,6 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimple var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); } } @@ -1171,8 +1130,6 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Excep var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); } } @@ -1193,8 +1150,6 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertTrue(service.defaultConfigIds().isEmpty()); assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); @@ -1221,7 +1176,6 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIn var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertThat( service.defaultConfigIds(), @@ -1271,7 +1225,6 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); assertThat( @@ -1410,11 +1363,6 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp } } - private void ensureAuthorizationCallFinished(ElasticInferenceService service) { - service.onNodeStarted(); - service.waitForFirstAuthorizationToComplete(TIMEOUT); - } - private ElasticInferenceService createServiceWithMockSender() { return createServiceWithMockSender(ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java index e42430b6512f5..6e6bfc8f8c910 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.EmptySecretSettings; @@ -52,6 +53,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +// TODO remove this class public class ElasticInferenceServiceAuthorizationHandlerTests extends ESSingleNodeTestCase { private DeterministicTaskQueue taskQueue; private ModelRegistry modelRegistry; @@ -61,6 +63,12 @@ protected Collection> getPlugins() { return List.of(LocalStateInferencePlugin.class); } + // TODO add the EIS url here + @Override + protected Settings nodeSettings() { + return Settings.EMPTY; + } + @Before public void init() throws Exception { taskQueue = new DeterministicTaskQueue(); From e6eed4ffcdfcf56d2741a70f0da9ef5d63023815 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 8 Oct 2025 14:10:50 +0000 Subject: [PATCH 07/18] [CI] Auto commit changes from spotless --- .../org/elasticsearch/xpack/inference/InferencePlugin.java | 4 +--- .../authorization/PreconfiguredEndpointsRequestHandler.java | 5 +---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 21ce5b79de725..447f33308284d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -322,9 +322,7 @@ public Collection createComponents(PluginServices services) { var eisSender = elasicInferenceServiceFactory.get().createSender(); var preconfigEndpointsHandler = new PreconfiguredEndpointsRequestHandler(authorizationHandler, eisSender); - modelRegistry.set( - new ModelRegistry(services.clusterService(), services.client(), preconfigEndpointsHandler) - ); + modelRegistry.set(new ModelRegistry(services.clusterService(), services.client(), preconfigEndpointsHandler)); services.clusterService().addListener(modelRegistry.get()); var sageMakerSchemas = new SageMakerSchemas(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java index 6cfe163b21904..fdd3a0e290611 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java @@ -43,9 +43,6 @@ public void getPreconfiguredEndpointAsUnparsedModel(String inferenceId, ActionLi public void getAllPreconfiguredEndpointsAsUnparsedModels(ActionListener> listener) { SubscribableListener.newForked(authListener -> { eisAuthorizationRequestHandler.getAuthorization(authListener, sender); - }) - .andThenApply(PreconfiguredEndpointsModel::of) - .andThenApply(PreconfiguredEndpointsModel::toUnparsedModels) - .addListener(listener); + }).andThenApply(PreconfiguredEndpointsModel::of).andThenApply(PreconfiguredEndpointsModel::toUnparsedModels).addListener(listener); } } From 5fbb7405f2d19736bdb9ceb210250c9f70047249 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 8 Oct 2025 15:47:18 -0400 Subject: [PATCH 08/18] Adding more tests --- .../integration/ModelRegistryEisBase.java | 89 +++ .../ModelRegistryEisGetModelIT.java | 507 ++++++++++++++++++ .../integration/ModelRegistryEisIT.java | 166 ++++++ .../ModelRegistryEisInvalidUrlIT.java | 41 ++ .../integration/ModelRegistryIT.java | 44 -- .../inference/registry/ModelRegistry.java | 19 +- ...lasticInferenceServiceMinimalSettings.java | 14 +- .../PreconfiguredEndpointsModel.java | 27 +- .../registry/ModelRegistryTests.java | 38 +- 9 files changed, 871 insertions(+), 74 deletions(-) create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBase.java create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisGetModelIT.java create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBase.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBase.java new file mode 100644 index 0000000000000..83c408fa35012 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBase.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelRegistryTests; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.integration.ModelRegistryIT.createModel; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL; + +public class ModelRegistryEisBase extends ESSingleNodeTestCase { + protected static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + protected static final MockWebServer webServer = new MockWebServer(); + + protected ModelRegistry modelRegistry; + private String eisUrl; + + public ModelRegistryEisBase() {} + + public ModelRegistryEisBase(String eisUrl) { + this.eisUrl = eisUrl; + } + + @BeforeClass + public static void init() throws Exception { + webServer.start(); + } + + @AfterClass + public static void shutdown() { + webServer.close(); + } + + @Before + public void createComponents() { + modelRegistry = node().injector().getInstance(ModelRegistry.class); + modelRegistry.clearDefaultIds(); + } + + @Override + protected Collection> getPlugins() { + return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); + } + + @Override + protected Settings nodeSettings() { + return Settings.builder().put(super.nodeSettings()).put(ELASTIC_INFERENCE_SERVICE_URL.getKey(), getEisUrl()).build(); + } + + private String getEisUrl() { + return eisUrl != null ? eisUrl : getUrl(webServer); + } + + protected void initializeModels() { + var service = "foo"; + var sparseAndTextEmbeddingModels = new ArrayList(); + sparseAndTextEmbeddingModels.add(createModel("sparse-1", TaskType.SPARSE_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel("sparse-2", TaskType.SPARSE_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel("sparse-3", TaskType.SPARSE_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel("embedding-1", TaskType.TEXT_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel("embedding-2", TaskType.TEXT_EMBEDDING, service)); + + for (var model : sparseAndTextEmbeddingModels) { + ModelRegistryTests.assertStoreModel(modelRegistry, model); + } + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisGetModelIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisGetModelIT.java new file mode 100644 index 0000000000000..4cd097517059d --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisGetModelIT.java @@ -0,0 +1,507 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; + +import java.util.Arrays; +import java.util.Map; +import java.util.function.BiConsumer; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; + +public class ModelRegistryEisGetModelIT extends ModelRegistryEisBase { + private final TestCase testCase; + + public ModelRegistryEisGetModelIT(TestCase testCase) { + super(); + this.testCase = testCase; + } + + public record TestCase( + String description, + BiConsumer> registryCall, + String responseJson, + @Nullable UnparsedModel expectedResult, + @Nullable String failureMessage, + @Nullable RestStatus failureStatus + ) {} + + private static class TestCaseBuilder { + private final String description; + private final BiConsumer> registryCall; + private final String responseJson; + private UnparsedModel expectedResult; + private String failureMessage; + private RestStatus failureStatus; + + TestCaseBuilder(String description, BiConsumer> registryCall, String responseJson) { + this.description = description; + this.registryCall = registryCall; + this.responseJson = responseJson; + } + + public TestCaseBuilder withSuccessfulResult(UnparsedModel expectedResult) { + this.expectedResult = expectedResult; + return this; + } + + public TestCaseBuilder withFailure(String failure, RestStatus status) { + this.failureMessage = failure; + this.failureStatus = status; + return this; + } + + public TestCase build() { + return new TestCase(description, registryCall, responseJson, expectedResult, failureMessage, failureStatus); + } + } + + @ParametersFactory + public static Iterable parameters() { + return Arrays.asList( + new TestCase[][] { + // getModel calls + { + new TestCaseBuilder( + "getModel retrieves eis chat completion preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModel throws an exception when retrieving eis " + + "chat completion preconfigured endpoint and it isn't authorized", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint " + + "[.rainbow-sprinkles-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModel retrieves eis elser preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + listener + ), + """ + { + "models": [ + { + "model_name": "elser_model_2", + "task_types": ["embed/text/sparse"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_2_MODEL_ID) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModel throws exception when retrieving eis elser preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint [.elser-2-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModel retrieves eis multilingual embed preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + listener + ), + """ + { + "models": [ + { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of( + ServiceFields.MODEL_ID, + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + ServiceFields.SIMILARITY, + SimilarityMeasure.COSINE.toString(), + ServiceFields.DIMENSIONS, + ElasticInferenceServiceMinimalSettings.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ServiceFields.ELEMENT_TYPE, + DenseVectorFieldMapper.ElementType.FLOAT.toString() + ) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModel throws exception when retrieving eis multilingual embed preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint " + + "[.multilingual-embed-v1-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModel retrieves eis rerank preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + { + "model_name": "rerank-v1", + "task_types": ["rerank/text/text-similarity"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + TaskType.RERANK, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_MODEL_ID_V1) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModel throws exception when retrieving eis rerank preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint [.rerank-v1-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + // getModelWithSecrets calls + { + new TestCaseBuilder( + "getModelWithSecrets retrieves eis chat completion preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets throws an exception when retrieving eis " + + "chat completion preconfigured endpoint and it isn't authorized", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint " + + "[.rainbow-sprinkles-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets retrieves eis elser preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + listener + ), + """ + { + "models": [ + { + "model_name": "elser_model_2", + "task_types": ["embed/text/sparse"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_2_MODEL_ID) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets throws exception when retrieving eis elser preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint [.elser-2-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets retrieves eis multilingual embed preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + listener + ), + """ + { + "models": [ + { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of( + ServiceFields.MODEL_ID, + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + ServiceFields.SIMILARITY, + SimilarityMeasure.COSINE.toString(), + ServiceFields.DIMENSIONS, + ElasticInferenceServiceMinimalSettings.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ServiceFields.ELEMENT_TYPE, + DenseVectorFieldMapper.ElementType.FLOAT.toString() + ) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets throws exception when retrieving eis " + + "multilingual embed preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint " + + "[.multilingual-embed-v1-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets retrieves eis rerank preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + { + "model_name": "rerank-v1", + "task_types": ["rerank/text/text-similarity"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + TaskType.RERANK, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_MODEL_ID_V1) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets throws exception when retrieving eis rerank preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint [.rerank-v1-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() } } + ); + } + + public void test() { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCase.responseJson)); + + PlainActionFuture listener = new PlainActionFuture<>(); + testCase.registryCall.accept(modelRegistry, listener); + + if (testCase.expectedResult != null) { + assertSuccessfulTestCase(listener); + } else { + assertFailureTestCase(listener); + } + } + + private void assertSuccessfulTestCase(PlainActionFuture listener) { + var model = listener.actionGet(TIMEOUT); + assertThat(model, is(testCase.expectedResult)); + } + + private void assertFailureTestCase(PlainActionFuture listener) { + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString(testCase.failureMessage)); + assertThat(exception.status(), is(testCase.failureStatus)); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java new file mode 100644 index 0000000000000..eb268a6a177f0 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.test.http.MockResponse; + +import java.util.List; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.is; + +public class ModelRegistryEisIT extends ModelRegistryEisBase { + + private static final String eisAuthorizedResponse = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + }, + { + "model_name": "elser_model_2", + "task_types": ["embed/text/sparse"] + }, + { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] + }, + { + "model_name": "rerank-v1", + "task_types": ["rerank/text/text-similarity"] + } + ] + } + """; + + private static final String eisUnauthorizedResponse = """ + { + "models": [ + ] + } + """; + + public void testGetModelsByTaskType() { + initializeModels(); + + { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getModelsByTaskType(TaskType.COMPLETION, listener); + + assertThat(listener.actionGet(TIMEOUT), is(List.of())); + } + { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of("sparse-1", "sparse-2", "sparse-3", ".elser-2-elastic").toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } + { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of("embedding-1", "embedding-2", ".multilingual-embed-v1-elastic").toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } + { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getModelsByTaskType(TaskType.CHAT_COMPLETION, listener); + + var results = listener.actionGet(TIMEOUT); + assertThat(results.size(), is(1)); + assertThat( + results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), + containsInAnyOrder(".rainbow-sprinkles-elastic") + ); + } + { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getModelsByTaskType(TaskType.RERANK, listener); + + var results = listener.actionGet(TIMEOUT); + assertThat(results.size(), is(1)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(".rerank-v1-elastic")); + } + } + + public void testGetAllModels() { + initializeModels(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getAllModels(false, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of( + "sparse-1", + "sparse-2", + "sparse-3", + "embedding-1", + "embedding-2", + ".elser-2-elastic", + ".multilingual-embed-v1-elastic", + ".rainbow-sprinkles-elastic", + ".rerank-v1-elastic" + ).toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } + + public void testGetAllModelsNoEisResults() { + initializeModels(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisUnauthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getAllModels(false, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of( + "sparse-1", + "sparse-2", + "sparse-3", + "embedding-1", + "embedding-2" + ).toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } + + public void testGetAllModelsEisReturnsFailureStatusCode() { + initializeModels(); + + webServer.enqueue(new MockResponse().setResponseCode(500).setBody("{}")); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getAllModels(false, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of( + "sparse-1", + "sparse-2", + "sparse-3", + "embedding-1", + "embedding-2" + ).toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java new file mode 100644 index 0000000000000..c0d4de0693baf --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.inference.UnparsedModel; + +import java.util.List; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.is; + +public class ModelRegistryEisInvalidUrlIT extends ModelRegistryEisBase { + public ModelRegistryEisInvalidUrlIT() { + super(""); + } + + public void testGetAllModelsDoesNotReturnEisModels_WhenEisUrlIsEmpty() { + initializeModels(); + + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getAllModels(false, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of( + "sparse-1", + "sparse-2", + "sparse-3", + "embedding-1", + "embedding-2" + ).toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 71b2c7acc323e..9aac2c44fc2e5 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -36,8 +36,6 @@ import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.http.MockResponse; -import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -46,14 +44,11 @@ import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistryTests; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettingsTests; -import org.junit.AfterClass; import org.junit.Before; -import org.junit.BeforeClass; import java.io.IOException; import java.util.ArrayList; @@ -70,8 +65,6 @@ import java.util.function.Function; import java.util.stream.Collectors; -import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; @@ -88,17 +81,6 @@ public class ModelRegistryIT extends ESSingleNodeTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private ModelRegistry modelRegistry; - private static final MockWebServer webServer = new MockWebServer(); - - @BeforeClass - public static void init() throws Exception { - webServer.start(); - } - - @AfterClass - public static void shutdown() { - webServer.close(); - } @Before public void createComponents() { @@ -111,11 +93,6 @@ protected Collection> getPlugins() { return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); } - @Override - protected Settings nodeSettings() { - return Settings.builder().put(super.nodeSettings()).put(ELASTIC_INFERENCE_SERVICE_URL.getKey(), getUrl(webServer)).build(); - } - public void testStoreModel() { String inferenceEntityId = "test-store-model"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); @@ -600,27 +577,6 @@ public void testGetByTaskType_WithDefaults() throws Exception { assertReturnModelIsModifiable(modelHolder.get().get(0)); } - public void testGetModel_RetrievesAnEisPreconfiguredEndpoint() { - var responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - PlainActionFuture listener = new PlainActionFuture<>(); - modelRegistry.getModel(ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, listener); - - var model = listener.actionGet(TIMEOUT); - assertThat(model.inferenceEntityId(), is(ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); - assertThat(model.taskType(), is(TaskType.CHAT_COMPLETION)); - } - private void assertInferenceIndexExists() { var indexResponse = client().admin().indices().prepareGetIndex(TEST_REQUEST_TIMEOUT).addIndices(".inference").get(); assertNotNull(indexResponse.getSettings()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index ccc643b668d20..c1781910ede40 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -96,6 +96,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -416,6 +417,7 @@ public void getModelsByTaskType(TaskType taskType, ActionListener taskTypeMatchedDefaults(taskType, defaultConfigIds.values()), + unparsedModel -> unparsedModel.taskType() == taskType, true, listener ); @@ -424,6 +426,7 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> defaultConfigIdsSupplier, + Predicate eisResponseFilter, boolean persistDefaultEndpoints, ActionListener> listener ) { @@ -431,6 +434,7 @@ private void getModelsHelper( (delegate, searchResponse) -> includeDefaultAndEisEndpoints( searchResponse, defaultConfigIdsSupplier, + eisResponseFilter, persistDefaultEndpoints, delegate ) @@ -454,6 +458,7 @@ private void getModelsHelper( private void includeDefaultAndEisEndpoints( SearchResponse searchResponse, Supplier> defaultConfigIdsSupplier, + Predicate eisResponseFilter, boolean persistDefaultEndpoints, ActionListener> listener ) { @@ -467,15 +472,18 @@ private void includeDefaultAndEisEndpoints( defaultConfigIdsSupplier.get(), missingDefaultConfigsAddedListener ); - }).>andThen((eisPreconfiguredEndpointsAddedListener, unparsedModels) -> { - ActionListener> eisListener = ActionListener.wrap(response -> { - var allModels = new ArrayList<>(unparsedModels); - allModels.addAll(response); + }).>andThen((eisPreconfiguredEndpointsAddedListener, defaultModelsAndFromIndex) -> { + ActionListener> eisListener = ActionListener.wrap(allEisAuthorizedModels -> { + var filteredEisModels = allEisAuthorizedModels.stream().filter(eisResponseFilter).toList(); + + var allModels = new ArrayList<>(defaultModelsAndFromIndex); + allModels.addAll(filteredEisModels); allModels.sort(Comparator.comparing(UnparsedModel::inferenceEntityId)); + eisPreconfiguredEndpointsAddedListener.onResponse(allModels); }, e -> { logger.debug("Failed to retrieve preconfigured endpoint from EIS", e); - eisPreconfiguredEndpointsAddedListener.onResponse(unparsedModels); + eisPreconfiguredEndpointsAddedListener.onResponse(defaultModelsAndFromIndex); }); preconfiguredEndpointsRequestHandler.getAllPreconfiguredEndpointsAsUnparsedModels(eisListener); @@ -498,6 +506,7 @@ public void getAllModels(boolean persistDefaultEndpoints, ActionListener new ArrayList<>(defaultConfigIds.values()), + eisResponse -> true, persistDefaultEndpoints, listener ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java index 61aa514ae700a..52f6f6fa074aa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java @@ -24,17 +24,17 @@ public class ElasticInferenceServiceMinimalSettings { public static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); // elser-2 - static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; - static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2"); + public static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; + public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2"); // multilingual-text-embed - static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024; - static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed-v1"; - static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); + public static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024; + public static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed-v1"; + public static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); // rerank-v1 - static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; - static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); + public static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; + public static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); public static final Set EIS_PRECONFIGURED_ENDPOINTS = Set.of( DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java index 9932e041244e6..bf75bf1286fe5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.elastic.authorization; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; @@ -26,9 +27,6 @@ public record PreconfiguredEndpointsModel(Map preconfiguredEndpoints) { public static PreconfiguredEndpointsModel of(ElasticInferenceServiceAuthorizationModel authModel) { - // TODO convert the auth model to a list of preconfigured endpoints - // iterate over the authorized model ids and retrieve the configurations from a new class that has the information - var endpoints = authModel.getAuthorizedModelIds() .stream() .filter(ElasticInferenceServiceMinimalSettings::containsModelName) @@ -106,14 +104,19 @@ private static Map embeddingSettings( ) { return new HashMap<>( Map.of( - ServiceFields.MODEL_ID, - modelId, - ServiceFields.SIMILARITY, - similarityMeasure.toString(), - ServiceFields.DIMENSIONS, - dimension, - ServiceFields.ELEMENT_TYPE, - elementType.toString() + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + ServiceFields.SIMILARITY, + similarityMeasure.toString(), + ServiceFields.DIMENSIONS, + dimension, + ServiceFields.ELEMENT_TYPE, + elementType.toString() + ) + ) ) ); } @@ -126,7 +129,7 @@ public UnparsedModel toUnparsedModel() { } private static Map settingsWithModelId(String modelId) { - return new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)); + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)))); } public UnparsedModel toUnparsedModel(String inferenceId) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 24979ac507cd5..432f8321b51a1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -245,12 +245,38 @@ public void testGetMinimalServiceSettings_ThrowsResourceNotFound_WhenInferenceId } public void testGetMinimalServiceSettings_ReturnsEisPreconfiguredEndpoint() { - var minimalSettings = registry.getMinimalServiceSettings( - ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 - ); - - assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); - assertThat(minimalSettings.taskType(), is(TaskType.CHAT_COMPLETION)); + { + var minimalSettings = registry.getMinimalServiceSettings( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 + ); + + assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); + assertThat(minimalSettings.taskType(), is(TaskType.CHAT_COMPLETION)); + } + { + var minimalSettings = registry.getMinimalServiceSettings( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2 + ); + + assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); + assertThat(minimalSettings.taskType(), is(TaskType.SPARSE_EMBEDDING)); + } + { + var minimalSettings = registry.getMinimalServiceSettings( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID + ); + + assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); + assertThat(minimalSettings.taskType(), is(TaskType.TEXT_EMBEDDING)); + } + { + var minimalSettings = registry.getMinimalServiceSettings( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1 + ); + + assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); + assertThat(minimalSettings.taskType(), is(TaskType.RERANK)); + } } public static void assertStoreModel(ModelRegistry registry, Model model) { From b29d60d9a0878af0922a45e30b57f32d5bc94d31 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 8 Oct 2025 20:11:50 +0000 Subject: [PATCH 09/18] [CI] Auto commit changes from spotless --- .../integration/ModelRegistryEisIT.java | 16 ++-------------- .../ModelRegistryEisInvalidUrlIT.java | 8 +------- .../inference/registry/ModelRegistryTests.java | 8 ++------ 3 files changed, 5 insertions(+), 27 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java index eb268a6a177f0..9891929c66468 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java @@ -134,13 +134,7 @@ public void testGetAllModelsNoEisResults() { modelRegistry.getAllModels(false, listener); var results = listener.actionGet(TIMEOUT); - var expected = Stream.of( - "sparse-1", - "sparse-2", - "sparse-3", - "embedding-1", - "embedding-2" - ).toArray(String[]::new); + var expected = Stream.of("sparse-1", "sparse-2", "sparse-3", "embedding-1", "embedding-2").toArray(String[]::new); assertThat(results.size(), is(expected.length)); assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); } @@ -153,13 +147,7 @@ public void testGetAllModelsEisReturnsFailureStatusCode() { modelRegistry.getAllModels(false, listener); var results = listener.actionGet(TIMEOUT); - var expected = Stream.of( - "sparse-1", - "sparse-2", - "sparse-3", - "embedding-1", - "embedding-2" - ).toArray(String[]::new); + var expected = Stream.of("sparse-1", "sparse-2", "sparse-3", "embedding-1", "embedding-2").toArray(String[]::new); assertThat(results.size(), is(expected.length)); assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java index c0d4de0693baf..9af4a2c9d9df5 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java @@ -28,13 +28,7 @@ public void testGetAllModelsDoesNotReturnEisModels_WhenEisUrlIsEmpty() { modelRegistry.getAllModels(false, listener); var results = listener.actionGet(TIMEOUT); - var expected = Stream.of( - "sparse-1", - "sparse-2", - "sparse-3", - "embedding-1", - "embedding-2" - ).toArray(String[]::new); + var expected = Stream.of("sparse-1", "sparse-2", "sparse-3", "embedding-1", "embedding-2").toArray(String[]::new); assertThat(results.size(), is(expected.length)); assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 432f8321b51a1..4e01b603c8a94 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -254,9 +254,7 @@ public void testGetMinimalServiceSettings_ReturnsEisPreconfiguredEndpoint() { assertThat(minimalSettings.taskType(), is(TaskType.CHAT_COMPLETION)); } { - var minimalSettings = registry.getMinimalServiceSettings( - ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2 - ); + var minimalSettings = registry.getMinimalServiceSettings(ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2); assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); assertThat(minimalSettings.taskType(), is(TaskType.SPARSE_EMBEDDING)); @@ -270,9 +268,7 @@ public void testGetMinimalServiceSettings_ReturnsEisPreconfiguredEndpoint() { assertThat(minimalSettings.taskType(), is(TaskType.TEXT_EMBEDDING)); } { - var minimalSettings = registry.getMinimalServiceSettings( - ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1 - ); + var minimalSettings = registry.getMinimalServiceSettings(ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1); assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); assertThat(minimalSettings.taskType(), is(TaskType.RERANK)); From 92b79e074a0393abb9aa5859ea64fdbb11dd11f8 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 8 Oct 2025 16:13:24 -0400 Subject: [PATCH 10/18] Removing unnecessary files --- ...cInferenceServiceAuthorizationHandler.java | 339 ------------------ ...renceServiceAuthorizationHandlerTests.java | 291 --------------- 2 files changed, 630 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java deleted file mode 100644 index b66569ac4db1c..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ /dev/null @@ -1,339 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.Randomness; -import org.elasticsearch.common.Strings; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.threadpool.Scheduler; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; - -import java.io.Closeable; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.EnumSet; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.TreeSet; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; - -// TODO remove this class -public class ElasticInferenceServiceAuthorizationHandler implements Closeable { - private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandler.class); - - private record AuthorizedContent( - ElasticInferenceServiceAuthorizationModel taskTypesAndModels, - List configIds, - List defaultModelConfigs - ) { - static AuthorizedContent empty() { - return new AuthorizedContent(ElasticInferenceServiceAuthorizationModel.newDisabledService(), List.of(), List.of()); - } - } - - private final ServiceComponents serviceComponents; - private final AtomicReference authorizedContent = new AtomicReference<>(AuthorizedContent.empty()); - private final ModelRegistry modelRegistry; - private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler; - private final Map defaultModelsConfigs; - private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1); - private final EnumSet implementedTaskTypes; - private final InferenceService inferenceService; - private final Sender sender; - private final Runnable callback; - private final AtomicReference lastAuthTask = new AtomicReference<>(null); - private final AtomicBoolean shutdown = new AtomicBoolean(false); - private final ElasticInferenceServiceSettings elasticInferenceServiceSettings; - - public ElasticInferenceServiceAuthorizationHandler( - ServiceComponents serviceComponents, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, - Map defaultModelsConfigs, - EnumSet implementedTaskTypes, - InferenceService inferenceService, - Sender sender, - ElasticInferenceServiceSettings elasticInferenceServiceSettings - ) { - this( - serviceComponents, - modelRegistry, - authorizationRequestHandler, - defaultModelsConfigs, - implementedTaskTypes, - Objects.requireNonNull(inferenceService), - sender, - elasticInferenceServiceSettings, - null - ); - } - - // default for testing - ElasticInferenceServiceAuthorizationHandler( - ServiceComponents serviceComponents, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, - Map defaultModelsConfigs, - EnumSet implementedTaskTypes, - InferenceService inferenceService, - Sender sender, - ElasticInferenceServiceSettings elasticInferenceServiceSettings, - // this is a hack to facilitate testing - Runnable callback - ) { - this.serviceComponents = Objects.requireNonNull(serviceComponents); - this.modelRegistry = Objects.requireNonNull(modelRegistry); - this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); - this.defaultModelsConfigs = Objects.requireNonNull(defaultModelsConfigs); - this.implementedTaskTypes = Objects.requireNonNull(implementedTaskTypes); - // allow the service to be null for testing - this.inferenceService = inferenceService; - this.sender = Objects.requireNonNull(sender); - this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); - this.callback = callback; - } - - public void init() { - logger.debug("Initializing authorization logic"); - serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest); - } - - /** - * Waits the specified amount of time for the first authorization call to complete. This is mainly to make testing easier. - * @param waitTime the max time to wait - * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} - */ - public void waitForAuthorizationToComplete(TimeValue waitTime) { - try { - if (firstAuthorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { - throw new IllegalStateException("The wait time has expired for authorization to complete."); - } - } catch (InterruptedException e) { - throw new IllegalStateException("Waiting for authorization to complete was interrupted"); - } - } - - public synchronized Set supportedStreamingTasks() { - var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); - authorizedStreamingTaskTypes.retainAll(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes()); - - return authorizedStreamingTaskTypes; - } - - public synchronized List defaultConfigIds() { - return authorizedContent.get().configIds; - } - - public synchronized void defaultConfigs(ActionListener> defaultsListener) { - var models = authorizedContent.get().defaultModelConfigs.stream().map(DefaultModelConfig::model).toList(); - defaultsListener.onResponse(models); - } - - public synchronized EnumSet supportedTaskTypes() { - return authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes(); - } - - public synchronized boolean hideFromConfigurationApi() { - return authorizedContent.get().taskTypesAndModels.isAuthorized() == false; - } - - @Override - public void close() throws IOException { - shutdown.set(true); - if (lastAuthTask.get() != null) { - lastAuthTask.get().cancel(); - } - } - - private void scheduleAuthorizationRequest() { - try { - if (elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled() == false) { - return; - } - - // this call has to be on the individual thread otherwise we get an exception - var random = Randomness.get(); - var jitter = (long) (elasticInferenceServiceSettings.getMaxAuthorizationRequestJitter().millis() * random.nextDouble()); - var waitTime = TimeValue.timeValueMillis(elasticInferenceServiceSettings.getAuthRequestInterval().millis() + jitter); - - logger.debug( - () -> Strings.format( - "Scheduling the next authorization call with request interval: %s ms, jitter: %d ms", - elasticInferenceServiceSettings.getAuthRequestInterval().millis(), - jitter - ) - ); - logger.debug(() -> Strings.format("Next authorization call in %d minutes", waitTime.getMinutes())); - - lastAuthTask.set( - serviceComponents.threadPool() - .schedule( - this::scheduleAndSendAuthorizationRequest, - waitTime, - serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME) - ) - ); - } catch (Exception e) { - logger.warn("Failed scheduling authorization request", e); - } - } - - private void scheduleAndSendAuthorizationRequest() { - if (shutdown.get()) { - return; - } - - scheduleAuthorizationRequest(); - sendAuthorizationRequest(); - } - - private void sendAuthorizationRequest() { - try { - ActionListener listener = ActionListener.wrap((model) -> { - setAuthorizedContent(model); - if (callback != null) { - callback.run(); - } - }, e -> { - // we don't need to do anything if there was a failure, everything is disabled by default - firstAuthorizationCompletedLatch.countDown(); - }); - - authorizationHandler.getAuthorization(listener, sender); - } catch (Exception e) { - logger.warn("Failure while sending the request to retrieve authorization", e); - // we don't need to do anything if there was a failure, everything is disabled by default - firstAuthorizationCompletedLatch.countDown(); - } - } - - private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) { - logger.debug(() -> Strings.format("Received authorization response, %s", auth)); - - var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes)); - logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels)); - - // recalculate which default config ids and models are authorized now - var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels); - - var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels); - var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds); - authorizedContent.set( - new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) - ); - - // We are no longer added the authorized preconfigured endpoints to the model registry. The model registry will reach out to - // the EIS gateway directly to get the model information. For now, I'm leaving the revoking logic in place but it will be removed - // when the authorization polling logic is moved to the master node. - handleRevokedDefaultConfigs(authorizedDefaultModelIds); - } - - private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorizationModel auth) { - var authorizedModels = auth.getAuthorizedModelIds(); - var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet()); - authorizedDefaultModelIds.retainAll(authorizedModels); - - return authorizedDefaultModelIds; - } - - private List getAuthorizedDefaultConfigIds( - Set authorizedDefaultModelIds, - ElasticInferenceServiceAuthorizationModel auth - ) { - var authorizedConfigIds = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - if (auth.getAuthorizedTaskTypes().contains(modelConfig.model().getTaskType()) == false) { - logger.warn( - org.elasticsearch.common.Strings.format( - "The authorization response included the default model: %s, " - + "but did not authorize the assumed task type of the model: %s. Enabling model.", - id, - modelConfig.model().getTaskType() - ) - ); - } - authorizedConfigIds.add( - new InferenceService.DefaultConfigId( - modelConfig.model().getInferenceEntityId(), - modelConfig.settings(), - inferenceService - ) - ); - } - } - - authorizedConfigIds.sort(Comparator.comparing(InferenceService.DefaultConfigId::inferenceId)); - return authorizedConfigIds; - } - - private List getAuthorizedDefaultModelsObjects(Set authorizedDefaultModelIds) { - var authorizedModels = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - authorizedModels.add(modelConfig); - } - } - - authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model().getInferenceEntityId())); - return authorizedModels; - } - - private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) { - // if a model was initially returned in the authorization response but is absent, then we'll assume authorization was revoked - var unauthorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); - unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds); - - // get all the default inference endpoint ids for the unauthorized model ids - var unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream() - .map(defaultModelsConfigs::get) // get all the model configs - .filter(Objects::nonNull) // limit to only non-null - .map(modelConfig -> modelConfig.model().getInferenceEntityId()) // get the inference ids - .collect(Collectors.toSet()); - - var deleteInferenceEndpointsListener = ActionListener.wrap(result -> { - logger.debug(Strings.format("Successfully revoked access to default inference endpoint IDs: %s", unauthorizedDefaultModelIds)); - firstAuthorizationCompletedLatch.countDown(); - }, e -> { - logger.warn( - Strings.format("Failed to revoke access to default inference endpoint IDs: %s, error: %s", unauthorizedDefaultModelIds, e) - ); - firstAuthorizationCompletedLatch.countDown(); - }); - - logger.debug( - () -> Strings.format( - "Synchronizing default inference endpoints, attempting to remove ids: %s", - unauthorizedDefaultInferenceEndpointIds - ) - ); - modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java deleted file mode 100644 index 6e6bfc8f8c910..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.EmptySecretSettings; -import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.MinimalServiceSettings; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.Utils; -import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; -import org.junit.Before; - -import java.io.IOException; -import java.util.Collection; -import java.util.EnumSet; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - -import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.defaultEndpointId; -import static org.hamcrest.CoreMatchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; - -// TODO remove this class -public class ElasticInferenceServiceAuthorizationHandlerTests extends ESSingleNodeTestCase { - private DeterministicTaskQueue taskQueue; - private ModelRegistry modelRegistry; - - @Override - protected Collection> getPlugins() { - return List.of(LocalStateInferencePlugin.class); - } - - // TODO add the EIS url here - @Override - protected Settings nodeSettings() { - return Settings.EMPTY; - } - - @Before - public void init() throws Exception { - taskQueue = new DeterministicTaskQueue(); - modelRegistry = getInstanceFromNode(ModelRegistry.class); - } - - public void testSecondAuthResultRevokesAuthorization() throws Exception { - var callbackCount = new AtomicInteger(0); - // we're only interested in two authorization calls which is why I'm using a value of 2 here - var latch = new CountDownLatch(2); - final AtomicReference handlerRef = new AtomicReference<>(); - - Runnable callback = () -> { - // the first authorization response contains a streaming task so we're expecting to support streaming here - if (callbackCount.incrementAndGet() == 1) { - assertThat(handlerRef.get().supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - } - latch.countDown(); - - // we only want to run the tasks twice, so advance the time on the queue - // which flags the scheduled authorization request to be ready to run - if (callbackCount.get() == 1) { - taskQueue.advanceTime(); - } else { - try { - handlerRef.get().close(); - } catch (IOException e) { - // ignore - } - } - }; - - var requestHandler = mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "rainbow-sprinkles", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ), - ElasticInferenceServiceAuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntity(List.of())) - ); - - handlerRef.set( - new ElasticInferenceServiceAuthorizationHandler( - createWithEmptySettings(taskQueue.getThreadPool()), - modelRegistry, - requestHandler, - initDefaultEndpoints(), - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION), - null, - mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - callback - ) - ); - - var handler = handlerRef.get(); - handler.init(); - taskQueue.runAllRunnableTasks(); - latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS); - - // this should be after we've received both authorization responses, the second response will revoke authorization - - assertThat(handler.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); - assertThat(handler.defaultConfigIds(), is(List.of())); - assertThat(handler.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - handler.defaultConfigs(listener); - - var configs = listener.actionGet(); - assertThat(configs.size(), is(0)); - } - - public void testSendsAnAuthorizationRequestTwice() throws Exception { - var callbackCount = new AtomicInteger(0); - // we're only interested in two authorization calls which is why I'm using a value of 2 here - var latch = new CountDownLatch(2); - final AtomicReference handlerRef = new AtomicReference<>(); - - Runnable callback = () -> { - // the first authorization response does not contain a streaming task so we're expecting to not support streaming here - if (callbackCount.incrementAndGet() == 1) { - assertThat(handlerRef.get().supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); - } - latch.countDown(); - - // we only want to run the tasks twice, so advance the time on the queue - // which flags the scheduled authorization request to be ready to run - if (callbackCount.get() == 1) { - taskQueue.advanceTime(); - } else { - try { - handlerRef.get().close(); - } catch (IOException e) { - // ignore - } - } - }; - - var requestHandler = mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("abc", EnumSet.of(TaskType.SPARSE_EMBEDDING)) - ) - ) - ), - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "abc", - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "rainbow-sprinkles", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ) - ); - - handlerRef.set( - new ElasticInferenceServiceAuthorizationHandler( - createWithEmptySettings(taskQueue.getThreadPool()), - modelRegistry, - requestHandler, - initDefaultEndpoints(), - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION), - null, - mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - callback - ) - ); - - var handler = handlerRef.get(); - handler.init(); - taskQueue.runAllRunnableTasks(); - latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS); - // this should be after we've received both authorization responses - - assertThat(handler.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - handler.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - null - ) - ) - ) - ); - assertThat(handler.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - handler.defaultConfigs(listener); - - var configs = listener.actionGet(); - assertThat(configs.get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - } - - private static ElasticInferenceServiceAuthorizationRequestHandler mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel firstAuthResponse, - ElasticInferenceServiceAuthorizationModel secondAuthResponse - ) { - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(firstAuthResponse); - return Void.TYPE; - }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(secondAuthResponse); - return Void.TYPE; - }).when(mockAuthHandler).getAuthorization(any(), any()); - - return mockAuthHandler; - } - - private static Map initDefaultEndpoints() { - return Map.of( - "rainbow-sprinkles", - new DefaultModelConfig( - new ElasticInferenceServiceCompletionModel( - defaultEndpointId("rainbow-sprinkles"), - TaskType.CHAT_COMPLETION, - "test", - new ElasticInferenceServiceCompletionServiceSettings("rainbow-sprinkles"), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.EMPTY_INSTANCE - ), - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME) - ), - "elser-2", - new DefaultModelConfig( - new ElasticInferenceServiceSparseEmbeddingsModel( - defaultEndpointId("elser-2"), - TaskType.SPARSE_EMBEDDING, - "test", - new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-2", null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.EMPTY_INSTANCE, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME) - ) - ); - } -} From 7db8bafda8d1494da19bf445969912c68f0bed8a Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 9 Oct 2025 10:29:55 -0400 Subject: [PATCH 11/18] Fixing tests --- .../esql/qa/rest/SemanticMatchTestCase.java | 11 +- ...nceServiceAuthorizationRequestHandler.java | 2 +- .../InferenceEndpointRegistryTests.java | 4 +- .../registry/ModelRegistryTests.java | 2 +- .../elastic/ElasticInferenceServiceTests.java | 179 +----------------- ...rviceAuthorizationRequestHandlerTests.java | 12 +- .../test/inference/inference_crud.yml | 2 +- 7 files changed, 26 insertions(+), 186 deletions(-) diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java index aada75f151d66..5cc005c477e90 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java @@ -9,6 +9,7 @@ import org.elasticsearch.client.Request; import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.xpack.esql.AssertWarnings; @@ -61,13 +62,15 @@ public void testWithMultipleInferenceIds() throws IOException { public void testWithInferenceNotConfigured() { assumeTrue("semantic text capability not available", EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS.isEnabled()); - String query = """ - from test-semantic3 + var inferenceId = "test-semantic3"; + + String query = Strings.format(""" + from %s | where match(semantic_text_field, "something") - """; + """, inferenceId); ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(query)); - assertThat(re.getMessage(), containsString("Inference endpoint not found")); + assertThat(re.getMessage(), containsString(Strings.format("Inference endpoint [%s] not found", inferenceId))); assertEquals(404, re.getResponse().getStatusLine().getStatusCode()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index 3068d44bd526c..e779d45450449 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -79,7 +79,7 @@ public void getAuthorization(ActionListener listener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [1]")); + assertThat(exception.getMessage(), containsString("Inference endpoint [1] not found")); } public void testGetModelWithSecrets() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 0a806cad3995a..cae5429e7f4ef 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -17,16 +17,13 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; @@ -1086,192 +1083,36 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi } } - public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimplementedTaskTypes() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "model-b", - "task_types": ["embed"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testSupportedTaskTypes_ThrowsUnsupportedException() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + expectThrows(UnsupportedOperationException.class, service::supportedTaskTypes); } } - public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "model-b", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - } - } - - public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChatCompletion() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testSupportedStreamingTasks_ReturnsChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertTrue(service.defaultConfigIds().isEmpty()); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + assertThat(service.defaultConfigIds(), is(List.of())); + expectThrows(UnsupportedOperationException.class, service::supportedTaskTypes); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); - assertTrue(listener.actionGet(TIMEOUT).isEmpty()); + assertThat(listener.actionGet(TIMEOUT), is(List.of())); } } - public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIncorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["embed/text/sparse"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - } - } - - public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - }, - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "multilingual-embed-v1", - "task_types": ["embed/text/dense"] - }, - { - "model_name": "rerank-v1", - "task_types": ["rerank/text/text-similarity"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testDefaultConfigs_ReturnsEmptyList() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertFalse(service.canStream(TaskType.ANY)); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".multilingual-embed-v1-elastic", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING)) - ); + assertThat(service.defaultConfigIds(), is(List.of())); + expectThrows(UnsupportedOperationException.class, service::supportedTaskTypes); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); - var models = listener.actionGet(TIMEOUT); - assertThat(models.size(), is(4)); - assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic")); - assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-v1-elastic")); - assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(models.get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic")); + assertThat(listener.actionGet(TIMEOUT), is(List.of())); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index e3d24ea2ec8f7..6657bb71f6848 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -80,10 +80,8 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); - assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); - assertFalse(authResponse.isAuthorized()); + var exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("The Elastic Inference Service URL is not configured.")); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(2)).debug(loggerArgsCaptor.capture()); @@ -102,10 +100,8 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); - assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); - assertFalse(authResponse.isAuthorized()); + var exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("The Elastic Inference Service URL is not configured.")); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(2)).debug(loggerArgsCaptor.capture()); diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml index 62a49422079b8..cb33b432b8afa 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml @@ -5,7 +5,7 @@ inference.get: inference_id: inference_to_get - match: { error.type: "resource_not_found_exception" } - - match: { error.reason: "Inference endpoint not found [inference_to_get]" } + - match: { error.reason: "Inference endpoint [inference_to_get] not found " } --- "Test put inference with bad task type": From 5db227970c882ee2a5a06dda032d992f909c17fe Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 9 Oct 2025 10:38:31 -0400 Subject: [PATCH 12/18] Adding some comments --- .../xpack/inference/registry/ModelRegistry.java | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index c1781910ede40..d00f4e68bbdac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -188,14 +188,6 @@ public boolean containsDefaultConfigId(String inferenceEntityId) { return defaultConfigIds.containsKey(inferenceEntityId); } - /** - * Adds the default configuration information if it does not already exist internally. - * @param defaultConfigId the default endpoint information - */ - public synchronized void putDefaultIdIfAbsent(InferenceService.DefaultConfigId defaultConfigId) { - defaultConfigIds.putIfAbsent(defaultConfigId.inferenceId(), defaultConfigId); - } - /** * Set the default inference ids provided by the services * @param defaultConfigId The default endpoint information @@ -244,6 +236,8 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId } } + // this is a temporary solution until the model registry handles polling the EIS authorization endpoint + // to retrieve the preconfigured inference endpoints var eisConfig = ElasticInferenceServiceMinimalSettings.getWithInferenceId(inferenceEntityId); if (eisConfig != null) { return eisConfig.minimalSettings(); From 231463545fe4d987ad879ed9902019ab8bafd571 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 9 Oct 2025 12:43:07 -0400 Subject: [PATCH 13/18] Adding more comments and tests --- .../InferenceGetServicesWithoutEisIT.java | 121 ++++++++++++++++++ ...sBase.java => ModelRegistryEisBaseIT.java} | 7 +- .../ModelRegistryEisGetModelIT.java | 5 +- .../integration/ModelRegistryEisIT.java | 2 +- .../ModelRegistryEisInvalidUrlIT.java | 2 +- .../TransportGetInferenceServicesAction.java | 15 ++- .../inference/registry/ModelRegistry.java | 3 +- .../elastic/ElasticInferenceService.java | 10 -- ...nceServiceAuthorizationRequestHandler.java | 6 +- .../PreconfiguredEndpointsModel.java | 46 ++++--- .../PreconfiguredEndpointsRequestHandler.java | 2 +- .../test/inference/inference_crud.yml | 2 +- 12 files changed, 180 insertions(+), 41 deletions(-) create mode 100644 x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesWithoutEisIT.java rename x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/{ModelRegistryEisBase.java => ModelRegistryEisBaseIT.java} (92%) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesWithoutEisIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesWithoutEisIT.java new file mode 100644 index 0000000000000..fea121f161afd --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesWithoutEisIT.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file has been contributed to by a Generative AI + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.client.Request; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.junit.ClassRule; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.not; + +public class InferenceGetServicesWithoutEisIT extends ESRestTestCase { + + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .distribution(DistributionType.DEFAULT) + .setting("xpack.license.self_generated.type", "trial") + .setting("xpack.security.enabled", "true") + // This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin + .plugin("inference-service-test") + .user("x_pack_rest_user", "x-pack-test-password") + .build(); + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + @Override + protected Settings restClientSettings() { + String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); + return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); + } + + public void testGetServicesWithoutTaskType() throws IOException { + assertThat(allProviders(), not(hasItem("elastic"))); + } + + private List allProviders() throws IOException { + return providers(getAllServices()); + } + + @SuppressWarnings("unchecked") + private List providers(List services) { + return services.stream().map(service -> { + var serviceConfig = (Map) service; + return (String) serviceConfig.get("service"); + }).toList(); + } + + public void testGetServicesWithTextEmbeddingTaskType() throws IOException { + var providers = providersFor(TaskType.TEXT_EMBEDDING); + assertThat(providers.size(), not(equalTo(0))); + assertThat(providers, not(hasItem("elastic"))); + } + + private List providersFor(TaskType taskType) throws IOException { + return providers(getServices(taskType)); + } + + public void testGetServicesWithRerankTaskType() throws IOException { + var providers = providersFor(TaskType.RERANK); + assertThat(providers.size(), not(equalTo(0))); + assertThat(providersFor(TaskType.RERANK), not(hasItem("elastic"))); + } + + public void testGetServicesWithCompletionTaskType() throws IOException { + var providers = providersFor(TaskType.COMPLETION); + assertThat(providers.size(), not(equalTo(0))); + assertThat(providersFor(TaskType.COMPLETION), not(hasItem("elastic"))); + } + + public void testGetServicesWithChatCompletionTaskType() throws IOException { + var providers = providersFor(TaskType.CHAT_COMPLETION); + assertThat(providers.size(), not(equalTo(0))); + assertThat(providersFor(TaskType.CHAT_COMPLETION), not(hasItem("elastic"))); + } + + public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { + var providers = providersFor(TaskType.SPARSE_EMBEDDING); + assertThat(providers.size(), not(equalTo(0))); + assertThat(providersFor(TaskType.SPARSE_EMBEDDING), not(hasItem("elastic"))); + } + + private List getAllServices() throws IOException { + var endpoint = Strings.format("_inference/_services"); + return getInternalAsList(endpoint); + } + + private List getServices(TaskType taskType) throws IOException { + var endpoint = Strings.format("_inference/_services/%s", taskType); + return getInternalAsList(endpoint); + } + + private List getInternalAsList(String endpoint) throws IOException { + var request = new Request("GET", endpoint); + var response = client().performRequest(request); + assertStatusOkOrCreated(response); + return entityAsList(response); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBase.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBaseIT.java similarity index 92% rename from x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBase.java rename to x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBaseIT.java index 83c408fa35012..65229afe6dde4 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBase.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBaseIT.java @@ -30,21 +30,22 @@ import static org.elasticsearch.xpack.inference.integration.ModelRegistryIT.createModel; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL; -public class ModelRegistryEisBase extends ESSingleNodeTestCase { +public class ModelRegistryEisBaseIT extends ESSingleNodeTestCase { protected static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); protected static final MockWebServer webServer = new MockWebServer(); protected ModelRegistry modelRegistry; private String eisUrl; - public ModelRegistryEisBase() {} + public ModelRegistryEisBaseIT() {} - public ModelRegistryEisBase(String eisUrl) { + public ModelRegistryEisBaseIT(String eisUrl) { this.eisUrl = eisUrl; } @BeforeClass public static void init() throws Exception { + // This must be called prior to retrieving the hostname and port from the mock server. webServer.start(); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisGetModelIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisGetModelIT.java index 4cd097517059d..33786f29e8c7e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisGetModelIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisGetModelIT.java @@ -32,7 +32,10 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; -public class ModelRegistryEisGetModelIT extends ModelRegistryEisBase { +/** + * Parameterized tests for {@link ModelRegistry#getModel} and {@link ModelRegistry#getModelWithSecrets}. + */ +public class ModelRegistryEisGetModelIT extends ModelRegistryEisBaseIT { private final TestCase testCase; public ModelRegistryEisGetModelIT(TestCase testCase) { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java index 9891929c66468..9eeea11eb1486 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java @@ -18,7 +18,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.is; -public class ModelRegistryEisIT extends ModelRegistryEisBase { +public class ModelRegistryEisIT extends ModelRegistryEisBaseIT { private static final String eisAuthorizedResponse = """ { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java index 9af4a2c9d9df5..0db54784851f2 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java @@ -16,7 +16,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.is; -public class ModelRegistryEisInvalidUrlIT extends ModelRegistryEisBase { +public class ModelRegistryEisInvalidUrlIT extends ModelRegistryEisBaseIT { public ModelRegistryEisInvalidUrlIT() { super(""); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java index b07de7434f36a..9ebbeb187ad7b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java @@ -145,11 +145,16 @@ private void getServiceConfigurationsForServicesAndEis( private void getEisAuthorization(ActionListener listener, Sender sender) { var disabledServiceListener = listener.delegateResponse((delegate, e) -> { - logger.warn( - "Failed to retrieve authorization information from the " - + "Elastic Inference Service while determining service configurations. Marking service as disabled.", - e - ); + if (eisAuthorizationRequestHandler.isServiceConfigured()) { + logger.warn( + "Failed to retrieve authorization information from the " + + "Elastic Inference Service while determining service configurations. Marking service as disabled.", + e + ); + } else { + logger.debug("The Elastic Inference Service is not configured. Marking service as disabled.", e); + } + delegate.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); }); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index d00f4e68bbdac..8b8e304b6e9ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -353,13 +353,14 @@ private void searchForEndpointInDefaultAndEis( return; } - // we didn't find the configuration in the inference index, so check if it is a pre-configured endpoint + // we didn't find the configuration in the inference index, so check if it is a preconfigured endpoint var maybeDefault = defaultConfigIds.get(inferenceEntityId); if (maybeDefault != null) { getDefaultConfig(true, maybeDefault, listener); return; } + // check if the inference id is a preconfigured endpoint available from EIS retrievePreconfiguredEndpointFromEisElseNotFound(listener, inferenceEntityId); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 8715f95263579..4f9a7d680de00 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -104,12 +104,6 @@ public class ElasticInferenceService extends SenderService { // A batch size of 16 provides optimal throughput and stability, especially on lower-tier instance types. public static final Integer SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 16; - private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( - TaskType.SPARSE_EMBEDDING, - TaskType.CHAT_COMPLETION, - TaskType.RERANK, - TaskType.TEXT_EMBEDDING - ); private static final String SERVICE_NAME = "Elastic"; // TODO: revisit this value once EIS supports dense models @@ -126,10 +120,6 @@ public class ElasticInferenceService extends SenderService { TaskType.TEXT_EMBEDDING ); - public static String defaultEndpointId(String modelId) { - return Strings.format(".%s-elastic", modelId); - } - private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; public ElasticInferenceService( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index e779d45450449..8f39f1261ee3e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -77,7 +77,7 @@ public void getAuthorization(ActionListener preconfiguredEndpoints) { + private static final Logger logger = LogManager.getLogger(PreconfiguredEndpointsModel.class); public static PreconfiguredEndpointsModel of(ElasticInferenceServiceAuthorizationModel authModel) { var endpoints = authModel.getAuthorizedModelIds() @@ -43,7 +49,14 @@ private static PreconfiguredEndpoint of(ElasticInferenceServiceMinimalSettings.S if (settings.minimalSettings().dimensions() == null || settings.minimalSettings().similarity() == null || settings.minimalSettings().elementType() == null) { - // TODO log a warning + logger.warn( + "Skipping embedding endpoint [{}] as it is missing required settings. " + + "Dimensions: [{}], Similarity: [{}], Element Type: [{}]", + settings.inferenceId(), + settings.minimalSettings().dimensions(), + settings.minimalSettings().similarity(), + settings.minimalSettings().elementType() + ); yield null; } @@ -102,25 +115,26 @@ private static Map embeddingSettings( int dimension, DenseVectorFieldMapper.ElementType elementType ) { - return new HashMap<>( - Map.of( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ServiceFields.MODEL_ID, - modelId, - ServiceFields.SIMILARITY, - similarityMeasure.toString(), - ServiceFields.DIMENSIONS, - dimension, - ServiceFields.ELEMENT_TYPE, - elementType.toString() - ) + return wrapWithServiceSettings( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + ServiceFields.SIMILARITY, + similarityMeasure.toString(), + ServiceFields.DIMENSIONS, + dimension, + ServiceFields.ELEMENT_TYPE, + elementType.toString() ) ) ); } + private static Map wrapWithServiceSettings(Map settings) { + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, settings)); + } + private record BasePreconfiguredEndpoint(String inferenceEntityId, TaskType taskType, String modelId) implements PreconfiguredEndpoint { @Override public UnparsedModel toUnparsedModel() { @@ -129,7 +143,7 @@ public UnparsedModel toUnparsedModel() { } private static Map settingsWithModelId(String modelId) { - return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)))); + return wrapWithServiceSettings(new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId))); } public UnparsedModel toUnparsedModel(String inferenceId) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java index fdd3a0e290611..1e347049d630b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java @@ -17,7 +17,7 @@ /** * This class is responsible for converting the current EIS authorization response structure - * into Models that + * into {@link UnparsedModel}. */ public class PreconfiguredEndpointsRequestHandler { private final ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler; diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml index cb33b432b8afa..ee848f4fe3840 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml @@ -5,7 +5,7 @@ inference.get: inference_id: inference_to_get - match: { error.type: "resource_not_found_exception" } - - match: { error.reason: "Inference endpoint [inference_to_get] not found " } + - match: { error.reason: "Inference endpoint [inference_to_get] not found or you are not authorized to access it" } --- "Test put inference with bad task type": From e39021363935ce98c21ff4a66231e84ae4f5fef8 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 9 Oct 2025 14:45:21 -0400 Subject: [PATCH 14/18] Trying to fix yaml tests --- x-pack/plugin/build.gradle | 1 + .../xpack/esql/qa/rest/SemanticMatchTestCase.java | 11 ++++------- .../inference/integration/InferenceIndicesIT.java | 4 ++-- .../inference/integration/ModelRegistryEisBaseIT.java | 2 +- .../xpack/inference/registry/ModelRegistry.java | 2 +- .../rest-api-spec/test/inference/inference_crud.yml | 4 ++-- 6 files changed, 11 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/build.gradle b/x-pack/plugin/build.gradle index ea715b0d5c921..f947a845a056d 100644 --- a/x-pack/plugin/build.gradle +++ b/x-pack/plugin/build.gradle @@ -155,6 +155,7 @@ tasks.named("yamlRestCompatTestTransform").configure({ task -> task.skipTest("esql/46_downsample/Query stats on downsampled index", "Extra function required to enable the field type") task.skipTest("esql/46_downsample/Render stats from downsampled index", "Extra function required to enable the field type") task.skipTest("esql/46_downsample/Sort from multiple indices one with aggregate metric double", "Extra function required to enable the field type") + task.skipTest("inference/inference_crud/Test get missing model", "Error message changed") }) tasks.named('yamlRestCompatTest').configure { diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java index 5cc005c477e90..48d421d970bf5 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java @@ -9,7 +9,6 @@ import org.elasticsearch.client.Request; import org.elasticsearch.client.ResponseException; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.xpack.esql.AssertWarnings; @@ -62,15 +61,13 @@ public void testWithMultipleInferenceIds() throws IOException { public void testWithInferenceNotConfigured() { assumeTrue("semantic text capability not available", EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS.isEnabled()); - var inferenceId = "test-semantic3"; - - String query = Strings.format(""" - from %s + String query = """ + from test-semantic3 | where match(semantic_text_field, "something") - """, inferenceId); + """; ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(query)); - assertThat(re.getMessage(), containsString(Strings.format("Inference endpoint [%s] not found", inferenceId))); + assertThat(re.getMessage(), containsString("Inference endpoint [inexistent] not found")); assertEquals(404, re.getResponse().getStatusLine().getStatusCode()); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java index e59f0617851c3..bdf4f9318ce60 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java @@ -167,7 +167,7 @@ public void testRetrievingInferenceEndpoint_ThrowsException_WhenIndexNodeIsNotAv var proxyResponse = sendInferenceProxyRequest(inferenceId); var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT)); - assertThat(exception.toString(), containsString("Failed to load inference endpoint with secrets [test-index-id-2]")); + assertThat(exception.toString(), containsString("Failed to load inference endpoint [test-index-id-2]")); var causeException = exception.getCause(); assertThat(causeException, instanceOf(SearchPhaseExecutionException.class)); @@ -196,7 +196,7 @@ public void testRetrievingInferenceEndpoint_ThrowsException_WhenSecretsIndexNode var proxyResponse = sendInferenceProxyRequest(inferenceId); var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT)); - assertThat(exception.toString(), containsString("Failed to load inference endpoint with secrets [test-secrets-index-id]")); + assertThat(exception.toString(), containsString("Failed to load inference endpoint [test-secrets-index-id]")); var causeException = exception.getCause(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBaseIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBaseIT.java index 65229afe6dde4..6d320a00d695c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBaseIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBaseIT.java @@ -30,7 +30,7 @@ import static org.elasticsearch.xpack.inference.integration.ModelRegistryIT.createModel; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL; -public class ModelRegistryEisBaseIT extends ESSingleNodeTestCase { +public abstract class ModelRegistryEisBaseIT extends ESSingleNodeTestCase { protected static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); protected static final MockWebServer webServer = new MockWebServer(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 8b8e304b6e9ca..cfe2057a7731b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -314,7 +314,7 @@ private void getModelHelper( var failureListener = listener.delegateResponse((delegate, e) -> { // If the inference endpoint does not exist, we've already created a well-defined exception, so just return it - if (e instanceof ElasticsearchException) { + if (e instanceof ResourceNotFoundException) { delegate.onFailure(e); return; } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml index ee848f4fe3840..3c1b53e60e0c4 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml @@ -1,11 +1,11 @@ --- -"Test get missing model": +"Test get missing model v2": - do: catch: missing inference.get: inference_id: inference_to_get - match: { error.type: "resource_not_found_exception" } - - match: { error.reason: "Inference endpoint [inference_to_get] not found or you are not authorized to access it" } + - contains: { error.reason: "Inference endpoint [inference_to_get] not found" } --- "Test put inference with bad task type": From d4a1a03842892cae4c71a40b024bb0e880a97f68 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 9 Oct 2025 15:08:46 -0400 Subject: [PATCH 15/18] Adding requirement on contains --- .../resources/rest-api-spec/test/inference/inference_crud.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml index 3c1b53e60e0c4..5444d70539a7b 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml @@ -1,5 +1,8 @@ --- "Test get missing model v2": + - skip: + reason: "contains is a newly added assertion" + features: contains - do: catch: missing inference.get: From 4dfbcd5dca7634bafdfc9e17ea77a33def6b828a Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 9 Oct 2025 15:46:46 -0400 Subject: [PATCH 16/18] Adding test for unauthorized model --- .../integration/ModelRegistryEisIT.java | 33 ++++++++++++++----- .../PreconfiguredEndpointsModel.java | 9 ++++- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java index 9eeea11eb1486..2450d12064832 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java @@ -7,19 +7,28 @@ package org.elasticsearch.xpack.inference.integration; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; +import org.junit.Before; import java.util.List; import java.util.stream.Stream; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; public class ModelRegistryEisIT extends ModelRegistryEisBaseIT { + @Before + public void setupTest() { + initializeModels(); + } + private static final String eisAuthorizedResponse = """ { "models": [ @@ -51,8 +60,6 @@ public class ModelRegistryEisIT extends ModelRegistryEisBaseIT { """; public void testGetModelsByTaskType() { - initializeModels(); - { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); PlainActionFuture> listener = new PlainActionFuture<>(); @@ -104,8 +111,6 @@ public void testGetModelsByTaskType() { } public void testGetAllModels() { - initializeModels(); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); PlainActionFuture> listener = new PlainActionFuture<>(); modelRegistry.getAllModels(false, listener); @@ -127,8 +132,6 @@ public void testGetAllModels() { } public void testGetAllModelsNoEisResults() { - initializeModels(); - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisUnauthorizedResponse)); PlainActionFuture> listener = new PlainActionFuture<>(); modelRegistry.getAllModels(false, listener); @@ -139,9 +142,23 @@ public void testGetAllModelsNoEisResults() { assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); } - public void testGetAllModelsEisReturnsFailureStatusCode() { - initializeModels(); + public void testGetModel_WhenNotAuthorizedForEis() { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisUnauthorizedResponse)); + PlainActionFuture listener = new PlainActionFuture<>(); + modelRegistry.getModel(ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Unable to retrieve the preconfigured inference endpoint")); + assertThat( + exception.getCause().getMessage(), + containsString( + "No Elastic Inference Service preconfigured endpoint found for inference ID [.rerank-v1-elastic]. " + + "Either it does not exist, or you are not authorized to access it." + ) + ); + } + public void testGetAllModelsEisReturnsFailureStatusCode() { webServer.enqueue(new MockResponse().setResponseCode(500).setBody("{}")); PlainActionFuture> listener = new PlainActionFuture<>(); modelRegistry.getAllModels(false, listener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java index 30189f9e1bc71..cf93393512b4b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsModel.java @@ -9,6 +9,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; @@ -149,7 +150,13 @@ private static Map settingsWithModelId(String modelId) { public UnparsedModel toUnparsedModel(String inferenceId) { PreconfiguredEndpoint endpoint = preconfiguredEndpoints.get(inferenceId); if (endpoint == null) { - throw new IllegalArgumentException("No EIS preconfigured endpoint found for inference ID: " + inferenceId); + throw new IllegalArgumentException( + Strings.format( + "No Elastic Inference Service preconfigured endpoint found for inference ID [%s]. " + + "Either it does not exist, or you are not authorized to access it.", + inferenceId + ) + ); } return endpoint.toUnparsedModel(); From 028ea33cec2a3b5cde5efb26e9dc480d141ce33d Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 9 Oct 2025 17:11:16 -0400 Subject: [PATCH 17/18] Switching tests for error message change --- x-pack/plugin/inference/build.gradle | 7 +++++++ .../inference/BaseMockEISAuthServerTest.java | 5 ----- ...etModelsWithElasticInferenceServiceIT.java | 20 +++---------------- .../inference/InferenceGetServicesIT.java | 16 --------------- .../test/inference/40_semantic_text_query.yml | 14 +++++++++---- .../70_text_similarity_rank_retriever.yml | 8 ++++---- 6 files changed, 24 insertions(+), 46 deletions(-) diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index eb9372e675831..7c6d2bc46abb8 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -407,6 +407,13 @@ tasks.named('yamlRestTest') { usesDefaultDistribution("Uses the inference API") } +tasks.named("yamlRestCompatTestTransform").configure({ task -> + task.skipTest("inference/40_semantic_text_query/Query a field with an invalid inference ID", "Error message changed") + task.skipTest("inference/40_semantic_text_query/Query a field with an invalid search inference ID", "Error message changed") + task.skipTest("inference/70_text_similarity_rank_retriever/Text similarity reranking fails if the inference ID does not exist", "Error message changed") + task.skipTest("inference/70_text_similarity_rank_retriever/Text similarity reranking fails if the inference ID does not exist and result set is empty", "Error message changed") +}) + artifacts { restXpackTests(new File(projectDir, "src/yamlRestTest/resources/rest-api-spec/test")) } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index 09834e6a91210..f809444a2a73d 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -27,11 +27,6 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase { protected static final MockElasticInferenceServiceAuthorizationServer mockEISServer = new MockElasticInferenceServiceAuthorizationServer(); - static { - // Ensure that the mock EIS server has an authorized response prior to the cluster starting - mockEISServer.enqueueAuthorizeAllModelsResponse(); - } - private static ElasticsearchCluster cluster = ElasticsearchCluster.local() .distribution(DistributionType.DEFAULT) .setting("xpack.license.self_generated.type", "trial") diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index ecc3bcd508bb6..4d98bb2801c94 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -10,7 +10,6 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.inference.TaskType; -import org.junit.BeforeClass; import java.io.IOException; import java.util.List; @@ -22,24 +21,11 @@ import static org.hamcrest.Matchers.is; public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest { - - /** - * This is done before the class because I've run into issues where another class that extends {@link BaseMockEISAuthServerTest} - * results in an authorization response not being queued up for the new Elasticsearch Node in time. When the node starts up, it - * retrieves authorization. If the request isn't queued up when that happens the tests will fail. From my testing locally it seems - * like the base class's static functionality to queue a response is only done once and not for each subclass. - * - * My understanding is that the @Before will be run after the node starts up and wouldn't be sufficient to handle - * this scenario. That is why this needs to be @BeforeClass. - */ - @BeforeClass - public static void init() { - // Ensure the mock EIS server has an authorized response ready - mockEISServer.enqueueAuthorizeAllModelsResponse(); - } - public void testGetDefaultEndpoints() throws IOException { + mockEISServer.enqueueAuthorizeAllModelsResponse(); var allModels = getAllModels(); + + mockEISServer.enqueueAuthorizeAllModelsResponse(); var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION); assertThat(allModels, hasSize(7)); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index f86c92c02db48..95cd94cb4b6f2 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.inference.TaskType; import org.junit.Before; -import org.junit.BeforeClass; import java.io.IOException; import java.util.List; @@ -32,21 +31,6 @@ public void setUp() throws Exception { mockEISServer.enqueueAuthorizeAllModelsResponse(); } - /** - * This is done before the class because I've run into issues where another class that extends {@link BaseMockEISAuthServerTest} - * results in an authorization response not being queued up for the new Elasticsearch Node in time. When the node starts up, it - * retrieves authorization. If the request isn't queued up when that happens the tests will fail. From my testing locally it seems - * like the base class's static functionality to queue a response is only done once and not for each subclass. - * - * My understanding is that the @Before will be run after the node starts up and wouldn't be sufficient to handle - * this scenario. That is why this needs to be @BeforeClass. - */ - @BeforeClass - public static void init() { - // Ensure the mock EIS server has an authorized response ready - mockEISServer.enqueueAuthorizeAllModelsResponse(); - } - public void testGetServicesWithoutTaskType() throws IOException { assertThat( allProviders(), diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml index 0b1a611bcdf72..a528ac9090168 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml @@ -781,7 +781,10 @@ setup: - match: { hits.total.value: 0 } --- -"Query a field with an invalid inference ID": +"Query a field with an invalid inference ID v2": + - skip: + reason: "contains is a newly added assertion" + features: contains - do: indices.create: index: test-index-with-invalid-inference-id @@ -803,7 +806,7 @@ setup: query: "inference test" - match: { error.type: "resource_not_found_exception" } - - match: { error.reason: "Inference endpoint not found [invalid-inference-id]" } + - contains: { error.reason: "Inference endpoint [invalid-inference-id] not found" } --- "Query a field with a search inference ID that uses the wrong task type": @@ -896,7 +899,10 @@ setup: compatible with the inference endpoint [dense-inference-id]?" } --- -"Query a field with an invalid search inference ID": +"Query a field with an invalid search inference ID v2": + - skip: + reason: "contains is a newly added assertion" + features: contains - do: indices.put_mapping: index: test-dense-index @@ -927,7 +933,7 @@ setup: query: "inference test" - match: { error.type: "resource_not_found_exception" } - - match: { error.reason: "Inference endpoint not found [invalid-inference-id]" } + - contains: { error.reason: "Inference endpoint [invalid-inference-id] not found" } --- "Query a field that uses the default ELSER 2 endpoint": diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml index d971aad2bbc4b..261cf92c3553f 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml @@ -181,9 +181,9 @@ setup: - match: { hits.hits.0._id: "doc_1" } --- -"Text similarity reranking fails if the inference ID does not exist": +"Text similarity reranking fails if the inference ID does not exist v2": - do: - catch: /Inference endpoint not found/ + catch: /Inference endpoint \[.*?\] not found/ search: index: test-index body: @@ -206,13 +206,13 @@ setup: size: 10 --- -"Text similarity reranking fails if the inference ID does not exist and result set is empty": +"Text similarity reranking fails if the inference ID does not exist and result set is empty v2": - requires: cluster_features: "gte_v8.15.1" reason: bug fixed in 8.15.1 - do: - catch: /Inference endpoint not found/ + catch: /Inference endpoint \[.*?\] not found/ search: index: test-index body: From 11b5e3ccd8cec1e507075b6bc3084677f406c603 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 9 Oct 2025 17:19:40 -0400 Subject: [PATCH 18/18] Adding compatability library --- x-pack/plugin/inference/build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 7c6d2bc46abb8..7afd9903090b3 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -9,6 +9,7 @@ apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' apply plugin: 'elasticsearch.internal-yaml-rest-test' apply plugin: 'elasticsearch.internal-test-artifact' +apply plugin: 'elasticsearch.yaml-rest-compat-test' restResources { restApi {