Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/129003.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129003
summary: Allow timeout during trained model download process
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.assignment;
package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ElasticsearchTimeoutException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
Expand All @@ -22,13 +23,15 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
Expand All @@ -41,12 +44,14 @@
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;

public abstract class BaseElasticsearchInternalService implements InferenceService {

protected final OriginSettingClient client;
protected final ThreadPool threadPool;
protected final ExecutorService inferenceExecutor;
protected final Consumer<ActionListener<PreferredModelVariant>> preferredModelVariantFn;
private final ClusterService clusterService;
Expand All @@ -60,6 +65,7 @@ public enum PreferredModelVariant {

public BaseElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN);
this.threadPool = context.threadPool();
this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
this.preferredModelVariantFn = this::preferredVariantFromPlatformArchitecture;
this.clusterService = context.clusterService();
Expand All @@ -75,6 +81,7 @@ public BaseElasticsearchInternalService(
Consumer<ActionListener<PreferredModelVariant>> preferredModelVariantFn
) {
this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN);
this.threadPool = context.threadPool();
this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
this.preferredModelVariantFn = preferredModelVariantFn;
this.clusterService = context.clusterService();
Expand All @@ -96,20 +103,38 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
return;
}

SubscribableListener.<Boolean>newForked(forkedListener -> { isBuiltinModelPut(model, forkedListener); })
.<Boolean>andThen((l, modelConfigExists) -> {
if (modelConfigExists == false) {
putModel(model, l);
} else {
l.onResponse(true);
}
})
.<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
})
.addListener(finalListener);
// instead of a subscribably listener, use some wait to wait for the first one.
var subscribableListener = SubscribableListener.<Boolean>newForked(
forkedListener -> { isBuiltinModelPut(model, forkedListener); }
).<Boolean>andThen((l, modelConfigExists) -> {
if (modelConfigExists == false) {
putModel(model, l);
} else {
l.onResponse(true);
}
}).<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
});
subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor);
subscribableListener.addListener(finalListener.delegateResponse((l, e) -> {
if (e instanceof ElasticsearchTimeoutException) {
l.onFailure(
new ModelDeploymentTimeoutException(
format(
"Timed out after [%s] waiting for trained model deployment for inference endpoint [%s] to start. "
+ "Use the trained model stats API to track the state of the deployment "
+ "and try again once it has started.",
timeout,
model.getInferenceEntityId()
)
)
);
} else {
l.onFailure(e);
}
}));

} else {
finalListener.onFailure(notElasticsearchModelException(model));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
Expand All @@ -65,7 +66,6 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.assignment.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.assignment.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService;

import java.util.HashMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;

Expand Down
Loading