Skip to content

Commit 5584aaa

Browse files
[ML] Inference API disable partial search results (#132362)
* Working on tests * Update docs/changelog/132362.yaml * Adding integration test * Wrapping exception * Fixing flaky tests * Removing assert * Refactoring testing functions
1 parent fe79a6e commit 5584aaa

File tree

6 files changed

+304
-16
lines changed

6 files changed

+304
-16
lines changed

docs/changelog/132362.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 132362
2+
summary: Inference API disable partial search results
3+
area: Machine Learning
4+
type: bug
5+
issues: []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.integration;
9+
10+
import org.elasticsearch.ElasticsearchException;
11+
import org.elasticsearch.action.ActionFuture;
12+
import org.elasticsearch.action.search.SearchPhaseExecutionException;
13+
import org.elasticsearch.common.bytes.BytesReference;
14+
import org.elasticsearch.common.settings.Settings;
15+
import org.elasticsearch.core.TimeValue;
16+
import org.elasticsearch.inference.InferenceServiceExtension;
17+
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.license.LicenseSettings;
19+
import org.elasticsearch.license.XPackLicenseState;
20+
import org.elasticsearch.plugins.Plugin;
21+
import org.elasticsearch.test.ESIntegTestCase;
22+
import org.elasticsearch.test.ESTestCase;
23+
import org.elasticsearch.xcontent.XContentBuilder;
24+
import org.elasticsearch.xcontent.XContentFactory;
25+
import org.elasticsearch.xcontent.XContentType;
26+
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
27+
import org.elasticsearch.xpack.core.inference.InferenceContext;
28+
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
29+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
30+
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
31+
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
32+
import org.elasticsearch.xpack.core.ssl.SSLService;
33+
import org.elasticsearch.xpack.inference.InferenceIndex;
34+
import org.elasticsearch.xpack.inference.InferencePlugin;
35+
import org.elasticsearch.xpack.inference.InferenceSecretsIndex;
36+
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
37+
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
38+
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
39+
40+
import java.io.IOException;
41+
import java.nio.file.Path;
42+
import java.util.Collection;
43+
import java.util.List;
44+
import java.util.Map;
45+
46+
import static org.hamcrest.CoreMatchers.containsString;
47+
import static org.hamcrest.CoreMatchers.equalTo;
48+
import static org.hamcrest.Matchers.instanceOf;
49+
50+
@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435
51+
public class InferenceIndicesIT extends ESIntegTestCase {
52+
53+
private static final String INDEX_ROUTER_ATTRIBUTE = "node.attr.index_router";
54+
private static final String CONFIG_ROUTER = "config";
55+
private static final String SECRETS_ROUTER = "secrets";
56+
57+
private static final Map<String, Object> TEST_SERVICE_SETTINGS = Map.of(
58+
"model",
59+
"my_model",
60+
"dimensions",
61+
256,
62+
"similarity",
63+
"cosine",
64+
"api_key",
65+
"my_api_key"
66+
);
67+
68+
public static class LocalStateIndexSettingsInferencePlugin extends LocalStateCompositeXPackPlugin {
69+
private final InferencePlugin inferencePlugin;
70+
71+
public LocalStateIndexSettingsInferencePlugin(final Settings settings, final Path configPath) throws Exception {
72+
super(settings, configPath);
73+
var thisVar = this;
74+
this.inferencePlugin = new InferencePlugin(settings) {
75+
@Override
76+
protected SSLService getSslService() {
77+
return thisVar.getSslService();
78+
}
79+
80+
@Override
81+
protected XPackLicenseState getLicenseState() {
82+
return thisVar.getLicenseState();
83+
}
84+
85+
@Override
86+
public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
87+
return List.of(
88+
TestSparseInferenceServiceExtension.TestInferenceService::new,
89+
TestDenseInferenceServiceExtension.TestInferenceService::new
90+
);
91+
}
92+
93+
@Override
94+
public Settings getIndexSettings() {
95+
return InferenceIndex.builder()
96+
.put(Settings.builder().put("index.routing.allocation.require.index_router", "config").build())
97+
.build();
98+
}
99+
100+
@Override
101+
public Settings getSecretsIndexSettings() {
102+
return InferenceSecretsIndex.builder()
103+
.put(Settings.builder().put("index.routing.allocation.require.index_router", "secrets").build())
104+
.build();
105+
}
106+
};
107+
plugins.add(inferencePlugin);
108+
}
109+
110+
}
111+
112+
@Override
113+
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
114+
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
115+
}
116+
117+
@Override
118+
protected Collection<Class<? extends Plugin>> nodePlugins() {
119+
return List.of(LocalStateIndexSettingsInferencePlugin.class, TestInferenceServicePlugin.class);
120+
}
121+
122+
public void testRetrievingInferenceEndpoint_ThrowsException_WhenIndexNodeIsNotAvailable() throws Exception {
123+
final var configIndexNodeAttributes = Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, CONFIG_ROUTER).build();
124+
125+
internalCluster().startMasterOnlyNode(configIndexNodeAttributes);
126+
final var configIndexDataNodes = internalCluster().startDataOnlyNode(configIndexNodeAttributes);
127+
128+
internalCluster().startDataOnlyNode(Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, SECRETS_ROUTER).build());
129+
130+
final var inferenceId = "test-index-id";
131+
createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, TEST_SERVICE_SETTINGS);
132+
133+
// Ensure the inference indices are created and we can retrieve the inference endpoint
134+
var getInferenceEndpointRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.TEXT_EMBEDDING, true);
135+
var responseFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest);
136+
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getEndpoints().get(0).getInferenceEntityId(), equalTo(inferenceId));
137+
138+
// stop the node that holds the inference index
139+
internalCluster().stopNode(configIndexDataNodes);
140+
141+
var responseFailureFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest);
142+
var exception = expectThrows(ElasticsearchException.class, () -> responseFailureFuture.actionGet(TEST_REQUEST_TIMEOUT));
143+
assertThat(exception.toString(), containsString("Failed to load inference endpoint [test-index-id]"));
144+
145+
var causeException = exception.getCause();
146+
assertThat(causeException, instanceOf(SearchPhaseExecutionException.class));
147+
}
148+
149+
public void testRetrievingInferenceEndpoint_ThrowsException_WhenIndexNodeIsNotAvailable_ForInferenceAction() throws Exception {
150+
final var configIndexNodeAttributes = Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, CONFIG_ROUTER).build();
151+
152+
internalCluster().startMasterOnlyNode(configIndexNodeAttributes);
153+
final var configIndexDataNodes = internalCluster().startDataOnlyNode(configIndexNodeAttributes);
154+
155+
internalCluster().startDataOnlyNode(Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, SECRETS_ROUTER).build());
156+
157+
final var inferenceId = "test-index-id-2";
158+
createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, TEST_SERVICE_SETTINGS);
159+
160+
// Ensure the inference indices are created and we can retrieve the inference endpoint
161+
var getInferenceEndpointRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.TEXT_EMBEDDING, true);
162+
var responseFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest);
163+
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getEndpoints().get(0).getInferenceEntityId(), equalTo(inferenceId));
164+
165+
// stop the node that holds the inference index
166+
internalCluster().stopNode(configIndexDataNodes);
167+
168+
var proxyResponse = sendInferenceProxyRequest(inferenceId);
169+
var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT));
170+
assertThat(exception.toString(), containsString("Failed to load inference endpoint with secrets [test-index-id-2]"));
171+
172+
var causeException = exception.getCause();
173+
assertThat(causeException, instanceOf(SearchPhaseExecutionException.class));
174+
}
175+
176+
public void testRetrievingInferenceEndpoint_ThrowsException_WhenSecretsIndexNodeIsNotAvailable() throws Exception {
177+
final var configIndexNodeAttributes = Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, CONFIG_ROUTER).build();
178+
internalCluster().startMasterOnlyNode(configIndexNodeAttributes);
179+
internalCluster().startDataOnlyNode(configIndexNodeAttributes);
180+
181+
var secretIndexDataNodes = internalCluster().startDataOnlyNode(
182+
Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, SECRETS_ROUTER).build()
183+
);
184+
185+
final var inferenceId = "test-secrets-index-id";
186+
createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, TEST_SERVICE_SETTINGS);
187+
188+
// Ensure the inference indices are created and we can retrieve the inference endpoint
189+
var getInferenceEndpointRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.TEXT_EMBEDDING, true);
190+
var responseFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest);
191+
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getEndpoints().get(0).getInferenceEntityId(), equalTo(inferenceId));
192+
193+
// stop the node that holds the inference secrets index
194+
internalCluster().stopNode(secretIndexDataNodes);
195+
196+
var proxyResponse = sendInferenceProxyRequest(inferenceId);
197+
198+
var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT));
199+
assertThat(exception.toString(), containsString("Failed to load inference endpoint with secrets [test-secrets-index-id]"));
200+
201+
var causeException = exception.getCause();
202+
203+
assertThat(causeException, instanceOf(SearchPhaseExecutionException.class));
204+
}
205+
206+
private ActionFuture<InferenceAction.Response> sendInferenceProxyRequest(String inferenceId) throws IOException {
207+
final BytesReference content;
208+
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
209+
builder.startObject();
210+
builder.field("input", List.of("test input"));
211+
builder.endObject();
212+
213+
content = BytesReference.bytes(builder);
214+
}
215+
216+
var inferenceRequest = new InferenceActionProxy.Request(
217+
TaskType.TEXT_EMBEDDING,
218+
inferenceId,
219+
content,
220+
XContentType.JSON,
221+
TimeValue.THIRTY_SECONDS,
222+
false,
223+
InferenceContext.EMPTY_INSTANCE
224+
);
225+
226+
return client().execute(InferenceActionProxy.INSTANCE, inferenceRequest);
227+
}
228+
229+
private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map<String, Object> serviceSettings) throws IOException {
230+
var responseFuture = createInferenceEndpointAsync(taskType, inferenceId, serviceSettings);
231+
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId));
232+
}
233+
234+
private ActionFuture<PutInferenceModelAction.Response> createInferenceEndpointAsync(
235+
TaskType taskType,
236+
String inferenceId,
237+
Map<String, Object> serviceSettings
238+
) throws IOException {
239+
final BytesReference content;
240+
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
241+
builder.startObject();
242+
builder.field("service", TestDenseInferenceServiceExtension.TestInferenceService.NAME);
243+
builder.field("service_settings", serviceSettings);
244+
builder.endObject();
245+
246+
content = BytesReference.bytes(builder);
247+
}
248+
249+
var request = new PutInferenceModelAction.Request(taskType, inferenceId, content, XContentType.JSON, TEST_REQUEST_TIMEOUT);
250+
return client().execute(PutInferenceModelAction.INSTANCE, request);
251+
}
252+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@ private InferenceIndex() {}
3030
private static final int INDEX_MAPPING_VERSION = 2;
3131

