Skip to content

Commit 0250d6f

Browse files
[9.1] [ML] Perform query field validation for rerank task type (#137219) (#137293)
* [ML] Perform query field validation for rerank task type (#137219) * Validating that query is defined for rerank task * Validation for completion works without query field * Update docs/changelog/137219.yaml * Addressing feedback * [CI] Auto commit changes from spotless --------- Co-authored-by: elasticsearchmachine <[email protected]> (cherry picked from commit 34145ed) # Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java * Fixing tests
1 parent 088813f commit 0250d6f

File tree

4 files changed

+361
-161
lines changed

4 files changed

+361
-161
lines changed

docs/changelog/137219.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137219
2+
summary: Perform query field validation for rerank task type
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ private static InferenceInputs createInput(
9090
case RERANK -> {
9191
ValidationException validationException = new ValidationException();
9292
service.validateRerankParameters(returnDocuments, topN, validationException);
93+
94+
if (query == null) {
95+
validationException.addValidationError("Rerank task type requires a non-null query field");
96+
}
97+
9398
if (validationException.validationErrors().isEmpty() == false) {
9499
throw validationException;
95100
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
import org.elasticsearch.inference.TaskType;
2222
import org.elasticsearch.test.ESTestCase;
2323
import org.elasticsearch.threadpool.ThreadPool;
24+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
2425
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
2526
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
27+
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
2628
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2729
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
2830
import org.junit.After;
@@ -34,9 +36,12 @@
3436
import java.util.List;
3537
import java.util.Map;
3638
import java.util.concurrent.TimeUnit;
39+
import java.util.concurrent.atomic.AtomicReference;
3740

3841
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
3942
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
43+
import static org.hamcrest.Matchers.containsString;
44+
import static org.hamcrest.Matchers.is;
4045
import static org.mockito.Mockito.mock;
4146
import static org.mockito.Mockito.times;
4247
import static org.mockito.Mockito.verify;
@@ -101,7 +106,113 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep
101106
verifyNoMoreInteractions(sender);
102107
}
103108

104-
private static final class TestSenderService extends SenderService {
109+
public void testReturnsValidationException_WhenQueryIsNullForRerankTaskType() throws IOException {
110+
var sender = mock(Sender.class);
111+
112+
var factory = mock(HttpRequestSender.Factory.class);
113+
when(factory.createSender()).thenReturn(sender);
114+
115+
try (var testService = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
116+
var model = mock(Model.class);
117+
when(model.getTaskType()).thenReturn(TaskType.RERANK);
118+
119+
var exception = expectThrows(
120+
ValidationException.class,
121+
() -> testService.infer(
122+
model,
123+
null,
124+
null,
125+
null,
126+
List.of("test input"),
127+
false,
128+
Map.of(),
129+
InputType.SEARCH,
130+
null,
131+
new PlainActionFuture<>()
132+
)
133+
);
134+
135+
assertThat(exception.getMessage(), containsString("Rerank task type requires a non-null query field"));
136+
}
137+
}
138+
139+
public void testInferSucceeds_WhenQueryIsDefinedForRerankTaskType() throws IOException {
140+
var sender = mock(Sender.class);
141+
142+
var factory = mock(HttpRequestSender.Factory.class);
143+
when(factory.createSender()).thenReturn(sender);
144+
145+
var queryString = "a query";
146+
var testInput = "test input";
147+
var doInferCalled = new AtomicReference<>(false);
148+
149+
var testService = new TestSenderService(factory, createWithEmptySettings(threadPool)) {
150+
@Override
151+
protected void doInfer(
152+
Model model,
153+
InferenceInputs inputs,
154+
Map<String, Object> taskSettings,
155+
TimeValue timeout,
156+
ActionListener<InferenceServiceResults> listener
157+
) {
158+
var queryDocs = inputs.castTo(QueryAndDocsInputs.class);
159+
assertThat(queryDocs.getQuery(), is(queryString));
160+
assertThat(queryDocs.getChunks(), is(List.of(testInput)));
161+
doInferCalled.set(true);
162+
listener.onResponse(mock(InferenceServiceResults.class));
163+
}
164+
};
165+
166+
try (testService) {
167+
var model = mock(Model.class);
168+
when(model.getTaskType()).thenReturn(TaskType.RERANK);
169+
170+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
171+
172+
testService.infer(model, queryString, null, null, List.of(testInput), false, Map.of(), null, null, listener);
173+
assertNotNull(listener.actionGet(TIMEOUT));
174+
assertTrue(doInferCalled.get());
175+
}
176+
}
177+
178+
public void testInferSucceeds_WhenQueryIsNotDefinedForCompletionTaskType() throws IOException {
179+
var sender = mock(Sender.class);
180+
181+
var factory = mock(HttpRequestSender.Factory.class);
182+
when(factory.createSender()).thenReturn(sender);
183+
184+
var testInput = "test input";
185+
var doInferCalled = new AtomicReference<>(false);
186+
187+
var testService = new TestSenderService(factory, createWithEmptySettings(threadPool)) {
188+
@Override
189+
protected void doInfer(
190+
Model model,
191+
InferenceInputs inputs,
192+
Map<String, Object> taskSettings,
193+
TimeValue timeout,
194+
ActionListener<InferenceServiceResults> listener
195+
) {
196+
var castedInput = inputs.castTo(ChatCompletionInput.class);
197+
assertThat(castedInput.getInputs(), is(List.of(testInput)));
198+
doInferCalled.set(true);
199+
listener.onResponse(mock(InferenceServiceResults.class));
200+
}
201+
};
202+
203+
try (testService) {
204+
var model = mock(Model.class);
205+
when(model.getTaskType()).thenReturn(TaskType.COMPLETION);
206+
207+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
208+
209+
testService.infer(model, null, null, null, List.of(testInput), false, Map.of(), null, null, listener);
210+
assertNotNull(listener.actionGet(TIMEOUT));
211+
assertTrue(doInferCalled.get());
212+
}
213+
}
214+
215+
private static class TestSenderService extends SenderService {
105216
TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
106217
super(factory, serviceComponents);
107218
}

0 commit comments

Comments
 (0)