Skip to content

Commit 6ee2618

Browse files
committed
[ML] Refactor SSE Parsing (elastic#125959)
ServerSentEvent is now a record with `event` and `data`, rather than it being a record for value with a separate `ServerSentEventField`. - `value` was renamed to `data` - `hasValue` was renamed to `hasData` - Parsing was refactored to read more like its spec: writing to a buffer and flushing when we reach a blank newline - We now support multiline data payloads
1 parent 49601ae commit 6ee2618

File tree

15 files changed

+209
-158
lines changed

15 files changed

+209
-158
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -299,17 +299,13 @@ public void testUnsupportedStream() throws Exception {
299299

300300
try {
301301
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID()), null);
302-
assertThat(events.size(), equalTo(2));
302+
assertThat(events.size(), equalTo(1));
303303
events.forEach(event -> {
304-
switch (event.name()) {
305-
case EVENT -> assertThat(event.value(), equalToIgnoringCase("error"));
306-
case DATA -> assertThat(
307-
event.value(),
308-
containsString(
309-
"Streaming is not allowed for service [streaming_completion_test_service] and task [sparse_embedding]"
310-
)
311-
);
312-
}
304+
assertThat(event.type(), equalToIgnoringCase("error"));
305+
assertThat(
306+
event.data(),
307+
containsString("Streaming is not allowed for service [streaming_completion_test_service] and task [sparse_embedding]")
308+
);
313309
});
314310
} finally {
315311
deleteModel(modelId);
@@ -331,12 +327,10 @@ public void testSupportedStream() throws Exception {
331327
input.stream().map(s -> s.toUpperCase(Locale.ROOT)).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"),
332328
Stream.of("[DONE]")
333329
).iterator();
334-
assertThat(events.size(), equalTo((input.size() + 1) * 2));
330+
assertThat(events.size(), equalTo(input.size() + 1));
335331
events.forEach(event -> {
336-
switch (event.name()) {
337-
case EVENT -> assertThat(event.value(), equalToIgnoringCase("message"));
338-
case DATA -> assertThat(event.value(), equalTo(expectedResponses.next()));
339-
}
332+
assertThat(event.type(), equalToIgnoringCase("message"));
333+
assertThat(event.data(), equalTo(expectedResponses.next()));
340334
});
341335
} finally {
342336
deleteModel(modelId);
@@ -359,12 +353,10 @@ public void testUnifiedCompletionInference() throws Exception {
359353
VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER
360354
);
361355
var expectedResponses = expectedResultsIterator(input);
362-
assertThat(events.size(), equalTo((input.size() + 1) * 2));
356+
assertThat(events.size(), equalTo(input.size() + 1));
363357
events.forEach(event -> {
364-
switch (event.name()) {
365-
case EVENT -> assertThat(event.value(), equalToIgnoringCase("message"));
366-
case DATA -> assertThat(event.value(), equalTo(expectedResponses.next()));
367-
}
358+
assertThat(event.type(), equalToIgnoringCase("message"));
359+
assertThat(event.data(), equalTo(expectedResponses.next()));
368360
});
369361
} finally {
370362
deleteModel(modelId);

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
5252
import org.elasticsearch.xpack.core.inference.results.XContentFormattedException;
5353
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
54-
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
5554
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
5655

5756
import java.io.IOException;
@@ -364,9 +363,8 @@ private static class RandomStringCollector {
364363
private void collect(String str) throws IOException {
365364
sseParser.parse(str.getBytes(StandardCharsets.UTF_8))
366365
.stream()
367-
.filter(event -> event.name() == ServerSentEventField.DATA)
368-
.filter(ServerSentEvent::hasValue)
369-
.map(ServerSentEvent::value)
366+
.filter(ServerSentEvent::hasData)
367+
.map(ServerSentEvent::data)
370368
.forEach(stringsVerified::offer);
371369
}
372370
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.xcontent.XContentParserConfiguration;
1313
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
14-
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
1514

1615
import java.io.IOException;
1716
import java.util.ArrayDeque;
@@ -40,7 +39,7 @@ public static <ParsedChunk> Deque<ParsedChunk> parseEvent(
4039
) throws Exception {
4140
var results = new ArrayDeque<ParsedChunk>(item.size());
4241
for (ServerSentEvent event : item) {
43-
if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
42+
if (event.hasData()) {
4443
try {
4544
var delta = parseFunction.apply(parserConfig, event);
4645
delta.forEachRemaining(results::offer);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicStreamingProcessor.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
1919
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
2020
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
21-
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
2221

2322
import java.io.IOException;
2423
import java.util.ArrayDeque;
@@ -42,8 +41,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
4241

4342
var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
4443
for (var event : item) {
45-
if (event.name() == ServerSentEventField.DATA && event.hasValue()) {
46-
try (var parser = parser(event.value())) {
44+
if (event.hasData()) {
45+
try (var parser = parser(event.data())) {
4746
var eventType = eventType(parser);
4847
switch (eventType) {
4948
case "error" -> {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioStreamingProcessor.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
1919
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
2020
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
21-
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
2221

2322
import java.io.IOException;
2423
import java.util.ArrayDeque;
@@ -37,8 +36,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
3736
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
3837
var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
3938
for (ServerSentEvent event : item) {
40-
if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
41-
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) {
39+
if (event.hasData()) {
40+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
4241
var delta = content.apply(jsonParser);
4342
results.offer(new StreamingChatCompletionResults.Result(delta));
4443
} catch (Exception e) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,11 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
124124

125125
private static Iterator<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event)
126126
throws IOException {
127-
if (DONE_MESSAGE.equalsIgnoreCase(event.value())) {
127+
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
128128
return Collections.emptyIterator();
129129
}
130130

131-
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) {
131+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
132132
moveToFirstToken(jsonParser);
133133

134134
XContentParser.Token token = jsonParser.currentToken();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
2020
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
2121
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
22-
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
2322

2423
import java.io.IOException;
2524
import java.util.ArrayDeque;
@@ -62,7 +61,6 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
6261

6362
private final BiFunction<String, Exception, Exception> errorParser;
6463
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
65-
private volatile boolean previousEventWasError = false;
6664

6765
public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
6866
this.errorParser = errorParser;
@@ -83,19 +81,15 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
8381

8482
var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(item.size());
8583
for (var event : item) {
86-
if (ServerSentEventField.EVENT == event.name() && "error".equals(event.value())) {
87-
previousEventWasError = true;
88-
} else if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
89-
if (previousEventWasError) {
90-
throw errorParser.apply(event.value(), null);
91-
}
92-
84+
if ("error".equals(event.type()) && event.hasData()) {
85+
throw errorParser.apply(event.data(), null);
86+
} else if (event.hasData()) {
9387
try {
9488
var delta = parse(parserConfig, event);
9589
delta.forEachRemaining(results::offer);
9690
} catch (Exception e) {
9791
logger.warn("Failed to parse event from inference provider: {}", event);
98-
throw errorParser.apply(event.value(), e);
92+
throw errorParser.apply(event.data(), e);
9993
}
10094
}
10195
}
@@ -118,11 +112,11 @@ private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChun
118112
XContentParserConfiguration parserConfig,
119113
ServerSentEvent event
120114
) throws IOException {
121-
if (DONE_MESSAGE.equalsIgnoreCase(event.value())) {
115+
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
122116
return Collections.emptyIterator();
123117
}
124118

125-
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) {
119+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
126120
moveToFirstToken(jsonParser);
127121

128122
XContentParser.Token token = jsonParser.currentToken();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEvent.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,26 @@
99

1010
/**
1111
* Server-Sent Event message: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
12-
* Messages always contain a {@link ServerSentEventField} and a non-null payload value.
13-
* When the stream is parsed and there is no value associated with a {@link ServerSentEventField}, an empty-string is set as the value.
1412
*/
15-
public record ServerSentEvent(ServerSentEventField name, String value) {
13+
public record ServerSentEvent(String type, String data) {
1614

1715
private static final String EMPTY = "";
16+
private static final String MESSAGE = "message";
1817

19-
public ServerSentEvent(ServerSentEventField name) {
20-
this(name, EMPTY);
18+
public static ServerSentEvent empty() {
19+
return new ServerSentEvent(EMPTY, EMPTY);
2120
}
2221

23-
// treat null value as an empty string, don't break parsing
24-
public ServerSentEvent(ServerSentEventField name, String value) {
25-
this.name = name;
26-
this.value = value != null ? value : EMPTY;
22+
public ServerSentEvent(String data) {
23+
this(MESSAGE, data);
2724
}
2825

29-
public boolean hasValue() {
30-
return value.isBlank() == false;
26+
public ServerSentEvent {
27+
data = data != null ? data : EMPTY;
28+
type = type != null && type.isBlank() == false ? type : MESSAGE;
29+
}
30+
31+
public boolean hasData() {
32+
return data.isBlank() == false;
3133
}
3234
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/ServerSentEventField.java

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)