Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/132362.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 132362
summary: Inference API disable partial search results
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.integration;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.core.inference.InferenceContext;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.ssl.SSLService;
import org.elasticsearch.xpack.inference.InferenceIndex;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.InferenceSecretsIndex;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;

import java.io.IOException;
import java.nio.file.Path;
import java.util.Collection;
import java.util.List;
import java.util.Map;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;

@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435
public class InferenceIndicesIT extends ESIntegTestCase {

private static final String INDEX_ROUTER_ATTRIBUTE = "node.attr.index_router";
private static final String CONFIG_ROUTER = "config";
private static final String SECRETS_ROUTER = "secrets";

private static final Map<String, Object> TEST_SERVICE_SETTINGS = Map.of(
"model",
"my_model",
"dimensions",
256,
"similarity",
"cosine",
"api_key",
"my_api_key"
);

public static class LocalStateIndexSettingsInferencePlugin extends LocalStateCompositeXPackPlugin {
private final InferencePlugin inferencePlugin;

public LocalStateIndexSettingsInferencePlugin(final Settings settings, final Path configPath) throws Exception {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can think of a better way to direct the inference configuration index and secrets index to separate nodes let me know. I'm injecting some settings that won't be present in production but allows us to direct the documents to specific nodes for easier testing.

super(settings, configPath);
var thisVar = this;
this.inferencePlugin = new InferencePlugin(settings) {
@Override
protected SSLService getSslService() {
return thisVar.getSslService();
}

@Override
protected XPackLicenseState getLicenseState() {
return thisVar.getLicenseState();
}

@Override
public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
return List.of(
TestSparseInferenceServiceExtension.TestInferenceService::new,
TestDenseInferenceServiceExtension.TestInferenceService::new
);
}

@Override
public Settings getIndexSettings() {
return InferenceIndex.builder()
.put(Settings.builder().put("index.routing.allocation.require.index_router", "config").build())
.build();
}

@Override
public Settings getSecretsIndexSettings() {
return InferenceSecretsIndex.builder()
.put(Settings.builder().put("index.routing.allocation.require.index_router", "secrets").build())
.build();
}
};
plugins.add(inferencePlugin);
}

}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(LocalStateIndexSettingsInferencePlugin.class, TestInferenceServicePlugin.class);
}

public void testRetrievingInferenceEndpoint_ThrowsException_WhenIndexNodeIsNotAvailable() throws Exception {
final var configIndexNodeAttributes = Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, CONFIG_ROUTER).build();

internalCluster().startMasterOnlyNode(configIndexNodeAttributes);
final var configIndexDataNodes = internalCluster().startDataOnlyNode(configIndexNodeAttributes);

internalCluster().startDataOnlyNode(Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, SECRETS_ROUTER).build());

final var inferenceId = "test-index-id";
createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, TEST_SERVICE_SETTINGS);

// Ensure the inference indices are created and we can retrieve the inference endpoint
var getInferenceEndpointRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.TEXT_EMBEDDING, true);
var responseFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest);
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getEndpoints().get(0).getInferenceEntityId(), equalTo(inferenceId));

// stop the node that holds the inference index
internalCluster().stopNode(configIndexDataNodes);

var responseFailureFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest);
var exception = expectThrows(ElasticsearchException.class, () -> responseFailureFuture.actionGet(TEST_REQUEST_TIMEOUT));
assertThat(exception.toString(), containsString("Failed to load inference endpoint [test-index-id]"));

var causeException = exception.getCause();
assertThat(causeException, instanceOf(SearchPhaseExecutionException.class));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried adding more assertThat's for looking for certain text in the search phase execution exception, but they exact wording changes during my test runs and was causing the test to be flaky.

}

public void testRetrievingInferenceEndpoint_ThrowsException_WhenIndexNodeIsNotAvailable_ForInferenceAction() throws Exception {
final var configIndexNodeAttributes = Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, CONFIG_ROUTER).build();

internalCluster().startMasterOnlyNode(configIndexNodeAttributes);
final var configIndexDataNodes = internalCluster().startDataOnlyNode(configIndexNodeAttributes);

internalCluster().startDataOnlyNode(Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, SECRETS_ROUTER).build());

final var inferenceId = "test-index-id-2";
createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, TEST_SERVICE_SETTINGS);

// Ensure the inference indices are created and we can retrieve the inference endpoint
var getInferenceEndpointRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.TEXT_EMBEDDING, true);
var responseFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest);
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getEndpoints().get(0).getInferenceEntityId(), equalTo(inferenceId));

// stop the node that holds the inference index
internalCluster().stopNode(configIndexDataNodes);

var proxyResponse = sendInferenceProxyRequest(inferenceId);
var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT));
assertThat(exception.toString(), containsString("Failed to load inference endpoint with secrets [test-index-id-2]"));

var causeException = exception.getCause();
assertThat(causeException, instanceOf(SearchPhaseExecutionException.class));
}

public void testRetrievingInferenceEndpoint_ThrowsException_WhenSecretsIndexNodeIsNotAvailable() throws Exception {
final var configIndexNodeAttributes = Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, CONFIG_ROUTER).build();
internalCluster().startMasterOnlyNode(configIndexNodeAttributes);
internalCluster().startDataOnlyNode(configIndexNodeAttributes);

var secretIndexDataNodes = internalCluster().startDataOnlyNode(
Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, SECRETS_ROUTER).build()
);

final var inferenceId = "test-secrets-index-id";
createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, TEST_SERVICE_SETTINGS);

// Ensure the inference indices are created and we can retrieve the inference endpoint
var getInferenceEndpointRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.TEXT_EMBEDDING, true);
var responseFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest);
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getEndpoints().get(0).getInferenceEntityId(), equalTo(inferenceId));

