Skip to content

Commit 4fac584

Browse files
authored
[ML] Protect against multiple concurrent downloads of the same model (#116869) (#117007)
Check for current downloading tasks in the download action. # Conflicts: # x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java
1 parent 33dfe55 commit 4fac584

File tree

12 files changed

+295
-142
lines changed

12 files changed

+295
-142
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
package org.elasticsearch.xpack.inference;
99

10+
import org.elasticsearch.client.Response;
11+
import org.elasticsearch.client.ResponseListener;
12+
import org.elasticsearch.common.Strings;
1013
import org.elasticsearch.inference.TaskType;
1114
import org.elasticsearch.threadpool.TestThreadPool;
1215
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
@@ -15,9 +18,12 @@
1518
import org.junit.Before;
1619

1720
import java.io.IOException;
21+
import java.util.ArrayList;
1822
import java.util.List;
1923
import java.util.Map;
24+
import java.util.concurrent.CountDownLatch;
2025

26+
import static org.hamcrest.Matchers.empty;
2127
import static org.hamcrest.Matchers.hasSize;
2228
import static org.hamcrest.Matchers.is;
2329
import static org.hamcrest.Matchers.oneOf;
@@ -100,4 +106,37 @@ private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
100106
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32))
101107
);
102108
}
109+
110+
public void testMultipleInferencesTriggeringDownloadAndDeploy() throws InterruptedException {
111+
int numParallelRequests = 4;
112+
var latch = new CountDownLatch(numParallelRequests);
113+
var errors = new ArrayList<Exception>();
114+
115+
var listener = new ResponseListener() {
116+
@Override
117+
public void onSuccess(Response response) {
118+
latch.countDown();
119+
}
120+
121+
@Override
122+
public void onFailure(Exception exception) {
123+
errors.add(exception);
124+
latch.countDown();
125+
}
126+
};
127+
128+
var inputs = List.of("Hello World", "Goodnight moon");
129+
var queryParams = Map.of("timeout", "120s");
130+
for (int i = 0; i < numParallelRequests; i++) {
131+
var request = createInferenceRequest(
132+
Strings.format("_inference/%s", ElasticsearchInternalService.DEFAULT_ELSER_ID),
133+
inputs,
134+
queryParams
135+
);
136+
client().performRequestAsync(request, listener);
137+
}
138+
139+
latch.await();
140+
assertThat(errors.toString(), errors, empty());
141+
}
103142
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,17 @@ protected Map<String, Object> infer(String modelId, TaskType taskType, List<Stri
373373
return inferInternal(endpoint, input, queryParameters);
374374
}
375375

376-
private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
376+
protected Request createInferenceRequest(String endpoint, List<String> input, Map<String, String> queryParameters) {
377377
var request = new Request("POST", endpoint);
378378
request.setJsonEntity(jsonBody(input));
379379
if (queryParameters.isEmpty() == false) {
380380
request.addParameters(queryParameters);
381381
}
382+
return request;
383+
}
384+
385+
private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
386+
var request = createInferenceRequest(endpoint, input, queryParameters);
382387
var response = client().performRequest(request);
383388
assertOkOrCreated(response);
384389
return entityAsMap(response);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,9 @@
77

88
package org.elasticsearch.xpack.inference.services.elasticsearch;
99

10-
import org.elasticsearch.ResourceNotFoundException;
11-
import org.elasticsearch.action.ActionListener;
1210
import org.elasticsearch.inference.ChunkingSettings;
13-
import org.elasticsearch.inference.Model;
1411
import org.elasticsearch.inference.TaskSettings;
1512
import org.elasticsearch.inference.TaskType;
16-
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
17-
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1813

