Skip to content

Commit ce25f45

Browse files
[ML] Remove regex (#113210) (#113380)
Regex is having trouble parsing some of the larger UTF8 characters, so instead we are just going to use our non-regex parser. Fix #113179 Fix #113148 Co-authored-by: Elastic Machine <[email protected]>
1 parent f95f292 commit ce25f45

File tree

1 file changed

+17
-32
lines changed

1 file changed

+17
-32
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
import org.elasticsearch.test.ESIntegTestCase;
4848
import org.elasticsearch.xcontent.ToXContent;
4949
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
50+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
51+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
52+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
5053

5154
import java.io.IOException;
5255
import java.nio.charset.StandardCharsets;
@@ -59,7 +62,6 @@
5962
import java.util.concurrent.Flow;
6063
import java.util.concurrent.LinkedBlockingDeque;
6164
import java.util.concurrent.TimeUnit;
62-
import java.util.concurrent.atomic.AtomicBoolean;
6365
import java.util.concurrent.atomic.AtomicInteger;
6466
import java.util.concurrent.atomic.AtomicReference;
6567
import java.util.function.Predicate;
@@ -80,9 +82,7 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
8082
private static final String NO_STREAM_ROUTE = "/_inference_no_stream";
8183
private static final Exception expectedException = new IllegalStateException("hello there");
8284
private static final String expectedExceptionAsServerSentEvent = """
83-
\uFEFF\
84-
event: error
85-
data: {\
85+
{\
8686
"error":{"root_cause":[{"type":"illegal_state_exception","reason":"hello there",\
8787
"caused_by":{"type":"illegal_state_exception","reason":"hello there"}}],\
8888
"type":"illegal_state_exception","reason":"hello there"},"status":500\
@@ -323,30 +323,16 @@ protected void releaseResources() {}
323323
}
324324

325325
private static class RandomStringCollector {
326-
private static final Pattern jsonPattern = Pattern.compile("^\uFEFFevent: message\ndata: \\{.*}$");
327-
private static final Pattern endPattern = Pattern.compile("^\uFEFFevent: message\ndata: \\[DONE\\]$");
328-
private final AtomicBoolean hasDoneChunk = new AtomicBoolean(false);
329326
private final Deque<String> stringsVerified = new LinkedBlockingDeque<>();
330-
private volatile String previousTokens = "";
327+
private final ServerSentEventParser sseParser = new ServerSentEventParser();
331328

332329
private void collect(String str) throws IOException {
333-
str = previousTokens + str;
334-
String[] events = str.split("\n\n", -1);
335-
for (var i = 0; i < events.length - 1; i++) {
336-
var line = events[i];
337-
if (jsonPattern.matcher(line).matches() || expectedExceptionAsServerSentEvent.equals(line)) {
338-
stringsVerified.offer(line);
339-
} else if (endPattern.matcher(line).matches()) {
340-
hasDoneChunk.set(true);
341-
} else {
342-
throw new IOException("Line does not match expected JSON message or DONE message. Line: " + line);
343-
}
344-
}
345-
346-
previousTokens = events[events.length - 1];
347-
if (endPattern.matcher(previousTokens.trim()).matches()) {
348-
hasDoneChunk.set(true);
349-
}
330+
sseParser.parse(str.getBytes(StandardCharsets.UTF_8))
331+
.stream()
332+
.filter(event -> event.name() == ServerSentEventField.DATA)
333+
.filter(ServerSentEvent::hasValue)
334+
.map(ServerSentEvent::value)
335+
.forEach(stringsVerified::offer);
350336
}
351337
}
352338

@@ -363,8 +349,8 @@ public void testResponse() {
363349

364350
var response = callAsync(request);
365351
assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_OK));
366-
assertThat(collector.stringsVerified.size(), equalTo(expectedTestCount));
367-
assertThat(collector.hasDoneChunk.get(), equalTo(true));
352+
assertThat(collector.stringsVerified.size(), equalTo(expectedTestCount + 1)); // normal payload count + done byte
353+
assertThat(collector.stringsVerified.peekLast(), equalTo("[DONE]"));
368354
}
369355

370356
private Response callAsync(Request request) {
@@ -409,10 +395,9 @@ public void testOnFailure() throws IOException {
409395
} catch (ResponseException e) {
410396
var response = e.getResponse();
411397
assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
412-
assertThat(
413-
EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8),
414-
equalTo(expectedExceptionAsServerSentEvent + "\n\n")
415-
);
398+
assertThat(EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8), equalTo("""
399+
\uFEFFevent: error
400+
data:\s""" + expectedExceptionAsServerSentEvent + "\n\n"));
416401
}
417402
}
418403

@@ -431,7 +416,7 @@ public void testErrorMidStream() {
431416
var response = callAsync(request);
432417
assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_OK)); // error still starts with 200-OK
433418
assertThat(collector.stringsVerified.size(), equalTo(expectedTestCount + 1)); // normal payload count + last error byte
434-
assertThat("DONE chunk is not sent on error", collector.hasDoneChunk.get(), equalTo(false));
419+
assertThat("DONE chunk is not sent on error", collector.stringsVerified.stream().anyMatch("[DONE]"::equals), equalTo(false));
435420
assertThat(collector.stringsVerified.getLast(), equalTo(expectedExceptionAsServerSentEvent));
436421
}
437422

0 commit comments

Comments
 (0)