// stop the node that holds the inference secrets index
internalCluster().stopNode(secretIndexDataNodes);

var proxyResponse = sendInferenceProxyRequest(inferenceId);

var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT));
assertThat(exception.toString(), containsString("Failed to load inference endpoint with secrets [test-secrets-index-id]"));

var causeException = exception.getCause();

assertThat(causeException, instanceOf(SearchPhaseExecutionException.class));
}

private ActionFuture<InferenceAction.Response> sendInferenceProxyRequest(String inferenceId) throws IOException {
final BytesReference content;
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.startObject();
builder.field("input", List.of("test input"));
builder.endObject();

content = BytesReference.bytes(builder);
}

var inferenceRequest = new InferenceActionProxy.Request(
TaskType.TEXT_EMBEDDING,
inferenceId,
content,
XContentType.JSON,
TimeValue.THIRTY_SECONDS,
false,
InferenceContext.EMPTY_INSTANCE
);

return client().execute(InferenceActionProxy.INSTANCE, inferenceRequest);
}

private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map<String, Object> serviceSettings) throws IOException {
var responseFuture = createInferenceEndpointAsync(taskType, inferenceId, serviceSettings);
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId));
}

private ActionFuture<PutInferenceModelAction.Response> createInferenceEndpointAsync(
TaskType taskType,
String inferenceId,
Map<String, Object> serviceSettings
) throws IOException {
final BytesReference content;
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.startObject();
builder.field("service", TestDenseInferenceServiceExtension.TestInferenceService.NAME);
builder.field("service_settings", serviceSettings);
builder.endObject();

content = BytesReference.bytes(builder);
}

var request = new PutInferenceModelAction.Request(taskType, inferenceId, content, XContentType.JSON, TEST_REQUEST_TIMEOUT);
return client().execute(PutInferenceModelAction.INSTANCE, request);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ private InferenceIndex() {}
private static final int INDEX_MAPPING_VERSION = 2;

public static Settings settings() {
return Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
.put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1")
.build();
return builder().build();
}

// Public to allow tests to create the index with custom settings
public static Settings.Builder builder() {
return Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings sett
.setPrimaryIndex(InferenceIndex.INDEX_NAME)
.setDescription("Contains inference service and model configuration")
.setMappings(InferenceIndex.mappings())
.setSettings(InferenceIndex.settings())
.setSettings(getIndexSettings())
.setOrigin(ClientHelper.INFERENCE_ORIGIN)
.setPriorSystemIndexDescriptors(List.of(inferenceIndexV1Descriptor))
.build(),
Expand All @@ -476,13 +476,23 @@ public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings sett
.setPrimaryIndex(InferenceSecretsIndex.INDEX_NAME)
.setDescription("Contains inference service secrets")
.setMappings(InferenceSecretsIndex.mappings())
.setSettings(InferenceSecretsIndex.settings())
.setSettings(getSecretsIndexSettings())
.setOrigin(ClientHelper.INFERENCE_ORIGIN)
.setNetNew()
.build()
);
}

// Overridable for tests
protected Settings getIndexSettings() {
return InferenceIndex.settings();
}

// Overridable for tests
protected Settings getSecretsIndexSettings() {
return InferenceSecretsIndex.settings();
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settingsToUse) {
return List.of(inferenceUtilityExecutor(settings));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ private InferenceSecretsIndex() {}
private static final int INDEX_MAPPING_VERSION = 1;

public static Settings settings() {
return Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
.put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1")
.build();
return builder().build();
}

// Public to allow tests to create the index with custom settings
public static Settings.Builder builder() {
return Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,25 +249,34 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId
* @param listener Model listener
*/
public void getModelWithSecrets(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
ActionListener<SearchResponse> searchListener = ActionListener.wrap((searchResponse) -> {
// There should be a hit for the configurations
if (searchResponse.getHits().getHits().length == 0) {
var maybeDefault = defaultConfigIds.get(inferenceEntityId);
if (maybeDefault != null) {
getDefaultConfig(true, maybeDefault, listener);
} else {
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
listener.onFailure(inferenceNotFoundException(inferenceEntityId));
}
return;
}

delegate.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId)));
listener.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId)));
}, (e) -> {
logger.warn(format("Failed to load inference endpoint with secrets [%s]", inferenceEntityId), e);
listener.onFailure(
new ElasticsearchException(
format("Failed to load inference endpoint with secrets [%s], error: [%s]", inferenceEntityId, e.getMessage()),
e
)
);
});

QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId);
SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN)
.setQuery(queryBuilder)
.setSize(2)
.setAllowPartialSearchResults(false)
.request();

client.search(modelSearch, searchListener);
Expand All @@ -280,21 +289,29 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener<Unparse
* @param listener Model listener
*/
public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
ActionListener<SearchResponse> searchListener = ActionListener.wrap((searchResponse) -> {
// There should be a hit for the configurations
if (searchResponse.getHits().getHits().length == 0) {
var maybeDefault = defaultConfigIds.get(inferenceEntityId);
if (maybeDefault != null) {
getDefaultConfig(true, maybeDefault, listener);
} else {
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
listener.onFailure(inferenceNotFoundException(inferenceEntityId));
}
return;
}

var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
assert modelConfigs.size() == 1;
delegate.onResponse(modelConfigs.get(0));
listener.onResponse(modelConfigs.get(0));
}, e -> {
logger.warn(format("Failed to load inference endpoint [%s]", inferenceEntityId), e);
listener.onFailure(
new ElasticsearchException(
format("Failed to load inference endpoint [%s], error: [%s]", inferenceEntityId, e.getMessage()),
e
)
);
});

QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId);
Expand Down