Skip to content

Commit 52eb9e5

Browse files
[9.2] [ML] Perform query field validation for rerank task type (#137219) (#137292)
* Fixing merge error * Fixing import
1 parent 437952c commit 52eb9e5

File tree

4 files changed

+348
-161
lines changed

4 files changed

+348
-161
lines changed

docs/changelog/137219.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137219
2+
summary: Perform query field validation for rerank task type
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ private static InferenceInputs createInput(
9696
case RERANK -> {
9797
ValidationException validationException = new ValidationException();
9898
service.validateRerankParameters(returnDocuments, topN, validationException);
99+
100+
if (query == null) {
101+
validationException.addValidationError("Rerank task type requires a non-null query field");
102+
}
103+
99104
if (validationException.validationErrors().isEmpty() == false) {
100105
throw validationException;
101106
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import org.elasticsearch.test.ESTestCase;
2525
import org.elasticsearch.threadpool.ThreadPool;
2626
import org.elasticsearch.xpack.inference.InferencePlugin;
27+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
2728
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
2829
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
30+
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
2931
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
3032
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
3133
import org.junit.After;
@@ -43,6 +45,8 @@
4345
import static org.elasticsearch.xpack.inference.Utils.mockClusterService;
4446
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
4547
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
48+
import static org.hamcrest.Matchers.containsString;
49+
import static org.hamcrest.Matchers.is;
4650
import static org.mockito.ArgumentMatchers.any;
4751
import static org.mockito.Mockito.doAnswer;
4852
import static org.mockito.Mockito.mock;
@@ -189,6 +193,101 @@ protected void doInfer(
189193
}
190194
}
191195

196+
public void testReturnsValidationException_WhenQueryIsNullForRerankTaskType() throws IOException {
197+
var sender = createMockSender();
198+
199+
var factory = mock(HttpRequestSender.Factory.class);
200+
when(factory.createSender()).thenReturn(sender);
201+
202+
try (var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
203+
var model = mock(Model.class);
204+
when(model.getTaskType()).thenReturn(TaskType.RERANK);
205+
206+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
207+
208+
testService.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, listener);
209+
var exception = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT));
210+
211+
assertThat(exception.getMessage(), containsString("Rerank task type requires a non-null query field"));
212+
}
213+
}
214+
215+
public void testInferSucceeds_WhenQueryIsDefinedForRerankTaskType() throws IOException {
216+
var sender = createMockSender();
217+
218+
var factory = mock(HttpRequestSender.Factory.class);
219+
when(factory.createSender()).thenReturn(sender);
220+
221+
var queryString = "a query";
222+
var testInput = "test input";
223+
var doInferCalled = new AtomicReference<>(false);
224+
225+
var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()) {
226+
@Override
227+
protected void doInfer(
228+
Model model,
229+
InferenceInputs inputs,
230+
Map<String, Object> taskSettings,
231+
TimeValue timeout,
232+
ActionListener<InferenceServiceResults> listener
233+
) {
234+
var queryDocs = inputs.castTo(QueryAndDocsInputs.class);
235+
assertThat(queryDocs.getQuery(), is(queryString));
236+
assertThat(queryDocs.getChunks(), is(List.of(testInput)));
237+
doInferCalled.set(true);
238+
listener.onResponse(mock(InferenceServiceResults.class));
239+
}
240+
};
241+
242+
try (testService) {
243+
var model = mock(Model.class);
244+
when(model.getTaskType()).thenReturn(TaskType.RERANK);
245+
246+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
247+
248+
testService.infer(model, queryString, null, null, List.of(testInput), false, Map.of(), null, null, listener);
249+
assertNotNull(listener.actionGet(TIMEOUT));
250+
assertTrue(doInferCalled.get());
251+
}
252+
}
253+
254+
public void testInferSucceeds_WhenQueryIsNotDefinedForCompletionTaskType() throws IOException {
255+
var sender = createMockSender();
256+
257+
var factory = mock(HttpRequestSender.Factory.class);
258+
when(factory.createSender()).thenReturn(sender);
259+
260+
var testInput = "test input";
261+
var doInferCalled = new AtomicReference<>(false);
262+
263+
var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()) {
264+
@Override
265+
protected void doInfer(
266+
Model model,
267+
InferenceInputs inputs,
268+
Map<String, Object> taskSettings,
269+
TimeValue timeout,
270+
ActionListener<InferenceServiceResults> listener
271+
) {
272+
var castedInput = inputs.castTo(ChatCompletionInput.class);
273+
assertThat(castedInput.getInputs(), is(List.of(testInput)));
274+
doInferCalled.set(true);
275+
listener.onResponse(mock(InferenceServiceResults.class));
276+
}
277+
};
278+
279+
try (testService) {
280+
var model = mock(Model.class);
281+
when(model.getTaskType()).thenReturn(TaskType.COMPLETION);
282+
283+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
284+
285+
testService.infer(model, null, null, null, List.of(testInput), false, Map.of(), null, null, listener);
286+
assertNotNull(listener.actionGet(TIMEOUT));
287+
assertTrue(doInferCalled.get());
288+
}
289+
}
290+
192291
public static Sender createMockSender() {
193292
var sender = mock(Sender.class);
194293
doAnswer(invocationOnMock -> {

0 commit comments

Comments
 (0)