diff --git a/docs/changelog/132973.yaml b/docs/changelog/132973.yaml new file mode 100644 index 0000000000000..b94c38a9d3a8d --- /dev/null +++ b/docs/changelog/132973.yaml @@ -0,0 +1,5 @@ +pr: 132973 +summary: Preserve lost thread context in node inference action. A lost context causes a memory leak if APM tracing is enabled +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java index c4915ef45c16d..0eb10aa33b9c7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.TaskOperationFailure; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.concurrent.AtomicArray; @@ -20,6 +21,7 @@ import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; @@ -37,11 +39,14 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc InferTrainedModelDeploymentAction.Response, InferTrainedModelDeploymentAction.Response> { + private final ThreadPool threadPool; + @Inject public TransportInferTrainedModelDeploymentAction( ClusterService clusterService, TransportService transportService, - ActionFilters actionFilters + ActionFilters actionFilters, + ThreadPool threadPool ) { super( InferTrainedModelDeploymentAction.NAME, @@ -52,6 +57,7 @@ public TransportInferTrainedModelDeploymentAction( InferTrainedModelDeploymentAction.Response::new, EsExecutors.DIRECT_EXECUTOR_SERVICE ); + this.threadPool = threadPool; } @Override @@ -99,6 +105,9 @@ protected void taskOperation( // and return order the results to match the request order AtomicInteger count = new AtomicInteger(); AtomicArray results = new AtomicArray<>(nlpInputs.size()); + + var contextPreservingListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()); + int slot = 0; for (var input : nlpInputs) { task.infer( @@ -109,7 +118,7 @@ protected void taskOperation( request.getPrefixType(), actionTask, request.isChunkResults(), - orderedListener(count, results, slot++, nlpInputs.size(), listener) + orderedListener(count, results, slot++, nlpInputs.size(), contextPreservingListener) ); } }