Skip to content

Commit 59602a9

Browse files
authored
[ML] Pass inference timeout to start deployment (#116725)
Default inference endpoints automatically deploy the model on inference the inference timeout is now passed to start model deployment so users can control that timeout
1 parent bd091d3 commit 59602a9

File tree

7 files changed

+31
-29
lines changed

7 files changed

+31
-29
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,10 @@ void chunkedInfer(
139139
/**
140140
* Start or prepare the model for use.
141141
* @param model The model
142+
* @param timeout Start timeout
142143
* @param listener The listener
143144
*/
144-
void start(Model model, ActionListener<Boolean> listener);
145+
void start(Model model, TimeValue timeout, ActionListener<Boolean> listener);
145146

146147
/**
147148
* Stop the model deployment.
@@ -153,17 +154,6 @@ default void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener)
153154
listener.onResponse(true);
154155
}
155156

156-
/**
157-
* Put the model definition (if applicable)
158-
* The main purpose of this function is to download ELSER
159-
* The default action does nothing except acknowledge the request (true).
160-
* @param modelVariant The configuration of the model variant to be downloaded
161-
* @param listener The listener
162-
*/
163-
default void putModel(Model modelVariant, ActionListener<Boolean> listener) {
164-
listener.onResponse(true);
165-
}
166-
167157
/**
168158
* Optionally test the new model configuration in the inference service.
169159
* This function should be called when the model is first created, the

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.common.ValidationException;
1313
import org.elasticsearch.common.io.stream.StreamInput;
1414
import org.elasticsearch.common.io.stream.StreamOutput;
15+
import org.elasticsearch.core.TimeValue;
1516
import org.elasticsearch.inference.InferenceService;
1617
import org.elasticsearch.inference.Model;
1718
import org.elasticsearch.inference.ModelConfigurations;
@@ -90,7 +91,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String,
9091
protected abstract ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap);
9192

9293
@Override
93-
public void start(Model model, ActionListener<Boolean> listener) {
94+
public void start(Model model, TimeValue timeout, ActionListener<Boolean> listener) {
9495
listener.onResponse(true);
9596
}
9697

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.common.settings.Settings;
2323
import org.elasticsearch.common.util.concurrent.EsExecutors;
2424
import org.elasticsearch.common.xcontent.XContentHelper;
25+
import org.elasticsearch.core.TimeValue;
2526
import org.elasticsearch.index.mapper.StrictDynamicMappingException;
2627
import org.elasticsearch.inference.InferenceService;
2728
import org.elasticsearch.inference.InferenceServiceRegistry;
@@ -159,20 +160,21 @@ protected void masterOperation(
159160
return;
160161
}
161162

162-
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, listener);
163+
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
163164
}
164165

165166
private void parseAndStoreModel(
166167
InferenceService service,
167168
String inferenceEntityId,
168169
TaskType taskType,
169170
Map<String, Object> config,
171+
TimeValue timeout,
170172
ActionListener<PutInferenceModelAction.Response> listener
171173
) {
172174
ActionListener<Model> storeModelListener = listener.delegateFailureAndWrap(
173175
(delegate, verifiedModel) -> modelRegistry.storeModel(
174176
verifiedModel,
175-
ActionListener.wrap(r -> startInferenceEndpoint(service, verifiedModel, delegate), e -> {
177+
ActionListener.wrap(r -> startInferenceEndpoint(service, timeout, verifiedModel, delegate), e -> {
176178
if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) {
177179
delegate.onFailure(
178180
new ElasticsearchStatusException(
@@ -199,11 +201,16 @@ private void parseAndStoreModel(
199201
service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener);
200202
}
201203

202-
private void startInferenceEndpoint(InferenceService service, Model model, ActionListener<PutInferenceModelAction.Response> listener) {
204+
private void startInferenceEndpoint(
205+
InferenceService service,
206+
TimeValue timeout,
207+
Model model,
208+
ActionListener<PutInferenceModelAction.Response> listener
209+
) {
203210
if (skipValidationAndStart) {
204211
listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
205212
} else {
206-
service.start(model, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
213+
service.start(model, timeout, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
207214
}
208215
}
209216

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,16 @@ protected abstract void doChunkedInfer(
104104
ActionListener<List<ChunkedInferenceServiceResults>> listener
105105
);
106106

107-
@Override
108107
public void start(Model model, ActionListener<Boolean> listener) {
109108
init();
110-
111109
doStart(model, listener);
112110
}
113111

112+
@Override
113+
public void start(Model model, @Nullable TimeValue unused, ActionListener<Boolean> listener) {
114+
start(model, listener);
115+
}
116+
114117
protected void doStart(Model model, ActionListener<Boolean> listener) {
115118
listener.onResponse(true);
116119
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public BaseElasticsearchInternalService(
8383
}
8484

8585
@Override
86-
public void start(Model model, ActionListener<Boolean> finalListener) {
86+
public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalListener) {
8787
if (model instanceof ElasticsearchInternalModel esModel) {
8888
if (supportedTaskTypes().contains(model.getTaskType()) == false) {
8989
finalListener.onFailure(
@@ -107,7 +107,7 @@ public void start(Model model, ActionListener<Boolean> finalListener) {
107107
}
108108
})
109109
.<Boolean>andThen((l2, modelDidPut) -> {
110-
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest();
110+
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
111111
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener);
112112
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
113113
})
@@ -149,8 +149,7 @@ protected static IllegalStateException notElasticsearchModelException(Model mode
149149
);
150150
}
151151

152-
@Override
153-
public void putModel(Model model, ActionListener<Boolean> listener) {
152+
protected void putModel(Model model, ActionListener<Boolean> listener) {
154153
if (model instanceof ElasticsearchInternalModel == false) {
155154
listener.onFailure(notElasticsearchModelException(model));
156155
return;
@@ -303,10 +302,9 @@ protected void maybeStartDeployment(
303302
}
304303

305304
if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
306-
this.start(
307-
model,
308-
listener.delegateFailureAndWrap((l, started) -> { client.execute(InferModelAction.INSTANCE, request, listener); })
309-
);
305+
this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> {
306+
client.execute(InferModelAction.INSTANCE, request, listener);
307+
}));
310308
} else {
311309
listener.onFailure(e);
312310
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.elasticsearch;
99

1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.core.TimeValue;
1112
import org.elasticsearch.inference.ChunkingSettings;
1213
import org.elasticsearch.inference.Model;
1314
import org.elasticsearch.inference.TaskType;
@@ -31,7 +32,7 @@ public boolean usesExistingDeployment() {
3132
}
3233

3334
@Override
34-
public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest() {
35+
public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest(TimeValue timeout) {
3536
throw new IllegalStateException("cannot start model that uses an existing deployment");
3637
}
3738

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.common.Strings;
12+
import org.elasticsearch.core.TimeValue;
1213
import org.elasticsearch.inference.ChunkingSettings;
1314
import org.elasticsearch.inference.Model;
1415
import org.elasticsearch.inference.ModelConfigurations;
@@ -67,11 +68,12 @@ public ElasticsearchInternalModel(
6768
this.internalServiceSettings = internalServiceSettings;
6869
}
6970

70-
public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest() {
71+
public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest(TimeValue timeout) {
7172
var startRequest = new StartTrainedModelDeploymentAction.Request(internalServiceSettings.modelId(), this.getInferenceEntityId());
7273
startRequest.setNumberOfAllocations(internalServiceSettings.getNumAllocations());
7374
startRequest.setThreadsPerAllocation(internalServiceSettings.getNumThreads());
7475
startRequest.setAdaptiveAllocationsSettings(internalServiceSettings.getAdaptiveAllocationsSettings());
76+
startRequest.setTimeout(timeout);
7577
startRequest.setWaitForState(STARTED);
7678

7779
return startRequest;

0 commit comments

Comments
 (0)