|
37 | 37 | import org.elasticsearch.inference.ModelConfigurations; |
38 | 38 | import org.elasticsearch.inference.SimilarityMeasure; |
39 | 39 | import org.elasticsearch.inference.TaskType; |
| 40 | +import org.elasticsearch.rest.RestStatus; |
40 | 41 | import org.elasticsearch.test.ESTestCase; |
41 | 42 | import org.elasticsearch.threadpool.ThreadPool; |
42 | 43 | import org.elasticsearch.xcontent.ParseField; |
|
49 | 50 | import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; |
50 | 51 | import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; |
51 | 52 | import org.elasticsearch.xpack.core.ml.MachineLearningField; |
| 53 | +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; |
52 | 54 | import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; |
53 | 55 | import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; |
54 | 56 | import org.elasticsearch.xpack.core.ml.action.InferModelAction; |
55 | 57 | import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; |
56 | 58 | import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; |
| 59 | +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; |
57 | 60 | import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; |
58 | 61 | import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; |
| 62 | +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; |
59 | 63 | import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; |
60 | 64 | import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; |
61 | 65 | import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; |
@@ -1858,6 +1862,49 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException { |
1858 | 1862 | } |
1859 | 1863 | } |
1860 | 1864 |
|
| 1865 | + public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { |
| 1866 | + var model = new ElserInternalModel( |
| 1867 | + "inference_id", |
| 1868 | + TaskType.SPARSE_EMBEDDING, |
| 1869 | + "elasticsearch", |
| 1870 | + new ElserInternalServiceSettings( |
| 1871 | + new ElasticsearchInternalServiceSettings(1, 1, "id", new AdaptiveAllocationsSettings(false, 0, 0), null) |
| 1872 | + ), |
| 1873 | + new ElserMlNodeTaskSettings(), |
| 1874 | + null |
| 1875 | + ); |
| 1876 | + |
| 1877 | + var client = mock(Client.class); |
| 1878 | + when(client.threadPool()).thenReturn(threadPool); |
| 1879 | + |
| 1880 | + doAnswer(invocationOnMock -> { |
| 1881 | + ActionListener<GetTrainedModelsAction.Response> listener = invocationOnMock.getArgument(2); |
| 1882 | + var builder = GetTrainedModelsAction.Response.builder(); |
| 1883 | + builder.setModels(List.of(mock(TrainedModelConfig.class))); |
| 1884 | + builder.setTotalCount(1); |
| 1885 | + |
| 1886 | + listener.onResponse(builder.build()); |
| 1887 | + return Void.TYPE; |
| 1888 | + }).when(client).execute(eq(GetTrainedModelsAction.INSTANCE), any(), any()); |
| 1889 | + |
| 1890 | + doAnswer(invocationOnMock -> { |
| 1891 | + ActionListener<CreateTrainedModelAssignmentAction.Response> listener = invocationOnMock.getArgument(2); |
| 1892 | + listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT)); |
| 1893 | + return Void.TYPE; |
| 1894 | + }).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any()); |
| 1895 | + |
| 1896 | + try (var service = createService(client)) { |
| 1897 | + var actionListener = new PlainActionFuture<Boolean>(); |
| 1898 | + service.start(model, TimeValue.timeValueSeconds(30), actionListener); |
| 1899 | + var exception = expectThrows( |
| 1900 | + ElasticsearchStatusException.class, |
| 1901 | + () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) |
| 1902 | + ); |
| 1903 | + |
| 1904 | + assertThat(exception.getMessage(), is("failed")); |
| 1905 | + } |
| 1906 | + } |
| 1907 | + |
1861 | 1908 | private ElasticsearchInternalService createService(Client client) { |
1862 | 1909 | var cs = mock(ClusterService.class); |
1863 | 1910 | var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); |
|
0 commit comments