Skip to content

Commit 9fed31c

Browse files
authored
[8.17][ML] Fix timeout ingesting an empty string into a semantic_text field (#118746)
Backport of #117840
1 parent df0956d commit 9fed31c

File tree

8 files changed

+164
-7
lines changed

8 files changed

+164
-7
lines changed

docs/changelog/117840.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117840
2+
summary: Fix timeout ingesting an empty string into a `semantic_text` field
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
6363
*
6464
* @param input Text to chunk
6565
* @param maxNumberWordsPerChunk Maximum size of the chunk
66-
* @return The input text chunked
66+
* @param includePrecedingSentence Include the previous sentence
67+
* @return The input text offsets
6768
*/
6869
public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
6970
var chunks = new ArrayList<String>();
@@ -154,6 +155,11 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean incl
154155
chunks.add(input.substring(chunkStart));
155156
}
156157

158+
if (chunks.isEmpty()) {
159+
// The input did not chunk, return the entire input
160+
chunks.add(input);
161+
}
162+
157163
return chunks;
158164
}
159165

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,6 @@ List<ChunkPosition> chunkPositions(String input, int chunkSize, int overlap) {
104104
throw new IllegalArgumentException("Invalid chunking parameters, overlap [" + overlap + "] must be >= 0");
105105
}
106106

107-
if (input.isEmpty()) {
108-
return List.of();
109-
}
110-
111107
var chunkPositions = new ArrayList<ChunkPosition>();
112108

113109
// This position in the chunk is where the next overlapping chunk will start

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
1919
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
2020
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
21+
import org.hamcrest.Matchers;
2122

2223
import java.util.ArrayList;
2324
import java.util.List;
@@ -31,16 +32,62 @@
3132

