Skip to content

Commit 926c4fe

Browse files
Wrapping exception
1 parent 796d94e commit 926c4fe

File tree

2 files changed

+64
-14
lines changed

2 files changed

+64
-14
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.integration;
99

10+
import org.elasticsearch.ElasticsearchException;
1011
import org.elasticsearch.action.ActionFuture;
1112
import org.elasticsearch.action.search.SearchPhaseExecutionException;
1213
import org.elasticsearch.common.bytes.BytesReference;
@@ -44,6 +45,7 @@
4445

4546
import static org.hamcrest.CoreMatchers.containsString;
4647
import static org.hamcrest.CoreMatchers.equalTo;
48+
import static org.hamcrest.Matchers.instanceOf;
4749

4850
@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435
4951
public class InferenceIndicesIT extends ESIntegTestCase {
@@ -137,11 +139,40 @@ public void testRetrievingInferenceEndpoint_ThrowsException_WhenIndexNodeIsNotAv
137139
internalCluster().stopNode(configIndexDataNodes);
138140

139141
var responseFailureFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest);
140-
var exception = expectThrows(SearchPhaseExecutionException.class, () -> responseFailureFuture.actionGet(TEST_REQUEST_TIMEOUT));
142+
var exception = expectThrows(ElasticsearchException.class, () -> responseFailureFuture.actionGet(TEST_REQUEST_TIMEOUT));
143+
assertThat(exception.toString(), containsString("Failed to load inference endpoint [test-index-id]"));
141144

142-
assertThat(exception.toString(), containsString("all shards failed"));
143-
assertThat(exception.toString(), containsString("Node not connected"));
144-
assertThat(exception.toString(), containsString(".inference"));
145+
var causeException = exception.getCause();
146+
assertThat(causeException, instanceOf(SearchPhaseExecutionException.class));
147+
assertThat(causeException.toString(), containsString(".inference"));
148+
}
149+
150+
public void testRetrievingInferenceEndpoint_ThrowsException_WhenIndexNodeIsNotAvailable_ForInferenceAction() throws Exception {
151+
final var configIndexNodeAttributes = Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, CONFIG_ROUTER).build();
152+
153+
internalCluster().startMasterOnlyNode(configIndexNodeAttributes);
154+
final var configIndexDataNodes = internalCluster().startDataOnlyNode(configIndexNodeAttributes);
155+
156+
internalCluster().startDataOnlyNode(Settings.builder().put(INDEX_ROUTER_ATTRIBUTE, SECRETS_ROUTER).build());
157+
158+
final var inferenceId = "test-index-id-2";
159+
createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, TEST_SERVICE_SETTINGS);
160+
161+
// Ensure the inference indices are created and we can retrieve the inference endpoint
162+
var getInferenceEndpointRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.TEXT_EMBEDDING, true);
163+
var responseFuture = client().execute(GetInferenceModelAction.INSTANCE, getInferenceEndpointRequest);
164+
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getEndpoints().get(0).getInferenceEntityId(), equalTo(inferenceId));
165+
166+
// stop the node that holds the inference index
167+
internalCluster().stopNode(configIndexDataNodes);
168+
169+
var proxyResponse = sendInferenceProxyRequest(inferenceId);
170+
var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT));
171+
assertThat(exception.toString(), containsString("Failed to load inference endpoint [test-index-id-2]"));
172+
173+
var causeException = exception.getCause();
174+
assertThat(causeException, instanceOf(SearchPhaseExecutionException.class));
175+
assertThat(causeException.toString(), containsString(".inference"));
145176
}
146177

147178
public void testRetrievingInferenceEndpoint_ThrowsException_WhenSecretsIndexNodeIsNotAvailable() throws Exception {
@@ -165,11 +196,14 @@ public void testRetrievingInferenceEndpoint_ThrowsException_WhenSecretsIndexNode
165196
internalCluster().stopNode(secretIndexDataNodes);
166197

167198
var proxyResponse = sendInferenceProxyRequest(inferenceId);
168-
var exception = expectThrows(SearchPhaseExecutionException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT));
169199

170-
assertThat(exception.toString(), containsString("shards failure"));
171-
assertThat(exception.toString(), containsString("Node not connected"));
172-
assertThat(exception.toString(), containsString(".secrets-inference"));
200+
var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT));
201+
assertThat(exception.toString(), containsString("Failed to load inference endpoint [test-secrets-index-id]"));
202+
203+
var causeException = exception.getCause();
204+
205+
assertThat(causeException, instanceOf(SearchPhaseExecutionException.class));
206+
assertThat(causeException.toString(), containsString(".secrets-inference"));
173207
}
174208

175209
private ActionFuture<InferenceAction.Response> sendInferenceProxyRequest(String inferenceId) throws IOException {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,19 +249,27 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId
249249
* @param listener Model listener
250250
*/
251251
public void getModelWithSecrets(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
252-
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
252+
ActionListener<SearchResponse> searchListener = ActionListener.wrap((searchResponse) -> {
253253
// There should be a hit for the configurations
254254
if (searchResponse.getHits().getHits().length == 0) {
255255
var maybeDefault = defaultConfigIds.get(inferenceEntityId);
256256
if (maybeDefault != null) {
257257
getDefaultConfig(true, maybeDefault, listener);
258258
} else {
259-
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
259+
listener.onFailure(inferenceNotFoundException(inferenceEntityId));
260260
}
261261
return;
262262
}
263263

264-
delegate.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId)));
264+
listener.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId)));
265+
}, (e) -> {
266+
logger.warn(format("Failed to load inference endpoint [%s]", inferenceEntityId), e);
267+
listener.onFailure(
268+
new ElasticsearchException(
269+
format("Failed to load inference endpoint [%s], error: [%s]", inferenceEntityId, e.getMessage()),
270+
e
271+
)
272+
);
265273
});
266274

267275
QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId);
@@ -281,21 +289,29 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener<Unparse
281289
* @param listener Model listener
282290
*/
283291
public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
284-
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
292+
ActionListener<SearchResponse> searchListener = ActionListener.wrap((searchResponse) -> {
285293
// There should be a hit for the configurations
286294
if (searchResponse.getHits().getHits().length == 0) {
287295
var maybeDefault = defaultConfigIds.get(inferenceEntityId);
288296
if (maybeDefault != null) {
289297
getDefaultConfig(true, maybeDefault, listener);
290298
} else {
291-
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
299+
listener.onFailure(inferenceNotFoundException(inferenceEntityId));
292300
}
293301
return;
294302
}
295303

296304
var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
297305
assert modelConfigs.size() == 1;
298-
delegate.onResponse(modelConfigs.get(0));
306+
listener.onResponse(modelConfigs.get(0));
307+
}, e -> {
308+
logger.warn(format("Failed to load inference endpoint [%s]", inferenceEntityId), e);
309+
listener.onFailure(
310+
new ElasticsearchException(
311+
format("Failed to load inference endpoint [%s], error: [%s]", inferenceEntityId, e.getMessage()),
312+
e
313+
)
314+
);
299315
});
300316

301317
QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId);

0 commit comments

Comments
 (0)