1914
public class CustomElandModel extends ElasticsearchInternalModel {
2015

@@ -39,31 +34,10 @@ public CustomElandModel(
3934
}
4035

4136
@Override
42-
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
43-
Model model,
44-
ActionListener<Boolean> listener
45-
) {
46-
47-
return new ActionListener<>() {
48-
@Override
49-
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
50-
listener.onResponse(Boolean.TRUE);
51-
}
52-
53-
@Override
54-
public void onFailure(Exception e) {
55-
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
56-
listener.onFailure(
57-
new ResourceNotFoundException(
58-
"Could not start the inference as the custom eland model [{0}] for this platform cannot be found."
59-
+ " Custom models need to be loaded into the cluster with eland before they can be started.",
60-
internalServiceSettings.modelId()
61-
)
62-
);
63-
return;
64-
}
65-
listener.onFailure(e);
66-
}
67-
};
37+
protected String modelNotFoundErrorMessage(String modelId) {
38+
return "Could not deploy model ["
39+
+ modelId
40+
+ "] as the model cannot be found."
41+
+ " Custom models need to be loaded into the cluster with Eland before they can be started.";
6842
}
6943
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA
3636
throw new IllegalStateException("cannot start model that uses an existing deployment");
3737
}
3838

39+
@Override
40+
protected String modelNotFoundErrorMessage(String modelId) {
41+
throw new IllegalStateException("cannot start model [" + modelId + "] that uses an existing deployment");
42+
}
43+
3944
@Override
4045
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
4146
Model model,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
package org.elasticsearch.xpack.inference.services.elasticsearch;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.ResourceAlreadyExistsException;
12+
import org.elasticsearch.ResourceNotFoundException;
1013
import org.elasticsearch.action.ActionListener;
1114
import org.elasticsearch.common.Strings;
1215
import org.elasticsearch.core.TimeValue;
@@ -15,8 +18,10 @@
1518
import org.elasticsearch.inference.ModelConfigurations;
1619
import org.elasticsearch.inference.TaskSettings;
1720
import org.elasticsearch.inference.TaskType;
21+
import org.elasticsearch.rest.RestStatus;
1822
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
1923
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
24+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2025

2126
import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED;
2227

@@ -79,10 +84,38 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA
7984
return startRequest;
8085
}
8186

82-
public abstract ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
87+
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
8388
Model model,
8489
ActionListener<Boolean> listener
85-
);
90+
) {
91+
return new ActionListener<>() {
92+
@Override
93+
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
94+
listener.onResponse(Boolean.TRUE);
95+
}
96+
97+
@Override
98+
public void onFailure(Exception e) {
99+
var cause = ExceptionsHelper.unwrapCause(e);
100+
if (cause instanceof ResourceNotFoundException) {
101+
listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(internalServiceSettings.modelId())));
102+
return;
103+
} else if (cause instanceof ElasticsearchStatusException statusException) {
104+
if (statusException.status() == RestStatus.CONFLICT
105+
&& statusException.getRootCause() instanceof ResourceAlreadyExistsException) {
106+
// Deployment is already started
107+
listener.onResponse(Boolean.TRUE);
108+
}
109+
return;
110+
}
111+
listener.onFailure(e);
112+
}
113+
};
114+
}
115+
116+
protected String modelNotFoundErrorMessage(String modelId) {
117+
return "Could not deploy model [" + modelId + "] as the model cannot be found.";
118+
}
86119

87120
public boolean usesExistingDeployment() {
88121
return internalServiceSettings.getDeploymentId() != null;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.services.elasticsearch;
99

10-
import org.elasticsearch.ResourceNotFoundException;
11-
import org.elasticsearch.action.ActionListener;
1210
import org.elasticsearch.inference.ChunkingSettings;
13-
import org.elasticsearch.inference.Model;
1411
import org.elasticsearch.inference.TaskType;
15-
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
16-
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1712

1813
public class ElserInternalModel extends ElasticsearchInternalModel {
1914

@@ -37,31 +32,4 @@ public ElserInternalServiceSettings getServiceSettings() {
3732
public ElserMlNodeTaskSettings getTaskSettings() {
3833
return (ElserMlNodeTaskSettings) super.getTaskSettings();
3934
}
40-
41-
@Override
42-
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
43-
Model model,
44-
ActionListener<Boolean> listener
45-
) {
46-
return new ActionListener<>() {
47-
@Override
48-
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
49-
listener.onResponse(Boolean.TRUE);
50-
}
51-
52-
@Override
53-
public void onFailure(Exception e) {
54-
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
55-
listener.onFailure(
56-
new ResourceNotFoundException(
57-
"Could not start the ELSER service as the ELSER model for this platform cannot be found."
58-
+ " ELSER needs to be downloaded before it can be started."
59-
)
60-
);
61-
return;
62-
}
63-
listener.onFailure(e);
64-
}
65-
};
66-
}
6735
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.services.elasticsearch;
99

