Skip to content

Commit da9e149

Browse files
[ML] Adding missing onFailure call for Inference API start model request (#126930) (#126945)
* Adding missing onFailure call * Update docs/changelog/126930.yaml
1 parent f333089 commit da9e149

File tree

4 files changed

+55
-1
lines changed

4 files changed

+55
-1
lines changed

docs/changelog/126930.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126930
2+
summary: Adding missing `onFailure` call for Inference API start model request
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
107107
})
108108
.<Boolean>andThen((l2, modelDidPut) -> {
109109
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
110-
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener);
110+
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
111111
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
112112
})
113113
.addListener(finalListener);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ public void onFailure(Exception e) {
105105
&& statusException.getRootCause() instanceof ResourceAlreadyExistsException) {
106106
// Deployment is already started
107107
listener.onResponse(Boolean.TRUE);
108+
} else {
109+
listener.onFailure(e);
108110
}
109111
return;
110112
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.elasticsearch.inference.ModelConfigurations;
3838
import org.elasticsearch.inference.SimilarityMeasure;
3939
import org.elasticsearch.inference.TaskType;
40+
import org.elasticsearch.rest.RestStatus;
4041
import org.elasticsearch.test.ESTestCase;
4142
import org.elasticsearch.threadpool.ThreadPool;
4243
import org.elasticsearch.xcontent.ParseField;
@@ -48,13 +49,16 @@
4849
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
4950
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
5051
import org.elasticsearch.xpack.core.ml.MachineLearningField;
52+
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
5153
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
5254
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
5355
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
5456
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
5557
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
58+
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
5659
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
5760
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
61+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
5862
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
5963
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
6064
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
@@ -1792,6 +1796,49 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException {
17921796
}
17931797
}
17941798

1799+
public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException {
1800+
var model = new ElserInternalModel(
1801+
"inference_id",
1802+
TaskType.SPARSE_EMBEDDING,
1803+
"elasticsearch",
1804+
new ElserInternalServiceSettings(
1805+
new ElasticsearchInternalServiceSettings(1, 1, "id", new AdaptiveAllocationsSettings(false, 0, 0), null)
1806+
),
1807+
new ElserMlNodeTaskSettings(),
1808+
null
1809+
);
1810+
1811+
var client = mock(Client.class);
1812+
when(client.threadPool()).thenReturn(threadPool);
1813+
1814+
doAnswer(invocationOnMock -> {
1815+
ActionListener<GetTrainedModelsAction.Response> listener = invocationOnMock.getArgument(2);
1816+
var builder = GetTrainedModelsAction.Response.builder();
1817+
builder.setModels(List.of(mock(TrainedModelConfig.class)));
1818+
builder.setTotalCount(1);
1819+
1820+
listener.onResponse(builder.build());
1821+
return Void.TYPE;
1822+
}).when(client).execute(eq(GetTrainedModelsAction.INSTANCE), any(), any());
1823+
1824+
doAnswer(invocationOnMock -> {
1825+
ActionListener<CreateTrainedModelAssignmentAction.Response> listener = invocationOnMock.getArgument(2);
1826+
listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT));
1827+
return Void.TYPE;
1828+
}).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any());
1829+
1830+
try (var service = createService(client)) {
1831+
var actionListener = new PlainActionFuture<Boolean>();
1832+
service.start(model, TimeValue.timeValueSeconds(30), actionListener);
1833+
var exception = expectThrows(
1834+
ElasticsearchStatusException.class,
1835+
() -> actionListener.actionGet(TimeValue.timeValueSeconds(30))
1836+
);
1837+
1838+
assertThat(exception.getMessage(), is("failed"));
1839+
}
1840+
}
1841+
17951842
private ElasticsearchInternalService createService(Client client) {
17961843
var cs = mock(ClusterService.class);
17971844
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));

0 commit comments

Comments
 (0)