3233
public class EmbeddingRequestChunkerTests extends ESTestCase {
3334

34-
public void testEmptyInput() {
35+
public void testEmptyInput_WordChunker() {
3536
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
3637
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
3738
assertThat(batches, empty());
3839
}
3940

40-
public void testBlankInput() {
41+
public void testEmptyInput_SentenceChunker() {
42+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
43+
var batches = new EmbeddingRequestChunker(List.of(), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
44+
.batchRequestsWithListeners(testListener());
45+
assertThat(batches, empty());
46+
}
47+
48+
public void testWhitespaceInput_SentenceChunker() {
49+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
50+
var batches = new EmbeddingRequestChunker(List.of(" "), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
51+
.batchRequestsWithListeners(testListener());
52+
assertThat(batches, hasSize(1));
53+
assertThat(batches.get(0).batch().inputs(), hasSize(1));
54+
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" "));
55+
}
56+
57+
public void testBlankInput_WordChunker() {
4158
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
4259
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
4360
assertThat(batches, hasSize(1));
61+
assertThat(batches.get(0).batch().inputs(), hasSize(1));
62+
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
63+
}
64+
65+
public void testBlankInput_SentenceChunker() {
66+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
67+
var batches = new EmbeddingRequestChunker(List.of(""), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
68+
.batchRequestsWithListeners(testListener());
69+
assertThat(batches, hasSize(1));
70+
assertThat(batches.get(0).batch().inputs(), hasSize(1));
71+
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
72+
}
73+
74+
public void testInputThatDoesNotChunk_WordChunker() {
75+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
76+
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10, embeddingType).batchRequestsWithListeners(
77+
testListener()
78+
);
79+
assertThat(batches, hasSize(1));
80+
assertThat(batches.get(0).batch().inputs(), hasSize(1));
81+
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
82+
}
83+
84+
public void testInputThatDoesNotChunk_SentenceChunker() {
85+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
86+
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
87+
.batchRequestsWithListeners(testListener());
88+
assertThat(batches, hasSize(1));
89+
assertThat(batches.get(0).batch().inputs(), hasSize(1));
90+
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
4491
}
4592

4693
public void testShortInputsAreSingleBatch() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import java.util.ArrayList;
1717
import java.util.Arrays;
18+
import java.util.List;
1819
import java.util.Locale;
1920

2021
import static org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT;
@@ -27,6 +28,53 @@
2728

2829
public class SentenceBoundaryChunkerTests extends ESTestCase {
2930

31+
/**
32+
* Utility method for testing.
33+
*/
34+
private List<String> textChunks(
35+
SentenceBoundaryChunker chunker,
36+
String input,
37+
int maxNumberWordsPerChunk,
38+
boolean includePrecedingSentence
39+
) {
40+
return chunker.chunk(input, maxNumberWordsPerChunk, includePrecedingSentence);
41+
}
42+
43+
public void testEmptyString() {
44+
var chunks = textChunks(new SentenceBoundaryChunker(), "", 100, randomBoolean());
45+
assertThat(chunks, hasSize(1));
46+
assertThat(chunks.get(0), Matchers.is(""));
47+
}
48+
49+
public void testBlankString() {
50+
var chunks = textChunks(new SentenceBoundaryChunker(), " ", 100, randomBoolean());
51+
assertThat(chunks, hasSize(1));
52+
assertThat(chunks.get(0), Matchers.is(" "));
53+
}
54+
55+
public void testSingleChar() {
56+
var chunks = textChunks(new SentenceBoundaryChunker(), " b", 100, randomBoolean());
57+
assertThat(chunks, Matchers.contains(" b"));
58+
59+
chunks = textChunks(new SentenceBoundaryChunker(), "b", 100, randomBoolean());
60+
assertThat(chunks, Matchers.contains("b"));
61+
62+
chunks = textChunks(new SentenceBoundaryChunker(), ". ", 100, randomBoolean());
63+
assertThat(chunks, Matchers.contains(". "));
64+
65+
chunks = textChunks(new SentenceBoundaryChunker(), " , ", 100, randomBoolean());
66+
assertThat(chunks, Matchers.contains(" , "));
67+
68+
chunks = textChunks(new SentenceBoundaryChunker(), " ,", 100, randomBoolean());
69+
assertThat(chunks, Matchers.contains(" ,"));
70+
}
71+
72+
public void testSingleCharRepeated() {
73+
var input = "a".repeat(32_000);
74+
var chunks = textChunks(new SentenceBoundaryChunker(), input, 100, randomBoolean());
75+
assertThat(chunks, Matchers.contains(input));
76+
}
77+
3078
public void testChunkSplitLargeChunkSizes() {
3179
for (int maxWordsPerChunk : new int[] { 100, 200 }) {
3280
var chunker = new SentenceBoundaryChunker();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.elasticsearch.inference.ChunkingSettings;
1313
import org.elasticsearch.test.ESTestCase;
14+
import org.hamcrest.Matchers;
1415

1516
import java.util.List;
1617
import java.util.Locale;
@@ -226,6 +227,39 @@ public void testWhitespace() {
226227
assertThat(chunks, contains(" "));
227228
}
228229

230+
private List<String> textChunks(WordBoundaryChunker chunker, String input, int chunkSize, int overlap) {
231+
return chunker.chunk(input, chunkSize, overlap);
232+
}
233+
234+
public void testBlankString() {
235+
var chunks = textChunks(new WordBoundaryChunker(), " ", 100, 10);
236+
assertThat(chunks, hasSize(1));
237+
assertThat(chunks.get(0), Matchers.is(" "));
238+
}
239+
240+
public void testSingleChar() {
241+
var chunks = textChunks(new WordBoundaryChunker(), " b", 100, 10);
242+
assertThat(chunks, Matchers.contains(" b"));
243+
244+
chunks = textChunks(new WordBoundaryChunker(), "b", 100, 10);
245+
assertThat(chunks, Matchers.contains("b"));
246+
247+
chunks = textChunks(new WordBoundaryChunker(), ". ", 100, 10);
248+
assertThat(chunks, Matchers.contains(". "));
249+
250+
chunks = textChunks(new WordBoundaryChunker(), " , ", 100, 10);
251+
assertThat(chunks, Matchers.contains(" , "));
252+
253+
chunks = textChunks(new WordBoundaryChunker(), " ,", 100, 10);
254+
assertThat(chunks, Matchers.contains(" ,"));
255+
}
256+
257+
public void testSingleCharRepeated() {
258+
var input = "a".repeat(32_000);
259+
var chunks = textChunks(new WordBoundaryChunker(), input, 100, 10);
260+
assertThat(chunks, Matchers.contains(input));
261+
}
262+
229263
public void testPunctuation() {
230264
int chunkSize = 1;
231265
var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0);

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,22 @@ public void testDeploymentThreadsIncludedInUsage() throws IOException {
11421142
}
11431143
}
11441144

1145+
public void testInferEmptyInput() throws IOException {
1146+
String modelId = "empty_input";
1147+
createPassThroughModel(modelId);
1148+
putModelDefinition(modelId);
1149+
putVocabulary(List.of("these", "are", "my", "words"), modelId);
1150+
startDeployment(modelId);
1151+
1152+
Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=30s");
1153+
request.setJsonEntity("""
1154+
{ "docs": [] }
1155+
""");
1156+
1157+
var inferenceResponse = client().performRequest(request);
1158+
assertThat(EntityUtils.toString(inferenceResponse.getEntity()), equalTo("{\"inference_results\":[]}"));
1159+
}
1160+
11451161
private void putModelDefinition(String modelId) throws IOException {
11461162
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
11471163
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
132132
Response.Builder responseBuilder = Response.builder();
133133
TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());
134134

135+
if (request.numberOfDocuments() == 0) {
136+
listener.onResponse(responseBuilder.setId(request.getId()).build());
137+
return;
138+
}
139+
135140
if (MachineLearning.INFERENCE_AGG_FEATURE.check(licenseState)) {
136141
responseBuilder.setLicensed(true);
137142
doInfer(task, request, responseBuilder, parentTaskId, listener);

0 commit comments

Comments
 (0)