10-
import org.elasticsearch.ResourceNotFoundException;
11-
import org.elasticsearch.action.ActionListener;
1210
import org.elasticsearch.inference.ChunkingSettings;
13-
import org.elasticsearch.inference.Model;
1411
import org.elasticsearch.inference.TaskType;
15-
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
16-
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1712

1813
public class MultilingualE5SmallModel extends ElasticsearchInternalModel {
1914

@@ -31,34 +26,4 @@ public MultilingualE5SmallModel(
3126
public MultilingualE5SmallInternalServiceSettings getServiceSettings() {
3227
return (MultilingualE5SmallInternalServiceSettings) super.getServiceSettings();
3328
}
34-
35-
@Override
36-
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
37-
Model model,
38-
ActionListener<Boolean> listener
39-
) {
40-
41-
return new ActionListener<>() {
42-
@Override
43-
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
44-
listener.onResponse(Boolean.TRUE);
45-
}
46-
47-
@Override
48-
public void onFailure(Exception e) {
49-
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
50-
listener.onFailure(
51-
new ResourceNotFoundException(
52-
"Could not start the TextEmbeddingService service as the "
53-
+ "Multilingual-E5-Small model for this platform cannot be found."
54-
+ " Multilingual-E5-Small needs to be downloaded before it can be started"
55-
)
56-
);
57-
return;
58-
}
59-
listener.onFailure(e);
60-
}
61-
};
62-
}
63-
6429
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.packageloader.action;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.master.AcknowledgedResponse;
12+
import org.elasticsearch.tasks.RemovedTaskListener;
13+
import org.elasticsearch.tasks.Task;
14+
15+
public record DownloadTaskRemovedListener(ModelDownloadTask trackedTask, ActionListener<AcknowledgedResponse> listener)
16+
implements
17+
RemovedTaskListener {
18+
19+
@Override
20+
public void onRemoved(Task task) {
21+
if (task.getId() == trackedTask.getId()) {
22+
if (trackedTask.getTaskException() == null) {
23+
listener.onResponse(AcknowledgedResponse.TRUE);
24+
} else {
25+
listener.onFailure(trackedTask.getTaskException());
26+
}
27+
}
28+
}
29+
}

x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTask.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.tasks.Task;
1414
import org.elasticsearch.tasks.TaskId;
1515
import org.elasticsearch.xcontent.XContentBuilder;
16+
import org.elasticsearch.xpack.core.ml.MlTasks;
1617

1718
import java.io.IOException;
1819
import java.util.Map;
@@ -51,9 +52,12 @@ public void writeTo(StreamOutput out) throws IOException {
5152
}
5253

5354
private final AtomicReference<DownLoadProgress> downloadProgress = new AtomicReference<>(new DownLoadProgress(0, 0));
55+
private final String modelId;
56+
private volatile Exception taskException;
5457

55-
public ModelDownloadTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
56-
super(id, type, action, description, parentTaskId, headers);
58+
public ModelDownloadTask(long id, String type, String action, String modelId, TaskId parentTaskId, Map<String, String> headers) {
59+
super(id, type, action, taskDescription(modelId), parentTaskId, headers);
60+
this.modelId = modelId;
5761
}
5862

5963
void setProgress(int totalParts, int downloadedParts) {
@@ -65,4 +69,19 @@ public DownloadStatus getStatus() {
6569
return new DownloadStatus(downloadProgress.get());
6670
}
6771

72+
public String getModelId() {
73+
return modelId;
74+
}
75+
76+
public void setTaskException(Exception exception) {
77+
this.taskException = exception;
78+
}
79+
80+
public Exception getTaskException() {
81+
return taskException;
82+
}
83+
84+
public static String taskDescription(String modelId) {
85+
return MlTasks.downloadModelTaskDescription(modelId);
86+
}
6887
}

0 commit comments

Comments
 (0)