Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/132973.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 132973
summary: Preserve lost thread context in node inference action. A lost context causes a memory leak if APM tracing is enabled
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.concurrent.AtomicArray;
Expand All @@ -20,6 +21,7 @@
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
Expand All @@ -37,11 +39,14 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
InferTrainedModelDeploymentAction.Response,
InferTrainedModelDeploymentAction.Response> {

private final ThreadPool threadPool;

@Inject
public TransportInferTrainedModelDeploymentAction(
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters
ActionFilters actionFilters,
ThreadPool threadPool
) {
super(
InferTrainedModelDeploymentAction.NAME,
Expand All @@ -52,6 +57,7 @@ public TransportInferTrainedModelDeploymentAction(
InferTrainedModelDeploymentAction.Response::new,
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
this.threadPool = threadPool;
}

@Override
Expand Down Expand Up @@ -99,6 +105,9 @@ protected void taskOperation(
// and return order the results to match the request order
AtomicInteger count = new AtomicInteger();
AtomicArray<InferenceResults> results = new AtomicArray<>(nlpInputs.size());

var contextPreservingListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());

int slot = 0;
for (var input : nlpInputs) {
task.infer(
Expand All @@ -109,7 +118,7 @@ protected void taskOperation(
request.getPrefixType(),
actionTask,
request.isChunkResults(),
orderedListener(count, results, slot++, nlpInputs.size(), listener)
orderedListener(count, results, slot++, nlpInputs.size(), contextPreservingListener)
);
}
}
Expand Down