21
21
import org .elasticsearch .test .ESTestCase ;
22
22
import org .elasticsearch .threadpool .FixedExecutorBuilder ;
23
23
import org .elasticsearch .threadpool .TestThreadPool ;
24
+ import org .elasticsearch .threadpool .ThreadPool ;
24
25
import org .elasticsearch .xpack .core .inference .action .GetInferenceModelAction ;
25
26
import org .elasticsearch .xpack .esql .parser .EsqlParser ;
26
- import org .elasticsearch .xpack .esql .plugin .EsqlPlugin ;
27
27
import org .junit .After ;
28
28
import org .junit .Before ;
29
29
@@ -51,7 +51,7 @@ public void setThreadPool() {
51
51
getTestClass ().getSimpleName (),
52
52
new FixedExecutorBuilder (
53
53
Settings .EMPTY ,
54
- EsqlPlugin . ESQL_WORKER_THREAD_POOL_NAME ,
54
+ "inference_utility" ,
55
55
between (1 , 10 ),
56
56
1024 ,
57
57
"esql" ,
@@ -101,9 +101,12 @@ public void testResolveInferenceIds() throws Exception {
101
101
List <String > inferenceIds = List .of ("rerank-plan" );
102
102
SetOnce <InferenceResolution > inferenceResolutionSetOnce = new SetOnce <>();
103
103
104
- inferenceResolver .resolveInferenceIds (inferenceIds , ActionListener .wrap (inferenceResolutionSetOnce ::set , e -> {
105
- throw new RuntimeException (e );
106
- }));
104
+ inferenceResolver .resolveInferenceIds (
105
+ inferenceIds ,
106
+ assertAnswerUsingThreadPool (ActionListener .wrap (inferenceResolutionSetOnce ::set , e -> {
107
+ throw new RuntimeException (e );
108
+ }))
109
+ );
107
110
108
111
assertBusy (() -> {
109
112
InferenceResolution inferenceResolution = inferenceResolutionSetOnce .get ();
@@ -118,9 +121,12 @@ public void testResolveMultipleInferenceIds() throws Exception {
118
121
List <String > inferenceIds = List .of ("rerank-plan" , "rerank-plan" , "completion-plan" );
119
122
SetOnce <InferenceResolution > inferenceResolutionSetOnce = new SetOnce <>();
120
123
121
- inferenceResolver .resolveInferenceIds (inferenceIds , ActionListener .wrap (inferenceResolutionSetOnce ::set , e -> {
122
- throw new RuntimeException (e );
123
- }));
124
+ inferenceResolver .resolveInferenceIds (
125
+ inferenceIds ,
126
+ assertAnswerUsingThreadPool (ActionListener .wrap (inferenceResolutionSetOnce ::set , e -> {
127
+ throw new RuntimeException (e );
128
+ }))
129
+ );
124
130
125
131
assertBusy (() -> {
126
132
InferenceResolution inferenceResolution = inferenceResolutionSetOnce .get ();
@@ -143,9 +149,12 @@ public void testResolveMissingInferenceIds() throws Exception {
143
149
144
150
SetOnce <InferenceResolution > inferenceResolutionSetOnce = new SetOnce <>();
145
151
146
- inferenceResolver .resolveInferenceIds (inferenceIds , ActionListener .wrap (inferenceResolutionSetOnce ::set , e -> {
147
- throw new RuntimeException (e );
148
- }));
152
+ inferenceResolver .resolveInferenceIds (
153
+ inferenceIds ,
154
+ assertAnswerUsingThreadPool (ActionListener .wrap (inferenceResolutionSetOnce ::set , e -> {
155
+ throw new RuntimeException (e );
156
+ }))
157
+ );
149
158
150
159
assertBusy (() -> {
151
160
InferenceResolution inferenceResolution = inferenceResolutionSetOnce .get ();
@@ -173,21 +182,17 @@ private Client mockClient() {
173
182
}
174
183
};
175
184
176
- if (randomBoolean ()) {
177
- sendResponse .run ();
178
- } else {
179
- threadPool .schedule (
180
- sendResponse ,
181
- TimeValue .timeValueNanos (between (1 , 1_000 )),
182
- threadPool .executor (EsqlPlugin .ESQL_WORKER_THREAD_POOL_NAME )
183
- );
184
- }
185
+ threadPool .schedule (sendResponse , TimeValue .timeValueNanos (between (1 , 1_000 )), threadPool .executor ("inference_utility" ));
185
186
186
187
return null ;
187
188
}).when (client ).execute (eq (GetInferenceModelAction .INSTANCE ), any (), any ());
188
189
return client ;
189
190
}
190
191
192
+ private <T > ActionListener <T > assertAnswerUsingThreadPool (ActionListener <T > actionListener ) {
193
+ return ActionListener .runBefore (actionListener , () -> ThreadPool .assertCurrentThreadPool (ThreadPool .Names .SEARCH_COORDINATION ));
194
+ }
195
+
191
196
private static ActionResponse getInferenceModelResponse (GetInferenceModelAction .Request request ) {
192
197
GetInferenceModelAction .Response response = mock (GetInferenceModelAction .Response .class );
193
198
@@ -205,7 +210,7 @@ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.
205
210
}
206
211
207
212
private InferenceResolver inferenceResolver () {
208
- return new InferenceResolver (mockClient ());
213
+ return new InferenceResolver (mockClient (), threadPool );
209
214
}
210
215
211
216
private static ModelConfigurations mockModelConfig (String inferenceId , TaskType taskType ) {
0 commit comments