Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/137219.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 137219
summary: Perform query field validation for rerank task type
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ private static InferenceInputs createInput(
case RERANK -> {
ValidationException validationException = new ValidationException();
service.validateRerankParameters(returnDocuments, topN, validationException);

if (query == null) {
validationException.addValidationError("Rerank task type requires a non-null query field");
}

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.junit.After;
Expand All @@ -34,9 +36,12 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -101,7 +106,113 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep
verifyNoMoreInteractions(sender);
}

private static final class TestSenderService extends SenderService {
public void testReturnsValidationException_WhenQueryIsNullForRerankTaskType() throws IOException {
var sender = mock(Sender.class);

var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);

try (var testService = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
var model = mock(Model.class);
when(model.getTaskType()).thenReturn(TaskType.RERANK);

var exception = expectThrows(
ValidationException.class,
() -> testService.infer(
model,
null,
null,
null,
List.of("test input"),
false,
Map.of(),
InputType.SEARCH,
null,
new PlainActionFuture<>()
)
);

assertThat(exception.getMessage(), containsString("Rerank task type requires a non-null query field"));
}
}

public void testInferSucceeds_WhenQueryIsDefinedForRerankTaskType() throws IOException {
var sender = mock(Sender.class);

var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);

var queryString = "a query";
var testInput = "test input";
var doInferCalled = new AtomicReference<>(false);

var testService = new TestSenderService(factory, createWithEmptySettings(threadPool)) {
@Override
protected void doInfer(
Model model,
InferenceInputs inputs,
Map<String, Object> taskSettings,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
var queryDocs = inputs.castTo(QueryAndDocsInputs.class);
assertThat(queryDocs.getQuery(), is(queryString));
assertThat(queryDocs.getChunks(), is(List.of(testInput)));
doInferCalled.set(true);
listener.onResponse(mock(InferenceServiceResults.class));
}
};

try (testService) {
var model = mock(Model.class);
when(model.getTaskType()).thenReturn(TaskType.RERANK);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();

testService.infer(model, queryString, null, null, List.of(testInput), false, Map.of(), null, null, listener);
assertNotNull(listener.actionGet(TIMEOUT));
assertTrue(doInferCalled.get());
}
}

public void testInferSucceeds_WhenQueryIsNotDefinedForCompletionTaskType() throws IOException {
var sender = mock(Sender.class);

var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);

var testInput = "test input";
var doInferCalled = new AtomicReference<>(false);

var testService = new TestSenderService(factory, createWithEmptySettings(threadPool)) {
@Override
protected void doInfer(
Model model,
InferenceInputs inputs,
Map<String, Object> taskSettings,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
var castedInput = inputs.castTo(ChatCompletionInput.class);
assertThat(castedInput.getInputs(), is(List.of(testInput)));
doInferCalled.set(true);
listener.onResponse(mock(InferenceServiceResults.class));
}
};

try (testService) {
var model = mock(Model.class);
when(model.getTaskType()).thenReturn(TaskType.COMPLETION);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();

testService.infer(model, null, null, null, List.of(testInput), false, Map.of(), null, null, listener);
assertNotNull(listener.actionGet(TIMEOUT));
assertTrue(doInferCalled.get());
}
}

private static class TestSenderService extends SenderService {
TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
}
Expand Down
Loading