Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/132973.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 132973
summary: Preserve lost thread context in node inference action
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.concurrent.AtomicArray;
Expand All @@ -20,6 +21,7 @@
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
Expand All @@ -37,11 +39,14 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
InferTrainedModelDeploymentAction.Response,
InferTrainedModelDeploymentAction.Response> {

private final ThreadPool threadPool;

@Inject
public TransportInferTrainedModelDeploymentAction(
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters
ActionFilters actionFilters,
ThreadPool threadPool
) {
super(
InferTrainedModelDeploymentAction.NAME,
Expand All @@ -52,6 +57,7 @@ public TransportInferTrainedModelDeploymentAction(
InferTrainedModelDeploymentAction.Response::new,
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
this.threadPool = threadPool;
}

@Override
Expand Down Expand Up @@ -99,6 +105,9 @@ protected void taskOperation(
// and return order the results to match the request order
AtomicInteger count = new AtomicInteger();
AtomicArray<InferenceResults> results = new AtomicArray<>(nlpInputs.size());

var contextPreservingListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());

int slot = 0;
for (var input : nlpInputs) {
task.infer(
Expand All @@ -109,7 +118,7 @@ protected void taskOperation(
request.getPrefixType(),
actionTask,
request.isChunkResults(),
orderedListener(count, results, slot++, nlpInputs.size(), listener)
orderedListener(count, results, slot++, nlpInputs.size(), contextPreservingListener)
);
}
}
Expand Down