Skip to content

Commit e42c118

Browse files
[ML] Adding missing onFailure call for Inference API start model request (#126930)
* Adding missing onFailure call * Update docs/changelog/126930.yaml
1 parent 7d6fda5 commit e42c118

File tree

4 files changed

+55
-1
lines changed

4 files changed

+55
-1
lines changed

docs/changelog/126930.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126930
2+
summary: Adding missing `onFailure` call for Inference API start model request
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
106106
})
107107
.<Boolean>andThen((l2, modelDidPut) -> {
108108
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
109-
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener);
109+
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
110110
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
111111
})
112112
.addListener(finalListener);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ public void onFailure(Exception e) {
105105
&& statusException.getRootCause() instanceof ResourceAlreadyExistsException) {
106106
// Deployment is already started
107107
listener.onResponse(Boolean.TRUE);
108+
} else {
109+
listener.onFailure(e);
108110
}
109111
return;
110112
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.elasticsearch.inference.ModelConfigurations;
3838
import org.elasticsearch.inference.SimilarityMeasure;
3939
import org.elasticsearch.inference.TaskType;
40+
import org.elasticsearch.rest.RestStatus;
4041
import org.elasticsearch.test.ESTestCase;
4142
import org.elasticsearch.threadpool.ThreadPool;
4243
import org.elasticsearch.xcontent.ParseField;
@@ -49,13 +50,16 @@
4950
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
5051
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
5152
import org.elasticsearch.xpack.core.ml.MachineLearningField;
53+
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
5254
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
5355
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
5456
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
5557
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
5658
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
59+
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
5760
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
5861
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
62+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
5963
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
6064
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
6165
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
@@ -1858,6 +1862,49 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException {
18581862
}
18591863
}
18601864

1865+
public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException {
1866+
var model = new ElserInternalModel(
1867+
"inference_id",
1868+
TaskType.SPARSE_EMBEDDING,
1869+
"elasticsearch",
1870+
new ElserInternalServiceSettings(
1871+
new ElasticsearchInternalServiceSettings(1, 1, "id", new AdaptiveAllocationsSettings(false, 0, 0), null)
1872+
),
1873+
new ElserMlNodeTaskSettings(),
1874+
null
1875+
);
1876+
1877+
var client = mock(Client.class);
1878+
when(client.threadPool()).thenReturn(threadPool);
1879+
1880+
doAnswer(invocationOnMock -> {
1881+
ActionListener<GetTrainedModelsAction.Response> listener = invocationOnMock.getArgument(2);
1882+
var builder = GetTrainedModelsAction.Response.builder();
1883+
builder.setModels(List.of(mock(TrainedModelConfig.class)));
1884+
builder.setTotalCount(1);
1885+
1886+
listener.onResponse(builder.build());
1887+
return Void.TYPE;
1888+
}).when(client).execute(eq(GetTrainedModelsAction.INSTANCE), any(), any());
1889+
1890+
doAnswer(invocationOnMock -> {
1891+
ActionListener<CreateTrainedModelAssignmentAction.Response> listener = invocationOnMock.getArgument(2);
1892+
listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT));
1893+
return Void.TYPE;
1894+
}).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any());
1895+
1896+
try (var service = createService(client)) {
1897+
var actionListener = new PlainActionFuture<Boolean>();
1898+
service.start(model, TimeValue.timeValueSeconds(30), actionListener);
1899+
var exception = expectThrows(
1900+
ElasticsearchStatusException.class,
1901+
() -> actionListener.actionGet(TimeValue.timeValueSeconds(30))
1902+
);
1903+
1904+
assertThat(exception.getMessage(), is("failed"));
1905+
}
1906+
}
1907+
18611908
private ElasticsearchInternalService createService(Client client) {
18621909
var cs = mock(ClusterService.class);
18631910
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));

0 commit comments

Comments
 (0)