12
12
import org .elasticsearch .action .FailedNodeException ;
13
13
import org .elasticsearch .action .TaskOperationFailure ;
14
14
import org .elasticsearch .action .support .ActionFilters ;
15
+ import org .elasticsearch .action .support .ContextPreservingActionListener ;
15
16
import org .elasticsearch .action .support .tasks .TransportTasksAction ;
16
17
import org .elasticsearch .cluster .service .ClusterService ;
17
18
import org .elasticsearch .common .util .concurrent .AtomicArray ;
20
21
import org .elasticsearch .injection .guice .Inject ;
21
22
import org .elasticsearch .rest .RestStatus ;
22
23
import org .elasticsearch .tasks .CancellableTask ;
24
+ import org .elasticsearch .threadpool .ThreadPool ;
23
25
import org .elasticsearch .transport .TransportService ;
24
26
import org .elasticsearch .xpack .core .ml .action .InferTrainedModelDeploymentAction ;
25
27
import org .elasticsearch .xpack .core .ml .inference .results .ErrorInferenceResults ;
@@ -37,11 +39,14 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
37
39
InferTrainedModelDeploymentAction .Response ,
38
40
InferTrainedModelDeploymentAction .Response > {
39
41
42
+ private final ThreadPool threadPool ;
43
+
40
44
@ Inject
41
45
public TransportInferTrainedModelDeploymentAction (
42
46
ClusterService clusterService ,
43
47
TransportService transportService ,
44
- ActionFilters actionFilters
48
+ ActionFilters actionFilters ,
49
+ ThreadPool threadPool
45
50
) {
46
51
super (
47
52
InferTrainedModelDeploymentAction .NAME ,
@@ -52,6 +57,7 @@ public TransportInferTrainedModelDeploymentAction(
52
57
InferTrainedModelDeploymentAction .Response ::new ,
53
58
EsExecutors .DIRECT_EXECUTOR_SERVICE
54
59
);
60
+ this .threadPool = threadPool ;
55
61
}
56
62
57
63
@ Override
@@ -99,6 +105,9 @@ protected void taskOperation(
99
105
// and return order the results to match the request order
100
106
AtomicInteger count = new AtomicInteger ();
101
107
AtomicArray <InferenceResults > results = new AtomicArray <>(nlpInputs .size ());
108
+
109
+ var contextPreservingListener = ContextPreservingActionListener .wrapPreservingContext (listener , threadPool .getThreadContext ());
110
+
102
111
int slot = 0 ;
103
112
for (var input : nlpInputs ) {
104
113
task .infer (
@@ -109,7 +118,7 @@ protected void taskOperation(
109
118
request .getPrefixType (),
110
119
actionTask ,
111
120
request .isChunkResults (),
112
- orderedListener (count , results , slot ++, nlpInputs .size (), listener )
121
+ orderedListener (count , results , slot ++, nlpInputs .size (), contextPreservingListener )
113
122
);
114
123
}
115
124
}
0 commit comments