Skip to content

Commit 69ab6d8

Browse files
[ML] Perform query field validation for rerank task type (#137219)
* Validating that query is defined for rerank task * Validation for completion works without query field * Update docs/changelog/137219.yaml * Addressing feedback * [CI] Auto commit changes from spotless --------- Co-authored-by: elasticsearchmachine <[email protected]> (cherry picked from commit 34145ed) # Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
1 parent 088813f commit 69ab6d8

File tree

4 files changed

+350
-161
lines changed

4 files changed

+350
-161
lines changed

docs/changelog/137219.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137219
2+
summary: Perform query field validation for rerank task type
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ private static InferenceInputs createInput(
9090
case RERANK -> {
9191
ValidationException validationException = new ValidationException();
9292
service.validateRerankParameters(returnDocuments, topN, validationException);
93+
94+
if (query == null) {
95+
validationException.addValidationError("Rerank task type requires a non-null query field");
96+
}
97+
9398
if (validationException.validationErrors().isEmpty() == false) {
9499
throw validationException;
95100
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
import org.elasticsearch.inference.TaskType;
2222
import org.elasticsearch.test.ESTestCase;
2323
import org.elasticsearch.threadpool.ThreadPool;
24+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
2425
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
2526
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
27+
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
2628
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2729
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
2830
import org.junit.After;
@@ -34,9 +36,12 @@
3436
import java.util.List;
3537
import java.util.Map;
3638
import java.util.concurrent.TimeUnit;
39+
import java.util.concurrent.atomic.AtomicReference;
3740

3841
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
3942
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
43+
import static org.hamcrest.Matchers.containsString;
44+
import static org.hamcrest.Matchers.is;
4045
import static org.mockito.Mockito.mock;
4146
import static org.mockito.Mockito.times;
4247
import static org.mockito.Mockito.verify;
@@ -101,7 +106,102 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep
101106
verifyNoMoreInteractions(sender);
102107
}
103108

104-
private static final class TestSenderService extends SenderService {
109+
public void testReturnsValidationException_WhenQueryIsNullForRerankTaskType() throws IOException {
110+
var sender = mock(Sender.class);
111+
112+
var factory = mock(HttpRequestSender.Factory.class);
113+
when(factory.createSender()).thenReturn(sender);
114+
115+
try (var testService = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
116+
var model = mock(Model.class);
117+
when(model.getTaskType()).thenReturn(TaskType.RERANK);
118+
119+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
120+
121+
testService.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, listener);
122+
var exception = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT));
123+
124+
assertThat(exception.getMessage(), containsString("Rerank task type requires a non-null query field"));
125+
}
126+
}
127+
128+
public void testInferSucceeds_WhenQueryIsDefinedForRerankTaskType() throws IOException {
129+
var sender = mock(Sender.class);
130+
131+
var factory = mock(HttpRequestSender.Factory.class);
132+
when(factory.createSender()).thenReturn(sender);
133+
134+
var queryString = "a query";
135+
var testInput = "test input";
136+
var doInferCalled = new AtomicReference<>(false);
137+
138+
var testService = new TestSenderService(factory, createWithEmptySettings(threadPool)) {
139+
@Override
140+
protected void doInfer(
141+
Model model,
142+
InferenceInputs inputs,
143+
Map<String, Object> taskSettings,
144+
TimeValue timeout,
145+
ActionListener<InferenceServiceResults> listener
146+
) {
147+
var queryDocs = inputs.castTo(QueryAndDocsInputs.class);
148+
assertThat(queryDocs.getQuery(), is(queryString));
149+
assertThat(queryDocs.getChunks(), is(List.of(testInput)));
150+
doInferCalled.set(true);
151+
listener.onResponse(mock(InferenceServiceResults.class));
152+
}
153+
};
154+
155+
try (testService) {
156+
var model = mock(Model.class);
157+
when(model.getTaskType()).thenReturn(TaskType.RERANK);
158+
159+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
160+
161+
testService.infer(model, queryString, null, null, List.of(testInput), false, Map.of(), null, null, listener);
162+
assertNotNull(listener.actionGet(TIMEOUT));
163+
assertTrue(doInferCalled.get());
164+
}
165+
}
166+
167+
public void testInferSucceeds_WhenQueryIsNotDefinedForCompletionTaskType() throws IOException {
168+
var sender = mock(Sender.class);
169+
170+
var factory = mock(HttpRequestSender.Factory.class);
171+
when(factory.createSender()).thenReturn(sender);
172+
173+
var testInput = "test input";
174+
var doInferCalled = new AtomicReference<>(false);
175+
176+
var testService = new TestSenderService(factory, createWithEmptySettings(threadPool)) {
177+
@Override
178+
protected void doInfer(
179+
Model model,
180+
InferenceInputs inputs,
181+
Map<String, Object> taskSettings,
182+
TimeValue timeout,
183+
ActionListener<InferenceServiceResults> listener
184+
) {
185+
var castedInput = inputs.castTo(ChatCompletionInput.class);
186+
assertThat(castedInput.getInputs(), is(List.of(testInput)));
187+
doInferCalled.set(true);
188+
listener.onResponse(mock(InferenceServiceResults.class));
189+
}
190+
};
191+
192+
try (testService) {
193+
var model = mock(Model.class);
194+
when(model.getTaskType()).thenReturn(TaskType.COMPLETION);
195+
196+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
197+
198+
testService.infer(model, null, null, null, List.of(testInput), false, Map.of(), null, null, listener);
199+
assertNotNull(listener.actionGet(TIMEOUT));
200+
assertTrue(doInferCalled.get());
201+
}
202+
}
203+
204+
private static class TestSenderService extends SenderService {
105205
TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
106206
super(factory, serviceComponents);
107207
}

0 commit comments

Comments
 (0)