1212import org .elasticsearch .action .FailedNodeException ;
1313import org .elasticsearch .action .TaskOperationFailure ;
1414import org .elasticsearch .action .support .ActionFilters ;
15+ import org .elasticsearch .action .support .ContextPreservingActionListener ;
1516import org .elasticsearch .action .support .tasks .TransportTasksAction ;
1617import org .elasticsearch .cluster .service .ClusterService ;
1718import org .elasticsearch .common .util .concurrent .AtomicArray ;
2021import org .elasticsearch .injection .guice .Inject ;
2122import org .elasticsearch .rest .RestStatus ;
2223import org .elasticsearch .tasks .CancellableTask ;
24+ import org .elasticsearch .threadpool .ThreadPool ;
2325import org .elasticsearch .transport .TransportService ;
2426import org .elasticsearch .xpack .core .ml .action .InferTrainedModelDeploymentAction ;
2527import org .elasticsearch .xpack .core .ml .inference .results .ErrorInferenceResults ;
@@ -37,11 +39,14 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
3739 InferTrainedModelDeploymentAction .Response ,
3840 InferTrainedModelDeploymentAction .Response > {
3941
42+ private final ThreadPool threadPool ;
43+
4044 @ Inject
4145 public TransportInferTrainedModelDeploymentAction (
4246 ClusterService clusterService ,
4347 TransportService transportService ,
44- ActionFilters actionFilters
48+ ActionFilters actionFilters ,
49+ ThreadPool threadPool
4550 ) {
4651 super (
4752 InferTrainedModelDeploymentAction .NAME ,
@@ -52,6 +57,7 @@ public TransportInferTrainedModelDeploymentAction(
5257 InferTrainedModelDeploymentAction .Response ::new ,
5358 EsExecutors .DIRECT_EXECUTOR_SERVICE
5459 );
60+ this .threadPool = threadPool ;
5561 }
5662
5763 @ Override
@@ -99,6 +105,9 @@ protected void taskOperation(
99105 // and return order the results to match the request order
100106 AtomicInteger count = new AtomicInteger ();
101107 AtomicArray <InferenceResults > results = new AtomicArray <>(nlpInputs .size ());
108+
109+ var contextPreservingListener = ContextPreservingActionListener .wrapPreservingContext (listener , threadPool .getThreadContext ());
110+
102111 int slot = 0 ;
103112 for (var input : nlpInputs ) {
104113 task .infer (
@@ -109,7 +118,7 @@ protected void taskOperation(
109118 request .getPrefixType (),
110119 actionTask ,
111120 request .isChunkResults (),
112- orderedListener (count , results , slot ++, nlpInputs .size (), listener )
121+ orderedListener (count , results , slot ++, nlpInputs .size (), contextPreservingListener )
113122 );
114123 }
115124 }
0 commit comments