Skip to content

Commit 7e119c7

Browse files
authored
[ML] Correctly handle empty inputs in chunkedInfer() (elastic#138838)
- Add method to allow services that implement SenderService to indicate whether they support chunked inference - Return immediately if the input list is empty for services that support chunked inference - Throw exception if the input list is empty for services that do not support chunked inference, to maintain existing behaviour - Add tests for all services that implement doChunkedInfer() - Update DeepSeekServiceTests for new error message (cherry picked from commit f70dbb8)
1 parent 30e5861 commit 7e119c7

File tree

26 files changed

+487
-9
lines changed

26 files changed

+487
-9
lines changed

docs/changelog/138632.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 138632
2+
summary: Correctly handle empty inputs in `chunkedInfer()`
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/core/src/test/java/org/elasticsearch/test/http/MockRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public InetSocketAddress getRemoteAddress() {
7575

7676
@Override
7777
public String toString() {
78-
return String.format(Locale.ROOT, "%s %s", method, uri);
78+
return String.format(Locale.ROOT, "%s %s %s", method, uri, body);
7979
}
8080

8181
/**

x-pack/plugin/core/src/test/java/org/elasticsearch/test/http/MockWebServer.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ public void start() throws IOException {
106106
server.createContext("/", s -> {
107107
try {
108108
MockResponse response = responses.poll();
109+
System.out.println("DONAL: done polling responses, got " + response);
109110
MockRequest request = createRequest(s);
110111
requests.add(request);
111112

@@ -245,7 +246,9 @@ public void enqueue(MockResponse response) {
245246
getStartOfBody(response)
246247
);
247248
}
249+
System.out.println("DONAL: adding response " + response);
248250
responses.add(response);
251+
System.out.println("DONAL: done adding response " + response);
249252
}
250253

251254
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,16 @@ public void chunkedInfer(
143143
if (validationException.validationErrors().isEmpty() == false) {
144144
throw validationException;
145145
}
146-
147-
// a non-null query is not supported and is dropped by all providers
148-
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
146+
if (supportsChunkedInfer()) {
147+
if (input.isEmpty()) {
148+
listener.onResponse(List.of());
149+
} else {
150+
// a non-null query is not supported and is dropped by all providers
151+
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
152+
}
153+
} else {
154+
listener.onFailure(new UnsupportedOperationException(Strings.format("%s service does not support chunked inference", name())));
155+
}
149156
}
150157

151158
protected abstract void doInfer(
@@ -176,6 +183,10 @@ protected abstract void doChunkedInfer(
176183
ActionListener<List<ChunkedInference>> listener
177184
);
178185

186+
protected boolean supportsChunkedInfer() {
187+
return true;
188+
}
189+
179190
public void start(Model model, ActionListener<Boolean> listener) {
180191
init();
181192
doStart(model, listener);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,15 @@ protected void doChunkedInfer(
228228
TimeValue timeout,
229229
ActionListener<List<ChunkedInference>> listener
230230
) {
231+
// Should never be called
231232
throw new UnsupportedOperationException("Anthropic service does not support chunked inference");
232233
}
233234

235+
@Override
236+
protected boolean supportsChunkedInfer() {
237+
return false;
238+
}
239+
234240
@Override
235241
public TransportVersion getMinimalSupportedVersion() {
236242
return TransportVersions.V_8_15_0;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,15 @@ protected void doChunkedInfer(
112112
TimeValue timeout,
113113
ActionListener<List<ChunkedInference>> listener
114114
) {
115+
// Should never be called
115116
listener.onFailure(new UnsupportedOperationException(Strings.format("The %s service only supports unified completion", NAME)));
116117
}
117118

119+
@Override
120+
protected boolean supportsChunkedInfer() {
121+
return false;
122+
}
123+
118124
@Override
119125
public String name() {
120126
return NAME;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ public void chunkedInfer(
251251
listener.onFailure(createInvalidModelException(model));
252252
return;
253253
}
254+
if (input.isEmpty()) {
255+
listener.onResponse(List.of());
256+
}
254257
try {
255258
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
256259
var batchedRequests = new EmbeddingRequestChunker<>(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
7070
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
7171
import static org.hamcrest.CoreMatchers.is;
72+
import static org.hamcrest.Matchers.empty;
7273
import static org.hamcrest.Matchers.hasSize;
7374
import static org.hamcrest.Matchers.instanceOf;
7475
import static org.mockito.Mockito.mock;
@@ -447,6 +448,27 @@ public void testChunkedInfer_SparseEmbeddingChunkingSettingsNotSet() throws IOEx
447448
testChunkedInfer(TaskType.SPARSE_EMBEDDING, null);
448449
}
449450

451+
public void testChunkedInfer_noInputs() throws IOException {
452+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
453+
454+
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
455+
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
456+
var model = createModelForTaskType(randomFrom(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING), null);
457+
458+
service.chunkedInfer(
459+
model,
460+
null,
461+
List.of(),
462+
new HashMap<>(),
463+
InputTypeTests.randomWithIngestAndSearch(),
464+
InferenceAction.Request.DEFAULT_TIMEOUT,
465+
listener
466+
);
467+
468+
}
469+
assertThat(listener.actionGet(TIMEOUT), empty());
470+
}
471+
450472
private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettings) throws IOException {
451473
var input = List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar"));
452474

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
import static org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettingsTests.createEmbeddingsRequestSettingsMap;
8585
import static org.hamcrest.CoreMatchers.is;
8686
import static org.hamcrest.Matchers.containsString;
87+
import static org.hamcrest.Matchers.empty;
8788
import static org.hamcrest.Matchers.hasSize;
8889
import static org.hamcrest.Matchers.instanceOf;
8990
import static org.mockito.Mockito.mock;
@@ -1274,6 +1275,43 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
12741275
testChunkedInfer(model);
12751276
}
12761277

1278+
public void testChunkedInfer_noInputs() throws IOException {
1279+
var model = AmazonBedrockEmbeddingsModelTests.createModel(
1280+
"id",
1281+
"region",
1282+
"model",
1283+
AmazonBedrockProvider.AMAZONTITAN,
1284+
null,
1285+
"access",
1286+
"secret"
1287+
);
1288+
1289+
var sender = mock(Sender.class);
1290+
var factory = mock(HttpRequestSender.Factory.class);
1291+
when(factory.createSender()).thenReturn(sender);
1292+
1293+
var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory(
1294+
ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY),
1295+
mockClusterServiceEmpty()
1296+
);
1297+
1298+
try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) {
1299+
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
1300+
service.chunkedInfer(
1301+
model,
1302+
null,
1303+
List.of(),
1304+
new HashMap<>(),
1305+
InputType.INTERNAL_INGEST,
1306+
InferenceAction.Request.DEFAULT_TIMEOUT,
1307+
listener
1308+
);
1309+
1310+
var results = listener.actionGet(TIMEOUT);
1311+
assertThat(results, empty());
1312+
}
1313+
}
1314+
12771315
private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOException {
12781316
var sender = mock(Sender.class);
12791317
var factory = mock(HttpRequestSender.Factory.class);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
import static org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRequestFields.API_KEY_HEADER;
8383
import static org.hamcrest.CoreMatchers.is;
8484
import static org.hamcrest.Matchers.containsString;
85+
import static org.hamcrest.Matchers.empty;
8586
import static org.hamcrest.Matchers.equalTo;
8687
import static org.hamcrest.Matchers.hasSize;
8788
import static org.hamcrest.Matchers.instanceOf;
@@ -1061,6 +1062,27 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
10611062
testChunkedInfer(model);
10621063
}
10631064

1065+
public void testChunkedInfer_noInputs() throws IOException {
1066+
var model = AzureAiStudioEmbeddingsModelTests.createModel(
1067+
"id",
1068+
getUrl(webServer),
1069+
AzureAiStudioProvider.OPENAI,
1070+
AzureAiStudioEndpointType.TOKEN,
1071+
"apikey"
1072+
);
1073+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
1074+
1075+
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
1076+
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
1077+
List<ChunkInferenceInput> input = List.of();
1078+
service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener);
1079+
1080+
var results = listener.actionGet(TIMEOUT);
1081+
assertThat(results, empty());
1082+
assertThat(webServer.requests(), empty());
1083+
}
1084+
}
1085+
10641086
private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException {
10651087
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
10661088

0 commit comments

Comments
 (0)