3232
public static Settings settings() {
33-
return Settings.builder()
34-
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
35-
.put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1")
36-
.build();
33+
return builder().build();
34+
}
35+
36+
// Public to allow tests to create the index with custom settings
37+
public static Settings.Builder builder() {
38+
return Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1");
3739
}
3840

3941
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings sett
466466
.setPrimaryIndex(InferenceIndex.INDEX_NAME)
467467
.setDescription("Contains inference service and model configuration")
468468
.setMappings(InferenceIndex.mappings())
469-
.setSettings(InferenceIndex.settings())
469+
.setSettings(getIndexSettings())
470470
.setOrigin(ClientHelper.INFERENCE_ORIGIN)
471471
.setPriorSystemIndexDescriptors(List.of(inferenceIndexV1Descriptor))
472472
.build(),
@@ -476,13 +476,23 @@ public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings sett
476476
.setPrimaryIndex(InferenceSecretsIndex.INDEX_NAME)
477477
.setDescription("Contains inference service secrets")
478478
.setMappings(InferenceSecretsIndex.mappings())
479-
.setSettings(InferenceSecretsIndex.settings())
479+
.setSettings(getSecretsIndexSettings())
480480
.setOrigin(ClientHelper.INFERENCE_ORIGIN)
481481
.setNetNew()
482482
.build()
483483
);
484484
}
485485

486+
// Overridable for tests
487+
protected Settings getIndexSettings() {
488+
return InferenceIndex.settings();
489+
}
490+
491+
// Overridable for tests
492+
protected Settings getSecretsIndexSettings() {
493+
return InferenceSecretsIndex.settings();
494+
}
495+
486496
@Override
487497
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settingsToUse) {
488498
return List.of(inferenceUtilityExecutor(settings));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceSecretsIndex.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ private InferenceSecretsIndex() {}
2929
private static final int INDEX_MAPPING_VERSION = 1;
3030

3131
public static Settings settings() {
32-
return Settings.builder()
33-
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
34-
.put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1")
35-
.build();
32+
return builder().build();
33+
}
34+
35+
// Public to allow tests to create the index with custom settings
36+
public static Settings.Builder builder() {
37+
return Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-1");
3638
}
3739

3840
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,25 +249,34 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId
249249
* @param listener Model listener
250250
*/
251251
public void getModelWithSecrets(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
252-
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
252+
ActionListener<SearchResponse> searchListener = ActionListener.wrap((searchResponse) -> {
253253
// There should be a hit for the configurations
254254
if (searchResponse.getHits().getHits().length == 0) {
255255
var maybeDefault = defaultConfigIds.get(inferenceEntityId);
256256
if (maybeDefault != null) {
257257
getDefaultConfig(true, maybeDefault, listener);
258258
} else {
259-
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
259+
listener.onFailure(inferenceNotFoundException(inferenceEntityId));
260260
}
261261
return;
262262
}
263263

264-
delegate.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId)));
264+
listener.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId)));
265+
}, (e) -> {
266+
logger.warn(format("Failed to load inference endpoint with secrets [%s]", inferenceEntityId), e);
267+
listener.onFailure(
268+
new ElasticsearchException(
269+
format("Failed to load inference endpoint with secrets [%s], error: [%s]", inferenceEntityId, e.getMessage()),
270+
e
271+
)
272+
);
265273
});
266274

