|
24 | 24 | import org.elasticsearch.test.ESTestCase; |
25 | 25 | import org.elasticsearch.threadpool.ThreadPool; |
26 | 26 | import org.elasticsearch.xpack.inference.InferencePlugin; |
| 27 | +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; |
27 | 28 | import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; |
28 | 29 | import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; |
| 30 | +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; |
29 | 31 | import org.elasticsearch.xpack.inference.external.http.sender.Sender; |
30 | 32 | import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; |
31 | 33 | import org.junit.After; |
|
43 | 45 | import static org.elasticsearch.xpack.inference.Utils.mockClusterService; |
44 | 46 | import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; |
45 | 47 | import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; |
| 48 | +import static org.hamcrest.Matchers.containsString; |
| 49 | +import static org.hamcrest.Matchers.is; |
46 | 50 | import static org.mockito.ArgumentMatchers.any; |
47 | 51 | import static org.mockito.Mockito.doAnswer; |
48 | 52 | import static org.mockito.Mockito.mock; |
@@ -189,6 +193,101 @@ protected void doInfer( |
189 | 193 | } |
190 | 194 | } |
191 | 195 |
|
| 196 | + public void testReturnsValidationException_WhenQueryIsNullForRerankTaskType() throws IOException { |
| 197 | + var sender = createMockSender(); |
| 198 | + |
| 199 | + var factory = mock(HttpRequestSender.Factory.class); |
| 200 | + when(factory.createSender()).thenReturn(sender); |
| 201 | + |
| 202 | + try (var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { |
| 203 | + var model = mock(Model.class); |
| 204 | + when(model.getTaskType()).thenReturn(TaskType.RERANK); |
| 205 | + |
| 206 | + PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); |
| 207 | + |
| 208 | + testService.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, listener); |
| 209 | + var exception = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); |
| 210 | + |
| 211 | + assertThat(exception.getMessage(), containsString("Rerank task type requires a non-null query field")); |
| 212 | + } |
| 213 | + } |
| 214 | + |
| 215 | + public void testInferSucceeds_WhenQueryIsDefinedForRerankTaskType() throws IOException { |
| 216 | + var sender = createMockSender(); |
| 217 | + |
| 218 | + var factory = mock(HttpRequestSender.Factory.class); |
| 219 | + when(factory.createSender()).thenReturn(sender); |
| 220 | + |
| 221 | + var queryString = "a query"; |
| 222 | + var testInput = "test input"; |
| 223 | + var doInferCalled = new AtomicReference<>(false); |
| 224 | + |
| 225 | + var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()) { |
| 226 | + @Override |
| 227 | + protected void doInfer( |
| 228 | + Model model, |
| 229 | + InferenceInputs inputs, |
| 230 | + Map<String, Object> taskSettings, |
| 231 | + TimeValue timeout, |
| 232 | + ActionListener<InferenceServiceResults> listener |
| 233 | + ) { |
| 234 | + var queryDocs = inputs.castTo(QueryAndDocsInputs.class); |
| 235 | + assertThat(queryDocs.getQuery(), is(queryString)); |
| 236 | + assertThat(queryDocs.getChunks(), is(List.of(testInput))); |
| 237 | + doInferCalled.set(true); |
| 238 | + listener.onResponse(mock(InferenceServiceResults.class)); |
| 239 | + } |
| 240 | + }; |
| 241 | + |
| 242 | + try (testService) { |
| 243 | + var model = mock(Model.class); |
| 244 | + when(model.getTaskType()).thenReturn(TaskType.RERANK); |
| 245 | + |
| 246 | + PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); |
| 247 | + |
| 248 | + testService.infer(model, queryString, null, null, List.of(testInput), false, Map.of(), null, null, listener); |
| 249 | + assertNotNull(listener.actionGet(TIMEOUT)); |
| 250 | + assertTrue(doInferCalled.get()); |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + public void testInferSucceeds_WhenQueryIsNotDefinedForCompletionTaskType() throws IOException { |
| 255 | + var sender = createMockSender(); |
| 256 | + |
| 257 | + var factory = mock(HttpRequestSender.Factory.class); |
| 258 | + when(factory.createSender()).thenReturn(sender); |
| 259 | + |
| 260 | + var testInput = "test input"; |
| 261 | + var doInferCalled = new AtomicReference<>(false); |
| 262 | + |
| 263 | + var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()) { |
| 264 | + @Override |
| 265 | + protected void doInfer( |
| 266 | + Model model, |
| 267 | + InferenceInputs inputs, |
| 268 | + Map<String, Object> taskSettings, |
| 269 | + TimeValue timeout, |
| 270 | + ActionListener<InferenceServiceResults> listener |
| 271 | + ) { |
| 272 | + var castedInput = inputs.castTo(ChatCompletionInput.class); |
| 273 | + assertThat(castedInput.getInputs(), is(List.of(testInput))); |
| 274 | + doInferCalled.set(true); |
| 275 | + listener.onResponse(mock(InferenceServiceResults.class)); |
| 276 | + } |
| 277 | + }; |
| 278 | + |
| 279 | + try (testService) { |
| 280 | + var model = mock(Model.class); |
| 281 | + when(model.getTaskType()).thenReturn(TaskType.COMPLETION); |
| 282 | + |
| 283 | + PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); |
| 284 | + |
| 285 | + testService.infer(model, null, null, null, List.of(testInput), false, Map.of(), null, null, listener); |
| 286 | + assertNotNull(listener.actionGet(TIMEOUT)); |
| 287 | + assertTrue(doInferCalled.get()); |
| 288 | + } |
| 289 | + } |
| 290 | + |
192 | 291 | public static Sender createMockSender() { |
193 | 292 | var sender = mock(Sender.class); |
194 | 293 | doAnswer(invocationOnMock -> { |
|
0 commit comments