Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.XContentFormattedException;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;

import java.io.IOException;
Expand Down Expand Up @@ -363,9 +362,8 @@ private static class RandomStringCollector {
private void collect(String str) throws IOException {
sseParser.parse(str.getBytes(StandardCharsets.UTF_8))
.stream()
.filter(event -> event.name() == ServerSentEventField.DATA)
.filter(ServerSentEvent::hasValue)
.map(ServerSentEvent::value)
.filter(ServerSentEvent::hasData)
.map(ServerSentEvent::data)
.forEach(stringsVerified::offer);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;

import java.io.IOException;
import java.util.ArrayDeque;
Expand Down Expand Up @@ -40,7 +39,7 @@ public static <ParsedChunk> Deque<ParsedChunk> parseEvent(
) throws Exception {
var results = new ArrayDeque<ParsedChunk>(item.size());
for (ServerSentEvent event : item) {
if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
if (event.hasData()) {
try {
var delta = parseFunction.apply(parserConfig, event);
delta.forEachRemaining(results::offer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;

import java.io.IOException;
import java.util.ArrayDeque;
Expand All @@ -42,8 +41,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {

var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
for (var event : item) {
if (event.name() == ServerSentEventField.DATA && event.hasValue()) {
try (var parser = parser(event.value())) {
if (event.hasData()) {
try (var parser = parser(event.data())) {
var eventType = eventType(parser);
switch (eventType) {
case "error" -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;

import java.io.IOException;
import java.util.ArrayDeque;
Expand All @@ -37,8 +36,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
for (ServerSentEvent event : item) {
if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) {
if (event.hasData()) {
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
var delta = content.apply(jsonParser);
results.offer(new StreamingChatCompletionResults.Result(delta));
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {

private static Iterator<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event)
throws IOException {
if (DONE_MESSAGE.equalsIgnoreCase(event.value())) {
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
return Collections.emptyIterator();
}

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) {
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
moveToFirstToken(jsonParser);

XContentParser.Token token = jsonParser.currentToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;

import java.io.IOException;
import java.util.ArrayDeque;
Expand Down Expand Up @@ -62,7 +61,6 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<

private final BiFunction<String, Exception, Exception> errorParser;
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
private volatile boolean previousEventWasError = false;

public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
this.errorParser = errorParser;
Expand All @@ -83,19 +81,15 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {

var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(item.size());
for (var event : item) {
if (ServerSentEventField.EVENT == event.name() && "error".equals(event.value())) {
previousEventWasError = true;
} else if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
if (previousEventWasError) {
throw errorParser.apply(event.value(), null);
}

if ("error".equals(event.type()) && event.hasData()) {
throw errorParser.apply(event.data(), null);
} else if (event.hasData()) {
try {
var delta = parse(parserConfig, event);
delta.forEachRemaining(results::offer);
} catch (Exception e) {
logger.warn("Failed to parse event from inference provider: {}", event);
throw errorParser.apply(event.value(), e);
throw errorParser.apply(event.data(), e);
}
}
}
Expand All @@ -118,11 +112,11 @@ private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChun
XContentParserConfiguration parserConfig,
ServerSentEvent event
) throws IOException {
if (DONE_MESSAGE.equalsIgnoreCase(event.value())) {
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
return Collections.emptyIterator();
}

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) {
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
moveToFirstToken(jsonParser);

XContentParser.Token token = jsonParser.currentToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,26 @@

/**
* Server-Sent Event message: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
* Messages always contain a {@link ServerSentEventField} and a non-null payload value.
* When the stream is parsed and there is no value associated with a {@link ServerSentEventField}, an empty-string is set as the value.
*/
public record ServerSentEvent(ServerSentEventField name, String value) {
public record ServerSentEvent(String type, String data) {

private static final String EMPTY = "";
private static final String MESSAGE = "message";

public ServerSentEvent(ServerSentEventField name) {
this(name, EMPTY);
public static ServerSentEvent empty() {
return new ServerSentEvent(EMPTY, EMPTY);
}

// treat null value as an empty string, don't break parsing
public ServerSentEvent(ServerSentEventField name, String value) {
this.name = name;
this.value = value != null ? value : EMPTY;
public ServerSentEvent(String data) {
this(MESSAGE, data);
}

public boolean hasValue() {
return value.isBlank() == false;
public ServerSentEvent {
data = data != null ? data : EMPTY;
type = type != null && type.isBlank() == false ? type : MESSAGE;
}

public boolean hasData() {
return data.isBlank() == false;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.nio.charset.StandardCharsets;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Locale;
import java.util.Optional;
import java.util.regex.Pattern;

Expand All @@ -20,11 +21,15 @@
* If the line starts with a colon, we discard this event.
* If the line contains a colon, we process it into {@link ServerSentEvent} with a non-empty value.
* If the line does not contain a colon, we process it into {@link ServerSentEvent}with an empty string value.
* If the line's field is not one of {@link ServerSentEventField}, we discard this event.
* If the line's field is not one of (data, event), we discard this event. `id` and `retry` are not implemented, because we do not use them
* and have no plans to use them.
*/
public class ServerSentEventParser {
private static final Pattern END_OF_LINE_REGEX = Pattern.compile("\\n|\\r|\\r\\n");
private static final Pattern END_OF_LINE_REGEX = Pattern.compile("\\r\\n|\\n|\\r");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at the String class docs for a while and was pleased to find there is actually a method for splitting a string like this

https://docs.oracle.com/en/java/javase/17/docs/api/java.base/java/lang/String.html#lines()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I didn't know that existed. Let's use that, it does change the logic a bit though

private static final String BOM = "\uFEFF";
private static final String TYPE_FIELD = "event";
private static final String DATA_FIELD = "data";
private final EventBuffer eventBuffer = new EventBuffer();
private volatile String previousTokens = "";

public Deque<ServerSentEvent> parse(byte[] bytes) {
Expand All @@ -39,11 +44,13 @@ public Deque<ServerSentEvent> parse(byte[] bytes) {
for (var i = 0; i < lines.length - 1; i++) {
var line = lines[i].replace(BOM, "");

if (line.isBlank() == false && line.startsWith(":") == false) {
if (line.isBlank()) {
eventBuffer.dispatch().ifPresent(collector::offer);
} else if (line.startsWith(":") == false) {
if (line.contains(":")) {
fieldValueEvent(line).ifPresent(collector::offer);
} else {
ServerSentEventField.oneOf(line).map(ServerSentEvent::new).ifPresent(collector::offer);
fieldValueEvent(line);
} else if (DATA_FIELD.equals(line.toLowerCase(Locale.ROOT))) {
eventBuffer.data("");
}
}
}
Expand All @@ -55,21 +62,64 @@ public Deque<ServerSentEvent> parse(byte[] bytes) {
return collector;
}

private Optional<ServerSentEvent> fieldValueEvent(String lineWithColon) {
private void fieldValueEvent(String lineWithColon) {
var firstColon = lineWithColon.indexOf(":");
var fieldStr = lineWithColon.substring(0, firstColon);
var serverSentField = ServerSentEventField.oneOf(fieldStr);

if ((firstColon + 1) != lineWithColon.length()) {
var value = lineWithColon.substring(firstColon + 1);
if (value.equals(" ") == false) {
var trimmedValue = value.charAt(0) == ' ' ? value.substring(1) : value;
return serverSentField.map(field -> new ServerSentEvent(field, trimmedValue));
var fieldStr = lineWithColon.substring(0, firstColon).toLowerCase(Locale.ROOT);

var value = lineWithColon.substring(firstColon + 1);
var trimmedValue = value.length() > 0 && value.charAt(0) == ' ' ? value.substring(1) : value;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not String::trim() or String:: stripLeading()?
Is the idea to literally remove the first space char only

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's literally remove the first space char only:

Collect the characters on the line after the first U+003A COLON character (:), and let value be that string. If value starts with a U+0020 SPACE character, remove it from value.

https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation

Or at least I'm interpreting that as "if there are two or more spaces, only remove one space"


if (DATA_FIELD.equals(fieldStr)) {
eventBuffer.data(trimmedValue);
} else if (TYPE_FIELD.equals(fieldStr)) {
eventBuffer.type(trimmedValue);
}
}

private static class EventBuffer {
private static final char LINE_FEED = '\n';
private static final String MESSAGE = "message";
private StringBuilder type = new StringBuilder();
private StringBuilder data = new StringBuilder();
private boolean appendLineFeed = false;

private void type(String type) {
this.type.append(type);
}

private void data(String data) {
if (appendLineFeed) {
this.data.append(LINE_FEED);
} else {
// the next time we add data, append line feed
appendLineFeed = true;
}
this.data.append(data);
}

// if we have "data:" or "data: ", treat it like a no-value line
return serverSentField.map(ServerSentEvent::new);
private Optional<ServerSentEvent> dispatch() {
var dataValue = data.toString();

// if the data buffer is empty, reset without dispatching
if (dataValue.isEmpty()) {
reset();
return Optional.empty();
}

// if the type buffer is not empty, set that as the type, else default to message
var typeValue = type.toString();
typeValue = typeValue.isBlank() ? MESSAGE : typeValue;

reset();

return Optional.of(new ServerSentEvent(typeValue, dataValue));
}

private void reset() {
type = new StringBuilder();
data = new StringBuilder();
appendLineFeed = false;
}
}

}
Loading