diff --git a/docs/changelog/132362.yaml b/docs/changelog/132362.yaml new file mode 100644 index 0000000000000..8cdf915346136 --- /dev/null +++ b/docs/changelog/132362.yaml @@ -0,0 +1,5 @@ +pr: 132362 +summary: Inference API disable partial search results +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java new file mode 100644 index 0000000000000..e59f0617851c3 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java @@ -0,0 +1,252 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.search.SearchPhaseExecutionException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin; +import org.elasticsearch.xpack.core.inference.InferenceContext; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.ssl.SSLService; +import org.elasticsearch.xpack.inference.InferenceIndex; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.InferenceSecretsIndex; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 +public class InferenceIndicesIT extends ESIntegTestCase { + + private static final String INDEX_ROUTER_ATTRIBUTE = "node.attr.index_router"; + private static final String CONFIG_ROUTER = "config"; + private static final String SECRETS_ROUTER = "secrets"; + + private static final Map TEST_SERVICE_SETTINGS = Map.of( + "model", + "my_model", + "dimensions", + 256, + "similarity", + "cosine", + "api_key", + "my_api_key" + ); + + public static class LocalStateIndexSettingsInferencePlugin extends LocalStateCompositeXPackPlugin { + private final InferencePlugin inferencePlugin; + + public LocalStateIndexSettingsInferencePlugin(final Settings settings, final Path configPath) throws Exception { + super(settings, configPath); + var thisVar = this; + this.inferencePlugin = new InferencePlugin(settings) { + @Override + protected SSLService getSslService() { + return thisVar.getSslService(); + } + + @Override + protected XPackLicenseState getLicenseState() { + return thisVar.getLicenseState(); + } + + @Override + public List getInferenceServiceFactories() { + return List.of( + TestSparseInferenceServiceExtension.TestInferenceService::new, + TestDenseInferenceServiceExtension.TestInferenceService::new + ); + } + + @Override + public Settings getIndexSettings() { + return InferenceIndex.builder() + .put(Settings.builder().put("index.routing.allocation.require.index_router", "config").build()) + .build(); + } + + @Override + public Settings getSecretsIndexSettings() { + return InferenceSecretsIndex.builder() + .put(Settings.builder().put("index.routing.allocation.require.index_router", "secrets").build()) + .build(); + } + }; + plugins.add(inferencePlugin); + } + + } + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins() { + return List.of(LocalStateIndexSettingsInferencePlugin.class, TestInferenceServicePlugin.class); + } + + public void testRetrievingInferenceEndpoint_ThrowsException_WhenIndexNodeIsNotAvailable() throws Exception { + final var configIndexNodeAttributes = Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, CONFIG_ROUTER).build(); + + internalCluster().startMasterOnlyNode(configIndexNodeAttributes); + final var configIndexDataNodes = internalCluster().startDataOnlyNode(configIndexNodeAttributes); + + internalCluster().startDataOnlyNode(Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, SECRETS_ROUTER).build()); + + final var inferenceId = "test-index-id"; + createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, TEST_SERVICE_SETTINGS); + + // Ensure the inference indices are created and we can retrieve the inference endpoint + var getInferenceEndpointRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.TEXT_EMBEDDING, true); + var responseFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest); + assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getEndpoints().get(0).getInferenceEntityId(), equalTo(inferenceId)); + + // stop the node that holds the inference index + internalCluster().stopNode(configIndexDataNodes); + + var responseFailureFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest); + var exception = expectThrows(ElasticsearchException.class, () -> responseFailureFuture.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat(exception.toString(), containsString("Failed to load inference endpoint [test-index-id]")); + + var causeException = exception.getCause(); + assertThat(causeException, instanceOf(SearchPhaseExecutionException.class)); + } + + public void testRetrievingInferenceEndpoint_ThrowsException_WhenIndexNodeIsNotAvailable_ForInferenceAction() throws Exception { + final var configIndexNodeAttributes = Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, CONFIG_ROUTER).build(); + + internalCluster().startMasterOnlyNode(configIndexNodeAttributes); + final var configIndexDataNodes = internalCluster().startDataOnlyNode(configIndexNodeAttributes); + + internalCluster().startDataOnlyNode(Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, SECRETS_ROUTER).build()); + + final var inferenceId = "test-index-id-2"; + createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, TEST_SERVICE_SETTINGS); + + // Ensure the inference indices are created and we can retrieve the inference endpoint + var getInferenceEndpointRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.TEXT_EMBEDDING, true); + var responseFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest); + assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getEndpoints().get(0).getInferenceEntityId(), equalTo(inferenceId)); + + // stop the node that holds the inference index + internalCluster().stopNode(configIndexDataNodes); + + var proxyResponse = sendInferenceProxyRequest(inferenceId); + var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat(exception.toString(), containsString("Failed to load inference endpoint with secrets [test-index-id-2]")); + + var causeException = exception.getCause(); + assertThat(causeException, instanceOf(SearchPhaseExecutionException.class)); + } + + public void testRetrievingInferenceEndpoint_ThrowsException_WhenSecretsIndexNodeIsNotAvailable() throws Exception { + final var configIndexNodeAttributes = Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, CONFIG_ROUTER).build(); + internalCluster().startMasterOnlyNode(configIndexNodeAttributes); + internalCluster().startDataOnlyNode(configIndexNodeAttributes); + + var secretIndexDataNodes = internalCluster().startDataOnlyNode( + Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, SECRETS_ROUTER).build() + ); + + final var inferenceId = "test-secrets-index-id"; + createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, TEST_SERVICE_SETTINGS); + + // Ensure the inference indices are created and we can retrieve the inference endpoint + var getInferenceEndpointRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.TEXT_EMBEDDING, true); + var responseFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest); + assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getEndpoints().get(0).getInferenceEntityId(), equalTo(inferenceId)); + + // stop the node that holds the inference secrets index + internalCluster().stopNode(secretIndexDataNodes); + + var proxyResponse = sendInferenceProxyRequest(inferenceId); + + var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat(exception.toString(), containsString("Failed to load inference endpoint with secrets [test-secrets-index-id]")); + + var causeException = exception.getCause(); + + assertThat(causeException, instanceOf(SearchPhaseExecutionException.class)); + } + + private ActionFuture sendInferenceProxyRequest(String inferenceId) throws IOException { + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("input", List.of("test input")); + builder.endObject(); + + content = BytesReference.bytes(builder); + } + + var inferenceRequest = new InferenceActionProxy.Request( + TaskType.TEXT_EMBEDDING, + inferenceId, + content, + XContentType.JSON, + TimeValue.THIRTY_SECONDS, + false, + InferenceContext.EMPTY_INSTANCE + ); + + return client().execute(InferenceActionProxy.INSTANCE, inferenceRequest); + } + + private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { + var responseFuture = createInferenceEndpointAsync(taskType, inferenceId, serviceSettings); + assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); + } + + private ActionFuture createInferenceEndpointAsync( + TaskType taskType, + String inferenceId, + Map serviceSettings + ) throws IOException { + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("service", TestDenseInferenceServiceExtension.TestInferenceService.NAME); + builder.field("service_settings", serviceSettings); + builder.endObject(); + + content = BytesReference.bytes(builder); + } + + var request = new PutInferenceModelAction.Request(taskType, inferenceId, content, XContentType.JSON, TEST_REQUEST_TIMEOUT); + return client().execute(PutInferenceModelAction.INSTANCE, request); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java index 1c93494d78636..eb79fc08bd1a0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java @@ -30,10 +30,12 @@ private InferenceIndex() {} private static final int INDEX_MAPPING_VERSION = 2; public static Settings settings() { - return Settings.builder() - .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) - .put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1") - .build(); + return builder().build(); + } + + // Public to allow tests to create the index with custom settings + public static Settings.Builder builder() { + return Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1"); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 1a2de4cc6b31f..c3ae4f0d9d6d6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -466,7 +466,7 @@ public Collection getSystemIndexDescriptors(Settings sett .setPrimaryIndex(InferenceIndex.INDEX_NAME) .setDescription("Contains inference service and model configuration") .setMappings(InferenceIndex.mappings()) - .setSettings(InferenceIndex.settings()) + .setSettings(getIndexSettings()) .setOrigin(ClientHelper.INFERENCE_ORIGIN) .setPriorSystemIndexDescriptors(List.of(inferenceIndexV1Descriptor)) .build(), @@ -476,13 +476,23 @@ public Collection getSystemIndexDescriptors(Settings sett .setPrimaryIndex(InferenceSecretsIndex.INDEX_NAME) .setDescription("Contains inference service secrets") .setMappings(InferenceSecretsIndex.mappings()) - .setSettings(InferenceSecretsIndex.settings()) + .setSettings(getSecretsIndexSettings()) .setOrigin(ClientHelper.INFERENCE_ORIGIN) .setNetNew() .build() ); } + // Overridable for tests + protected Settings getIndexSettings() { + return InferenceIndex.settings(); + } + + // Overridable for tests + protected Settings getSecretsIndexSettings() { + return InferenceSecretsIndex.settings(); + } + @Override public List> getExecutorBuilders(Settings settingsToUse) { return List.of(inferenceUtilityExecutor(settings)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceSecretsIndex.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceSecretsIndex.java index f11864eb9f068..649dc27e4a493 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceSecretsIndex.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceSecretsIndex.java @@ -29,10 +29,12 @@ private InferenceSecretsIndex() {} private static final int INDEX_MAPPING_VERSION = 1; public static Settings settings() { - return Settings.builder() - .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) - .put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1") - .build(); + return builder().build(); + } + + // Public to allow tests to create the index with custom settings + public static Settings.Builder builder() { + return Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1"); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 37a82b2160595..fe7c4a9395cd1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -249,25 +249,34 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId * @param listener Model listener */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { - ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { + ActionListener searchListener = ActionListener.wrap((searchResponse) -> { // There should be a hit for the configurations if (searchResponse.getHits().getHits().length == 0) { var maybeDefault = defaultConfigIds.get(inferenceEntityId); if (maybeDefault != null) { getDefaultConfig(true, maybeDefault, listener); } else { - delegate.onFailure(inferenceNotFoundException(inferenceEntityId)); + listener.onFailure(inferenceNotFoundException(inferenceEntityId)); } return; } - delegate.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId))); + listener.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId))); + }, (e) -> { + logger.warn(format("Failed to load inference endpoint with secrets [%s]", inferenceEntityId), e); + listener.onFailure( + new ElasticsearchException( + format("Failed to load inference endpoint with secrets [%s], error: [%s]", inferenceEntityId, e.getMessage()), + e + ) + ); }); QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId); SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) .setQuery(queryBuilder) .setSize(2) + .setAllowPartialSearchResults(false) .request(); client.search(modelSearch, searchListener); @@ -280,21 +289,29 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { - ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { + ActionListener searchListener = ActionListener.wrap((searchResponse) -> { // There should be a hit for the configurations if (searchResponse.getHits().getHits().length == 0) { var maybeDefault = defaultConfigIds.get(inferenceEntityId); if (maybeDefault != null) { getDefaultConfig(true, maybeDefault, listener); } else { - delegate.onFailure(inferenceNotFoundException(inferenceEntityId)); + listener.onFailure(inferenceNotFoundException(inferenceEntityId)); } return; } var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); assert modelConfigs.size() == 1; - delegate.onResponse(modelConfigs.get(0)); + listener.onResponse(modelConfigs.get(0)); + }, e -> { + logger.warn(format("Failed to load inference endpoint [%s]", inferenceEntityId), e); + listener.onFailure( + new ElasticsearchException( + format("Failed to load inference endpoint [%s], error: [%s]", inferenceEntityId, e.getMessage()), + e + ) + ); }); QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId);