2222import org .elasticsearch .xpack .esql .plugin .EsqlPlugin ;
2323import org .junit .After ;
2424import org .junit .Before ;
25+ import org .mockito .stubbing .Answer ;
2526
2627import java .util .ArrayList ;
2728import java .util .Iterator ;
3031import java .util .concurrent .atomic .AtomicReference ;
3132import java .util .stream .Stream ;
3233
34+ import static org .hamcrest .Matchers .allOf ;
3335import static org .hamcrest .Matchers .contains ;
3436import static org .hamcrest .Matchers .empty ;
3537import static org .hamcrest .Matchers .equalTo ;
3638import static org .hamcrest .Matchers .hasSize ;
39+ import static org .hamcrest .Matchers .notNullValue ;
3740import static org .mockito .ArgumentMatchers .any ;
38- import static org .mockito .ArgumentMatchers .eq ;
3941import static org .mockito .Mockito .doAnswer ;
4042import static org .mockito .Mockito .doThrow ;
4143import static org .mockito .Mockito .mock ;
44+ import static org .mockito .Mockito .never ;
4245import static org .mockito .Mockito .verify ;
4346import static org .mockito .Mockito .when ;
4447
@@ -65,42 +68,54 @@ public void shutdownThreadPool() {
6568 terminate (threadPool );
6669 }
6770
71+ @ SuppressWarnings ("unchecked" )
72+ private <T extends InferenceServiceResults > BulkInferenceOutputBuilder <T , List <T >> mockOutputBuilder (Class <T > resultClass )
73+ throws Exception {
74+ BulkInferenceOutputBuilder <T , List <T >> outputBuilder = mock (BulkInferenceOutputBuilder .class );
75+ List <T > output = new ArrayList <>();
76+ doAnswer (invocation -> {
77+ output .add (invocation .getArgument (0 , resultClass ));
78+ return null ;
79+ }).when (outputBuilder ).onInferenceResults (any ());
80+ when (outputBuilder .buildOutput ()).thenReturn (output );
81+ when (outputBuilder .inferenceResultsClass ()).thenReturn (resultClass );
82+
83+ return outputBuilder ;
84+ }
85+
6886 @ SuppressWarnings ("unchecked" )
6987 public void testSuccessfulExecution () throws Exception {
70- List <InferenceAction .Request > requests = Stream .generate (this ::mockInferenceRequest ).limit (between (1 , 50 )).toList ();
71- BulkInferenceRequestIterator requestIterator = requestIterator (requests );
72- List <InferenceAction .Response > responses = Stream .generate (() -> mockInferenceResponse (RankedDocsResults .class ))
73- .limit (requests .size ())
74- .toList ();
88+ List <InferenceAction .Request > requests = randomInferenceRequestList (between (1 , 50 ));
89+ List <InferenceAction .Response > responses = randomInferenceResponseList (requests .size (), RankedDocsResults .class );
7590
76- InferenceRunner inferenceRunner = mock (InferenceRunner .class );
77- doAnswer ((invocation ) -> {
91+ InferenceRunner inferenceRunner = mockInferenceRunner (invocation -> {
7892 ActionListener <InferenceAction .Response > l = invocation .getArgument (1 );
79- if (randomBoolean ()) {
80- Thread .sleep (between (0 , 5 ));
81- }
8293 l .onResponse (responses .get (requests .indexOf (invocation .getArgument (0 , InferenceAction .Request .class ))));
8394 return null ;
84- }). when ( inferenceRunner ). doInference ( any (), any ()) ;
95+ });
8596
97+ AtomicReference <List <RankedDocsResults >> output = new AtomicReference <>();
8698 ActionListener <List <RankedDocsResults >> listener = mock (ActionListener .class );
87-
88- List <RankedDocsResults > output = new ArrayList <>();
89- BulkInferenceOutputBuilder <RankedDocsResults , List <RankedDocsResults >> outputBuilder = mock (BulkInferenceOutputBuilder .class );
9099 doAnswer (invocation -> {
91- output .add (invocation .getArgument (0 , RankedDocsResults .class ));
100+ output .set (invocation .getArgument (0 , List .class ));
92101 return null ;
93- }).when (outputBuilder ).onInferenceResults (any ());
94- when (outputBuilder .buildOutput ()).thenReturn (output );
95- when (outputBuilder .inferenceResultsClass ()).thenReturn (RankedDocsResults .class );
102+ }).when (listener ).onResponse (any ());
96103
97- BulkInferenceExecutor executor = bulkExecutor (inferenceRunner );
98- executor .execute (requestIterator , outputBuilder , listener );
104+ BulkInferenceOutputBuilder <RankedDocsResults , List <RankedDocsResults >> outputBuilder = mockOutputBuilder (RankedDocsResults .class );
105+
106+ bulkExecutor (inferenceRunner ).execute (requestIterator (requests ), outputBuilder , listener );
99107
100108 assertBusy (() -> {
101- assertThat (output , hasSize (requests .size ()));
102- assertThat (output , contains (responses .stream ().map (InferenceAction .Response ::getResults ).toArray ()));
103- verify (listener ).onResponse (eq (output ));
109+ verify (listener ).onResponse (any ());
110+ verify (listener , never ()).onFailure (any ());
111+ assertThat (
112+ output .get (),
113+ allOf (
114+ notNullValue (),
115+ hasSize (requests .size ()),
116+ contains (responses .stream ().map (InferenceAction .Response ::getResults ).toArray ())
117+ )
118+ );
104119 });
105120 }
106121
@@ -109,35 +124,33 @@ public void testSuccessfulExecutionOnEmptyRequest() throws Exception {
109124 BulkInferenceRequestIterator requestIterator = mock (BulkInferenceRequestIterator .class );
110125 when (requestIterator .hasNext ()).thenReturn (false );
111126
127+ AtomicReference <List <RankedDocsResults >> output = new AtomicReference <>();
112128 ActionListener <List <RankedDocsResults >> listener = mock (ActionListener .class );
129+ doAnswer (invocation -> {
130+ output .set (invocation .getArgument (0 , List .class ));
131+ return null ;
132+ }).when (listener ).onResponse (any ());
113133
114- List <RankedDocsResults > output = new ArrayList <>();
115- BulkInferenceOutputBuilder <RankedDocsResults , List <RankedDocsResults >> outputBuilder = mock (BulkInferenceOutputBuilder .class );
116- when (outputBuilder .buildOutput ()).thenReturn (output );
134+ BulkInferenceOutputBuilder <RankedDocsResults , List <RankedDocsResults >> outputBuilder = mockOutputBuilder (RankedDocsResults .class );
117135
118- BulkInferenceExecutor executor = bulkExecutor (mock (InferenceRunner .class ));
119- executor .execute (requestIterator , outputBuilder , listener );
136+ bulkExecutor (mock (InferenceRunner .class )).execute (requestIterator , outputBuilder , listener );
120137
121138 assertBusy (() -> {
122- assertThat (output , empty ());
123- verify (listener ).onResponse (eq (output ));
139+ verify (listener ).onResponse (any ());
140+ verify (listener , never ()).onFailure (any ());
141+ assertThat (output .get (), allOf (notNullValue (), empty ()));
124142 });
125143 }
126144
127145 @ SuppressWarnings ("unchecked" )
128146 public void testInferenceRunnerAlwaysFails () throws Exception {
129- List <InferenceAction .Request > requests = Stream .generate (this ::mockInferenceRequest ).limit (between (1 , 30 )).toList ();
130- BulkInferenceRequestIterator requestIterator = requestIterator (requests );
147+ List <InferenceAction .Request > requests = randomInferenceRequestList (between (1 , 30 ));
131148
132- InferenceRunner inferenceRunner = mock (InferenceRunner .class );
133- doAnswer (invocation -> {
149+ InferenceRunner inferenceRunner = mock (invocation -> {
134150 ActionListener <InferenceAction .Response > listener = invocation .getArgument (1 );
135- if (randomBoolean ()) {
136- Thread .sleep (between (0 , 5 ));
137- }
138151 listener .onFailure (new RuntimeException ("inference failure" ));
139152 return null ;
140- }). when ( inferenceRunner ). doInference ( any (), any ()) ;
153+ });
141154
142155 ActionListener <List <RankedDocsResults >> listener = mock (ActionListener .class );
143156 AtomicReference <Exception > e = new AtomicReference <>();
@@ -146,37 +159,31 @@ public void testInferenceRunnerAlwaysFails() throws Exception {
146159 return null ;
147160 }).when (listener ).onFailure (any ());
148161
149- BulkInferenceOutputBuilder <RankedDocsResults , List <RankedDocsResults >> outputBuilder = mock ( BulkInferenceOutputBuilder .class );
162+ BulkInferenceOutputBuilder <RankedDocsResults , List <RankedDocsResults >> outputBuilder = mockOutputBuilder ( RankedDocsResults .class );
150163
151- BulkInferenceExecutor executor = bulkExecutor (inferenceRunner );
152- executor .execute (requestIterator , outputBuilder , listener );
164+ bulkExecutor (inferenceRunner ).execute (requestIterator (requests ), outputBuilder , listener );
153165
154166 assertBusy (() -> {
155167 verify (listener ).onFailure (any (RuntimeException .class ));
168+ verify (listener , never ()).onResponse (any ());
156169 assertThat (e .get ().getMessage (), equalTo ("inference failure" ));
157170 });
158171 }
159172
160173 @ SuppressWarnings ("unchecked" )
161174 public void testInferenceRunnerSometimesFails () throws Exception {
162- List <InferenceAction .Request > requests = Stream .generate (this ::mockInferenceRequest ).limit (between (2 , 30 )).toList ();
163- BulkInferenceRequestIterator requestIterator = requestIterator (requests );
175+ List <InferenceAction .Request > requests = randomInferenceRequestList (between (2 , 30 ));
164176
165- InferenceRunner inferenceRunner = mock (InferenceRunner .class );
166- doAnswer (invocation -> {
177+ InferenceRunner inferenceRunner = mockInferenceRunner (invocation -> {
167178 ActionListener <InferenceAction .Response > listener = invocation .getArgument (1 );
168- if (randomBoolean ()) {
169- Thread .sleep (between (0 , 5 ));
170- }
171-
172179 if ((requests .indexOf (invocation .getArgument (0 , InferenceAction .Request .class )) % requests .size ()) == 0 ) {
173180 listener .onFailure (new RuntimeException ("inference failure" ));
174181 } else {
175182 listener .onResponse (mockInferenceResponse (RankedDocsResults .class ));
176183 }
177184
178185 return null ;
179- }). when ( inferenceRunner ). doInference ( any (), any ()) ;
186+ });
180187
181188 ActionListener <List <RankedDocsResults >> listener = mock (ActionListener .class );
182189 AtomicReference <Exception > e = new AtomicReference <>();
@@ -185,29 +192,25 @@ public void testInferenceRunnerSometimesFails() throws Exception {
185192 return null ;
186193 }).when (listener ).onFailure (any ());
187194
188- BulkInferenceOutputBuilder <RankedDocsResults , List <RankedDocsResults >> outputBuilder = mock (BulkInferenceOutputBuilder .class );
189- when (outputBuilder .inferenceResultsClass ()).thenReturn (RankedDocsResults .class );
190-
191- BulkInferenceExecutor executor = bulkExecutor (inferenceRunner );
192- executor .execute (requestIterator , outputBuilder , listener );
195+ BulkInferenceOutputBuilder <RankedDocsResults , List <RankedDocsResults >> outputBuilder = mockOutputBuilder (RankedDocsResults .class );
196+ bulkExecutor (inferenceRunner ).execute (requestIterator (requests ), outputBuilder , listener );
193197
194198 assertBusy (() -> {
195199 verify (listener ).onFailure (any (RuntimeException .class ));
200+ verify (listener , never ()).onResponse (any ());
196201 assertThat (e .get ().getMessage (), equalTo ("inference failure" ));
197202 });
198203 }
199204
200205 @ SuppressWarnings ("unchecked" )
201206 public void testBuildOutputFailure () throws Exception {
202- List <InferenceAction .Request > requests = Stream .generate (this ::mockInferenceRequest ).limit (between (1 , 30 )).toList ();
203- BulkInferenceRequestIterator requestIterator = requestIterator (requests );
207+ List <InferenceAction .Request > requests = randomInferenceRequestList (between (1 , 30 ));
204208
205- InferenceRunner inferenceRunner = mock (InferenceRunner .class );
206- doAnswer (invocation -> {
209+ InferenceRunner inferenceRunner = mockInferenceRunner (invocation -> {
207210 ActionListener <InferenceAction .Response > listener = invocation .getArgument (1 );
208211 listener .onResponse (mockInferenceResponse (RankedDocsResults .class ));
209212 return null ;
210- }). when ( inferenceRunner ). doInference ( any (), any ()) ;
213+ });
211214
212215 ActionListener <List <RankedDocsResults >> listener = mock (ActionListener .class );
213216 AtomicReference <Exception > e = new AtomicReference <>();
@@ -222,10 +225,11 @@ public void testBuildOutputFailure() throws Exception {
222225
223226 BulkInferenceExecutor executor = bulkExecutor (inferenceRunner );
224227
225- executor .execute (requestIterator , outputBuilder , listener );
228+ bulkExecutor ( inferenceRunner ) .execute (requestIterator ( requests ) , outputBuilder , listener );
226229
227230 assertBusy (() -> {
228231 verify (listener ).onFailure (any (IllegalStateException .class ));
232+ verify (listener , never ()).onResponse (any ());
229233 assertThat (e .get ().getMessage (), equalTo ("build output failure" ));
230234 });
231235 }
@@ -255,4 +259,18 @@ private BulkInferenceRequestIterator requestIterator(List<InferenceAction.Reques
255259 doAnswer (i -> delegate .next ()).when (iterator ).next ();
256260 return iterator ;
257261 }
262+
263+ private List <InferenceAction .Request > randomInferenceRequestList (int size ) {
264+ return Stream .generate (this ::mockInferenceRequest ).limit (size ).toList ();
265+ }
266+
267+ private List <InferenceAction .Response > randomInferenceResponseList (int size , Class <? extends InferenceServiceResults > resultClass ) {
268+ return Stream .generate (() -> this .mockInferenceResponse (resultClass )).limit (size ).toList ();
269+ }
270+
271+ private InferenceRunner mockInferenceRunner (Answer <Void > doInferenceAnswer ) {
272+ InferenceRunner inferenceRunner = mock (InferenceRunner .class );
273+ doAnswer (doInferenceAnswer ).when (inferenceRunner ).doInference (any (), any ());
274+ return inferenceRunner ;
275+ }
258276}
0 commit comments