|  | 
| 24 | 24 | import org.elasticsearch.test.ESTestCase; | 
| 25 | 25 | import org.elasticsearch.threadpool.ThreadPool; | 
| 26 | 26 | import org.elasticsearch.xpack.inference.InferencePlugin; | 
|  | 27 | +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; | 
| 27 | 28 | import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; | 
| 28 | 29 | import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; | 
|  | 30 | +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; | 
| 29 | 31 | import org.elasticsearch.xpack.inference.external.http.sender.Sender; | 
| 30 | 32 | import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; | 
| 31 | 33 | import org.junit.After; | 
|  | 
| 43 | 45 | import static org.elasticsearch.xpack.inference.Utils.mockClusterService; | 
| 44 | 46 | import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; | 
| 45 | 47 | import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; | 
|  | 48 | +import static org.hamcrest.Matchers.containsString; | 
|  | 49 | +import static org.hamcrest.Matchers.is; | 
| 46 | 50 | import static org.mockito.ArgumentMatchers.any; | 
| 47 | 51 | import static org.mockito.Mockito.doAnswer; | 
| 48 | 52 | import static org.mockito.Mockito.mock; | 
| @@ -189,6 +193,101 @@ protected void doInfer( | 
| 189 | 193 |         } | 
| 190 | 194 |     } | 
| 191 | 195 | 
 | 
|  | 196 | +    public void testReturnsValidationException_WhenQueryIsNullForRerankTaskType() throws IOException { | 
|  | 197 | +        var sender = createMockSender(); | 
|  | 198 | + | 
|  | 199 | +        var factory = mock(HttpRequestSender.Factory.class); | 
|  | 200 | +        when(factory.createSender()).thenReturn(sender); | 
|  | 201 | + | 
|  | 202 | +        try (var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { | 
|  | 203 | +            var model = mock(Model.class); | 
|  | 204 | +            when(model.getTaskType()).thenReturn(TaskType.RERANK); | 
|  | 205 | + | 
|  | 206 | +            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); | 
|  | 207 | + | 
|  | 208 | +            testService.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, listener); | 
|  | 209 | +            var exception = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); | 
|  | 210 | + | 
|  | 211 | +            assertThat(exception.getMessage(), containsString("Rerank task type requires a non-null query field")); | 
|  | 212 | +        } | 
|  | 213 | +    } | 
|  | 214 | + | 
|  | 215 | +    public void testInferSucceeds_WhenQueryIsDefinedForRerankTaskType() throws IOException { | 
|  | 216 | +        var sender = createMockSender(); | 
|  | 217 | + | 
|  | 218 | +        var factory = mock(HttpRequestSender.Factory.class); | 
|  | 219 | +        when(factory.createSender()).thenReturn(sender); | 
|  | 220 | + | 
|  | 221 | +        var queryString = "a query"; | 
|  | 222 | +        var testInput = "test input"; | 
|  | 223 | +        var doInferCalled = new AtomicReference<>(false); | 
|  | 224 | + | 
|  | 225 | +        var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()) { | 
|  | 226 | +            @Override | 
|  | 227 | +            protected void doInfer( | 
|  | 228 | +                Model model, | 
|  | 229 | +                InferenceInputs inputs, | 
|  | 230 | +                Map<String, Object> taskSettings, | 
|  | 231 | +                TimeValue timeout, | 
|  | 232 | +                ActionListener<InferenceServiceResults> listener | 
|  | 233 | +            ) { | 
|  | 234 | +                var queryDocs = inputs.castTo(QueryAndDocsInputs.class); | 
|  | 235 | +                assertThat(queryDocs.getQuery(), is(queryString)); | 
|  | 236 | +                assertThat(queryDocs.getChunks(), is(List.of(testInput))); | 
|  | 237 | +                doInferCalled.set(true); | 
|  | 238 | +                listener.onResponse(mock(InferenceServiceResults.class)); | 
|  | 239 | +            } | 
|  | 240 | +        }; | 
|  | 241 | + | 
|  | 242 | +        try (testService) { | 
|  | 243 | +            var model = mock(Model.class); | 
|  | 244 | +            when(model.getTaskType()).thenReturn(TaskType.RERANK); | 
|  | 245 | + | 
|  | 246 | +            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); | 
|  | 247 | + | 
|  | 248 | +            testService.infer(model, queryString, null, null, List.of(testInput), false, Map.of(), null, null, listener); | 
|  | 249 | +            assertNotNull(listener.actionGet(TIMEOUT)); | 
|  | 250 | +            assertTrue(doInferCalled.get()); | 
|  | 251 | +        } | 
|  | 252 | +    } | 
|  | 253 | + | 
|  | 254 | +    public void testInferSucceeds_WhenQueryIsNotDefinedForCompletionTaskType() throws IOException { | 
|  | 255 | +        var sender = createMockSender(); | 
|  | 256 | + | 
|  | 257 | +        var factory = mock(HttpRequestSender.Factory.class); | 
|  | 258 | +        when(factory.createSender()).thenReturn(sender); | 
|  | 259 | + | 
|  | 260 | +        var testInput = "test input"; | 
|  | 261 | +        var doInferCalled = new AtomicReference<>(false); | 
|  | 262 | + | 
|  | 263 | +        var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()) { | 
|  | 264 | +            @Override | 
|  | 265 | +            protected void doInfer( | 
|  | 266 | +                Model model, | 
|  | 267 | +                InferenceInputs inputs, | 
|  | 268 | +                Map<String, Object> taskSettings, | 
|  | 269 | +                TimeValue timeout, | 
|  | 270 | +                ActionListener<InferenceServiceResults> listener | 
|  | 271 | +            ) { | 
|  | 272 | +                var castedInput = inputs.castTo(ChatCompletionInput.class); | 
|  | 273 | +                assertThat(castedInput.getInputs(), is(List.of(testInput))); | 
|  | 274 | +                doInferCalled.set(true); | 
|  | 275 | +                listener.onResponse(mock(InferenceServiceResults.class)); | 
|  | 276 | +            } | 
|  | 277 | +        }; | 
|  | 278 | + | 
|  | 279 | +        try (testService) { | 
|  | 280 | +            var model = mock(Model.class); | 
|  | 281 | +            when(model.getTaskType()).thenReturn(TaskType.COMPLETION); | 
|  | 282 | + | 
|  | 283 | +            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); | 
|  | 284 | + | 
|  | 285 | +            testService.infer(model, null, null, null, List.of(testInput), false, Map.of(), null, null, listener); | 
|  | 286 | +            assertNotNull(listener.actionGet(TIMEOUT)); | 
|  | 287 | +            assertTrue(doInferCalled.get()); | 
|  | 288 | +        } | 
|  | 289 | +    } | 
|  | 290 | + | 
| 192 | 291 |     public static Sender createMockSender() { | 
| 193 | 292 |         var sender = mock(Sender.class); | 
| 194 | 293 |         doAnswer(invocationOnMock -> { | 
|  | 
0 commit comments