267275
QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId);
268276
SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN)
269277
.setQuery(queryBuilder)
270278
.setSize(2)
279+
.setAllowPartialSearchResults(false)
271280
.request();
272281

273282
client.search(modelSearch, searchListener);
@@ -280,21 +289,29 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener<Unparse
280289
* @param listener Model listener
281290
*/
282291
public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
283-
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
292+
ActionListener<SearchResponse> searchListener = ActionListener.wrap((searchResponse) -> {
284293
// There should be a hit for the configurations
285294
if (searchResponse.getHits().getHits().length == 0) {
286295
var maybeDefault = defaultConfigIds.get(inferenceEntityId);
287296
if (maybeDefault != null) {
288297
getDefaultConfig(true, maybeDefault, listener);
289298
} else {
290-
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
299+
listener.onFailure(inferenceNotFoundException(inferenceEntityId));
291300
}
292301
return;
293302
}
294303

295304
var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
296305
assert modelConfigs.size() == 1;
297-
delegate.onResponse(modelConfigs.get(0));
306+
listener.onResponse(modelConfigs.get(0));
307+
}, e -> {
308+
logger.warn(format("Failed to load inference endpoint [%s]", inferenceEntityId), e);
309+
listener.onFailure(
310+
new ElasticsearchException(
311+
format("Failed to load inference endpoint [%s], error: [%s]", inferenceEntityId, e.getMessage()),
312+
e
313+
)
314+
);
298315
});
299316

300317
QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId);

0 commit comments

Comments
 (0)