Skip to content

Commit 1a7e1d5

Browse files
[ML] Adding missing onFailure call for Inference API start model request (elastic#126930)
* Adding missing onFailure call * Update docs/changelog/126930.yaml
1 parent 6763e06 commit 1a7e1d5

File tree

4 files changed

+55
-1
lines changed

4 files changed

+55
-1
lines changed

docs/changelog/126930.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126930
2+
summary: Adding missing `onFailure` call for Inference API start model request
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
106106
})
107107
.<Boolean>andThen((l2, modelDidPut) -> {
108108
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
109-
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener);
109+
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
110110
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
111111
})
112112
.addListener(finalListener);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ public void onFailure(Exception e) {
105105
&& statusException.getRootCause() instanceof ResourceAlreadyExistsException) {
106106
// Deployment is already started
107107
listener.onResponse(Boolean.TRUE);
108+
} else {
109+
listener.onFailure(e);
108110
}
109111
return;
110112
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.elasticsearch.inference.ModelConfigurations;
3838
import org.elasticsearch.inference.SimilarityMeasure;
3939
import org.elasticsearch.inference.TaskType;
40+
import org.elasticsearch.rest.RestStatus;
4041
import org.elasticsearch.test.ESTestCase;
4142
import org.elasticsearch.threadpool.ThreadPool;
4243
import org.elasticsearch.xcontent.ParseField;
@@ -49,13 +50,16 @@
4950
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
5051
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
5152
import org.elasticsearch.xpack.core.ml.MachineLearningField;
53+
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
5254
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
5355
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
5456
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
5557
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
5658
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
59+
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
5760
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
5861
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
62+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
5963
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
6064
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
6165
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
@@ -1832,6 +1836,49 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException {
18321836
}
18331837
}
18341838

1839+
public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException {
1840+
var model = new ElserInternalModel(
1841+
"inference_id",
1842+
TaskType.SPARSE_EMBEDDING,
1843+
"elasticsearch",
1844+
new ElserInternalServiceSettings(
1845+
new ElasticsearchInternalServiceSettings(1, 1, "id", new AdaptiveAllocationsSettings(false, 0, 0), null)
1846+
),
1847+
new ElserMlNodeTaskSettings(),
1848+
null
1849+
);
1850+
1851+
var client = mock(Client.class);
1852+
when(client.threadPool()).thenReturn(threadPool);
1853+
1854+
doAnswer(invocationOnMock -> {
1855+
ActionListener<GetTrainedModelsAction.Response> listener = invocationOnMock.getArgument(2);
1856+
var builder = GetTrainedModelsAction.Response.builder();
1857+
builder.setModels(List.of(mock(TrainedModelConfig.class)));
1858+
builder.setTotalCount(1);
1859+
1860+
listener.onResponse(builder.build());
1861+
return Void.TYPE;
1862+
}).when(client).execute(eq(GetTrainedModelsAction.INSTANCE), any(), any());
1863+
1864+
doAnswer(invocationOnMock -> {
1865+
ActionListener<CreateTrainedModelAssignmentAction.Response> listener = invocationOnMock.getArgument(2);
1866+
listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT));
1867+
return Void.TYPE;
1868+
}).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any());
1869+
1870+
try (var service = createService(client)) {
1871+
var actionListener = new PlainActionFuture<Boolean>();
1872+
service.start(model, TimeValue.timeValueSeconds(30), actionListener);
1873+
var exception = expectThrows(
1874+
ElasticsearchStatusException.class,
1875+
() -> actionListener.actionGet(TimeValue.timeValueSeconds(30))
1876+
);
1877+
1878+
assertThat(exception.getMessage(), is("failed"));
1879+
}
1880+
}
1881+
18351882
private ElasticsearchInternalService createService(Client client) {
18361883
var cs = mock(ClusterService.class);
18371884
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));

0 commit comments

Comments
 (0)