Skip to content

Commit 7c2e075

Browse files
authored
[ML] Ignore unrecognized openai sse fields (#114715)
Azure / Llama sends back fields we do not expect - rewriting the parser to better handle unknown fields (by dropping them).
1 parent 74522c4 commit 7c2e075

File tree

3 files changed

+59
-14
lines changed

3 files changed

+59
-14
lines changed

docs/changelog/114715.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114715
2+
summary: Ignore unrecognized openai sse fields
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,6 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSe
110110
private static final String CHOICES_FIELD = "choices";
111111
private static final String DELTA_FIELD = "delta";
112112
private static final String CONTENT_FIELD = "content";
113-
private static final String FINISH_REASON_FIELD = "finish_reason";
114-
private static final String STOP_MESSAGE = "stop";
115113
private static final String DONE_MESSAGE = "[done]";
116114

117115
@Override
@@ -162,21 +160,27 @@ private Iterator<StreamingChatCompletionResults.Result> parse(XContentParserConf
162160
ensureExpectedToken(XContentParser.Token.START_OBJECT, currentToken, parser);
163161

164162
currentToken = parser.nextToken();
165-
if (currentToken == XContentParser.Token.END_OBJECT) {
166-
consumeUntilObjectEnd(parser); // end choices
167-
return ""; // stopped
168-
}
169163

170-
if (currentToken == XContentParser.Token.FIELD_NAME && parser.currentName().equals(CONTENT_FIELD)) {
171-
parser.nextToken();
172-
} else {
173-
positionParserAtTokenAfterField(parser, CONTENT_FIELD, FAILED_TO_FIND_FIELD_TEMPLATE);
164+
// continue until the end of delta
165+
while (currentToken != null && currentToken != XContentParser.Token.END_OBJECT) {
166+
if (currentToken == XContentParser.Token.START_OBJECT || currentToken == XContentParser.Token.START_ARRAY) {
167+
parser.skipChildren();
168+
}
169+
170+
if (currentToken == XContentParser.Token.FIELD_NAME && parser.currentName().equals(CONTENT_FIELD)) {
171+
parser.nextToken();
172+
ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
173+
var content = parser.text();
174+
consumeUntilObjectEnd(parser); // end delta
175+
consumeUntilObjectEnd(parser); // end choices
176+
return content;
177+
}
178+
179+
currentToken = parser.nextToken();
174180
}
175-
ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
176-
var content = parser.text();
177-
consumeUntilObjectEnd(parser); // end delta
181+
178182
consumeUntilObjectEnd(parser); // end choices
179-
return content;
183+
return ""; // stopped
180184
}).stream()
181185
.filter(Objects::nonNull)
182186
.filter(Predicate.not(String::isEmpty))

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessorTests.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,42 @@ public void testDoneMessageIsIgnored() throws Exception {
149149
verify(downstream, times(0)).onNext(any());
150150
}
151151

152+
public void testInitialLlamaResponseIsIgnored() throws Exception {
153+
var item = new ArrayDeque<ServerSentEvent>();
154+
item.offer(new ServerSentEvent(ServerSentEventField.DATA, """
155+
{
156+
"id":"12345",
157+
"object":"chat.completion.chunk",
158+
"created":123456789,
159+
"model":"Llama-2-7b-chat",
160+
"system_fingerprint": "123456789",
161+
"choices":[
162+
{
163+
"index":0,
164+
"delta":{
165+
"role":"assistant"
166+
},
167+
"logprobs":null,
168+
"finish_reason":null
169+
}
170+
]
171+
}
172+
"""));
173+
174+
var processor = new OpenAiStreamingProcessor();
175+
176+
Flow.Subscriber<ChunkedToXContent> downstream = mock();
177+
processor.subscribe(downstream);
178+
179+
Flow.Subscription upstream = mock();
180+
processor.onSubscribe(upstream);
181+
182+
processor.next(item);
183+
184+
verify(upstream, times(1)).request(1);
185+
verify(downstream, times(0)).onNext(any());
186+
}
187+
152188
private String toJsonString(ChunkedToXContent chunkedToXContent) throws IOException {
153189
try (var builder = XContentFactory.jsonBuilder()) {
154190
chunkedToXContent.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {

0 commit comments

Comments
 (0)