Skip to content

Commit ceef95a

Browse files
Refactor DeltaParser in LlamaStreamingProcessor to improve argument handling
1 parent cc14b18 commit ceef95a

File tree

2 files changed

+308
-0
lines changed

2 files changed

+308
-0
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@
77

88
package org.elasticsearch.xpack.inference.services.llama.completion;
99

10+
import org.elasticsearch.inference.InferenceServiceResults;
11+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
1012
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
1113
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1214
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1315
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1416
import org.elasticsearch.xpack.inference.external.request.Request;
17+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
18+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
1519
import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse;
1620
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
1721

1822
import java.util.Locale;
23+
import java.util.concurrent.Flow;
1924

2025
public class LlamaChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
2126

@@ -44,4 +49,13 @@ protected Exception buildError(String message, Request request, HttpResult resul
4449
return super.buildError(message, request, result, errorResponse);
4550
}
4651
}
52+
53+
@Override
54+
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
55+
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
56+
var openAiProcessor = new LlamaStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));
57+
flow.subscribe(serverSentEventProcessor);
58+
serverSentEventProcessor.subscribe(openAiProcessor);
59+
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
60+
}
4761
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.llama.completion;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
13+
import org.elasticsearch.xcontent.ConstructingObjectParser;
14+
import org.elasticsearch.xcontent.ParseField;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xcontent.XContentParser;
17+
import org.elasticsearch.xcontent.XContentParserConfiguration;
18+
import org.elasticsearch.xcontent.XContentType;
19+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
20+
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
21+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
22+
23+
import java.io.IOException;
24+
import java.util.ArrayDeque;
25+
import java.util.Collections;
26+
import java.util.Deque;
27+
import java.util.List;
28+
import java.util.function.BiFunction;
29+
import java.util.stream.Stream;
30+
31+
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
32+
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
33+
34+
public class LlamaStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, StreamingUnifiedChatCompletionResults.Results> {
35+
public static final String FUNCTION_FIELD = "function";
36+
private static final Logger logger = LogManager.getLogger(LlamaStreamingProcessor.class);
37+
38+
private static final String CHOICES_FIELD = "choices";
39+
private static final String DELTA_FIELD = "delta";
40+
private static final String CONTENT_FIELD = "content";
41+
private static final String FUNCTION_CALL_FIELD = "function_call";
42+
private static final String DONE_MESSAGE = "[done]";
43+
private static final String REFUSAL_FIELD = "refusal";
44+
private static final String TOOL_CALLS_FIELD = "tool_calls";
45+
public static final String ROLE_FIELD = "role";
46+
public static final String FINISH_REASON_FIELD = "finish_reason";
47+
public static final String INDEX_FIELD = "index";
48+
public static final String OBJECT_FIELD = "object";
49+
public static final String MODEL_FIELD = "model";
50+
public static final String ID_FIELD = "id";
51+
public static final String CHOICE_FIELD = "choice";
52+
public static final String USAGE_FIELD = "usage";
53+
public static final String TYPE_FIELD = "type";
54+
public static final String NAME_FIELD = "name";
55+
public static final String ARGUMENTS_FIELD = "arguments";
56+
public static final String COMPLETION_TOKENS_FIELD = "completion_tokens";
57+
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
58+
public static final String TOTAL_TOKENS_FIELD = "total_tokens";
59+
60+
private final BiFunction<String, Exception, Exception> errorParser;
61+
62+
public LlamaStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
63+
this.errorParser = errorParser;
64+
}
65+
66+
@Override
67+
protected void next(Deque<ServerSentEvent> item) throws Exception {
68+
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
69+
70+
var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(item.size());
71+
for (var event : item) {
72+
if ("error".equals(event.type()) && event.hasData()) {
73+
throw errorParser.apply(event.data(), null);
74+
} else if (event.hasData()) {
75+
try {
76+
var delta = parse(parserConfig, event);
77+
delta.forEach(results::offer);
78+
} catch (Exception e) {
79+
logger.warn("Failed to parse event from inference provider: {}", event);
80+
throw errorParser.apply(event.data(), e);
81+
}
82+
}
83+
}
84+
85+
if (results.isEmpty()) {
86+
upstream().request(1);
87+
} else {
88+
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
89+
}
90+
}
91+
92+
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
93+
XContentParserConfiguration parserConfig,
94+
ServerSentEvent event
95+
) throws IOException {
96+
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
97+
return Stream.empty();
98+
}
99+
100+
return parse(parserConfig, event.data());
101+
}
102+
103+
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
104+
XContentParserConfiguration parserConfig,
105+
String data
106+
) throws IOException {
107+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) {
108+
moveToFirstToken(jsonParser);
109+
110+
XContentParser.Token token = jsonParser.currentToken();
111+
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
112+
113+
StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser);
114+
115+
return Stream.of(chunk);
116+
}
117+
}
118+
119+
public static class ChatCompletionChunkParser {
120+
@SuppressWarnings("unchecked")
121+
private static final ConstructingObjectParser<StreamingUnifiedChatCompletionResults.ChatCompletionChunk, Void> PARSER =
122+
new ConstructingObjectParser<>(
123+
"chat_completion_chunk",
124+
true,
125+
args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
126+
(String) args[0],
127+
(List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice>) args[1],
128+
(String) args[2],
129+
(String) args[3],
130+
(StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage) args[4]
131+
)
132+
);
133+
134+
static {
135+
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(ID_FIELD));
136+
PARSER.declareObjectArray(
137+
ConstructingObjectParser.constructorArg(),
138+
(p, c) -> ChoiceParser.parse(p),
139+
new ParseField(CHOICES_FIELD)
140+
);
141+
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(MODEL_FIELD));
142+
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(OBJECT_FIELD));
143+
PARSER.declareObjectOrNull(
144+
ConstructingObjectParser.optionalConstructorArg(),
145+
(p, c) -> UsageParser.parse(p),
146+
null,
147+
new ParseField(USAGE_FIELD)
148+
);
149+
}
150+
151+
public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XContentParser parser) throws IOException {
152+
return PARSER.parse(parser, null);
153+
}
154+
155+
private static class ChoiceParser {
156+
private static final ConstructingObjectParser<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice, Void> PARSER =
157+
new ConstructingObjectParser<>(
158+
CHOICE_FIELD,
159+
true,
160+
args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
161+
(StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta) args[0],
162+
(String) args[1],
163+
(int) args[2]
164+
)
165+
);
166+
167+
static {
168+
PARSER.declareObject(
169+
ConstructingObjectParser.constructorArg(),
170+
(p, c) -> DeltaParser.parse(p),
171+
new ParseField(DELTA_FIELD)
172+
);
173+
PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(FINISH_REASON_FIELD));
174+
PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(INDEX_FIELD));
175+
}
176+
177+
public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice parse(XContentParser parser) {
178+
return PARSER.apply(parser, null);
179+
}
180+
}
181+
182+
// TODO try to move the changes to OpenAiUnifiedStreamingProcessor to avoid duplication, but still have OpenAI payload processible
183+
// TODO try to ignore FUNCTION_CALL_FIELD completely
184+
// TODO try to adapt declareObjectArrayOrNull without optionalConstructorArg because it is not used anywhere in this manner
185+
private static class DeltaParser {
186+
@SuppressWarnings("unchecked")
187+
private static final ConstructingObjectParser<
188+
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta,
189+
Void> PARSER = new ConstructingObjectParser<>(
190+
DELTA_FIELD,
191+
true,
192+
args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(
193+
(String) args[0],
194+
(String) args[2],
195+
(String) args[3],
196+
args.length > 4
197+
? (List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall>) args[4]
198+
: Collections.emptyList()
199+
)
200+
);
201+
202+
static {
203+
PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CONTENT_FIELD));
204+
PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(FUNCTION_CALL_FIELD));
205+
PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(REFUSAL_FIELD));
206+
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD));
207+
PARSER.declareObjectArrayOrNull(
208+
ConstructingObjectParser.optionalConstructorArg(),
209+
(p, c) -> ToolCallParser.parse(p),
210+
new ParseField(TOOL_CALLS_FIELD)
211+
);
212+
}
213+
214+
public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta parse(XContentParser parser)
215+
throws IOException {
216+
return PARSER.parse(parser, null);
217+
}
218+
}
219+
220+
private static class ToolCallParser {
221+
private static final ConstructingObjectParser<
222+
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall,
223+
Void> PARSER = new ConstructingObjectParser<>(
224+
"tool_call",
225+
true,
226+
args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall(
227+
(int) args[0],
228+
(String) args[1],
229+
(StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function) args[2],
230+
(String) args[3]
231+
)
232+
);
233+
234+
static {
235+
PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(INDEX_FIELD));
236+
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ID_FIELD));
237+
PARSER.declareObject(
238+
ConstructingObjectParser.optionalConstructorArg(),
239+
(p, c) -> FunctionParser.parse(p),
240+
new ParseField(FUNCTION_FIELD)
241+
);
242+
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(TYPE_FIELD));
243+
}
244+
245+
public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall parse(XContentParser parser)
246+
throws IOException {
247+
return PARSER.parse(parser, null);
248+
}
249+
}
250+
251+
private static class FunctionParser {
252+
private static final ConstructingObjectParser<
253+
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function,
254+
Void> PARSER = new ConstructingObjectParser<>(
255+
FUNCTION_FIELD,
256+
true,
257+
args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function(
258+
(String) args[0],
259+
(String) args[1]
260+
)
261+
);
262+
263+
static {
264+
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ARGUMENTS_FIELD));
265+
PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD));
266+
}
267+
268+
public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse(
269+
XContentParser parser
270+
) throws IOException {
271+
return PARSER.parse(parser, null);
272+
}
273+
}
274+
275+
private static class UsageParser {
276+
private static final ConstructingObjectParser<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage, Void> PARSER =
277+
new ConstructingObjectParser<>(
278+
USAGE_FIELD,
279+
true,
280+
args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage((int) args[0], (int) args[1], (int) args[2])
281+
);
282+
283+
static {
284+
PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(COMPLETION_TOKENS_FIELD));
285+
PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(PROMPT_TOKENS_FIELD));
286+
PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(TOTAL_TOKENS_FIELD));
287+
}
288+
289+
public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage parse(XContentParser parser) throws IOException {
290+
return PARSER.parse(parser, null);
291+
}
292+
}
293+
}
294+
}

0 commit comments

Comments
 (0)