Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/129003.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129003
summary: Allow timeout during trained model download process
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.assignment;
package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ElasticsearchTimeoutException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
Expand All @@ -22,13 +23,15 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
Expand All @@ -41,12 +44,14 @@
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;

public abstract class BaseElasticsearchInternalService implements InferenceService {

protected final OriginSettingClient client;
protected final ThreadPool threadPool;
protected final ExecutorService inferenceExecutor;
protected final Consumer<ActionListener<PreferredModelVariant>> preferredModelVariantFn;
private final ClusterService clusterService;
Expand All @@ -60,6 +65,7 @@ public enum PreferredModelVariant {

public BaseElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN);
this.threadPool = context.threadPool();
this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
this.preferredModelVariantFn = this::preferredVariantFromPlatformArchitecture;
this.clusterService = context.clusterService();
Expand All @@ -75,6 +81,7 @@ public BaseElasticsearchInternalService(
Consumer<ActionListener<PreferredModelVariant>> preferredModelVariantFn
) {
this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN);
this.threadPool = context.threadPool();
this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
this.preferredModelVariantFn = preferredModelVariantFn;
this.clusterService = context.clusterService();
Expand All @@ -96,20 +103,38 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
return;
}

SubscribableListener.<Boolean>newForked(forkedListener -> { isBuiltinModelPut(model, forkedListener); })
.<Boolean>andThen((l, modelConfigExists) -> {
if (modelConfigExists == false) {
putModel(model, l);
} else {
l.onResponse(true);
}
})
.<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
})
.addListener(finalListener);
// instead of a subscribably listener, use some wait to wait for the first one.
var subscribableListener = SubscribableListener.<Boolean>newForked(
forkedListener -> { isBuiltinModelPut(model, forkedListener); }
).<Boolean>andThen((l, modelConfigExists) -> {
if (modelConfigExists == false) {
putModel(model, l);
} else {
l.onResponse(true);
}
}).<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
});
subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor);
subscribableListener.addListener(finalListener.delegateResponse((l, e) -> {
if (e instanceof ElasticsearchTimeoutException) {
l.onFailure(
new ModelDeploymentTimeoutException(
format(
"Timed out after [%s] waiting for trained model deployment for inference endpoint [%s] to start. "
+ "The inference endpoint can not be used to perform inference until the deployment has started. "
+ "Use the trained model stats API to track the state of the deployment.",
timeout,
model.getInferenceEntityId()
)
)
);
} else {
l.onFailure(e);
}
}));

} else {
finalListener.onFailure(notElasticsearchModelException(model));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
Expand All @@ -65,7 +66,6 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.assignment.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.assignment.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService;

import java.util.HashMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;

Expand Down
Loading