1616import org .elasticsearch .xcontent .ToXContent ;
1717
1818import java .io .IOException ;
19+ import java .util .Arrays ;
20+ import java .util .Collections ;
1921import java .util .Deque ;
2022import java .util .Iterator ;
2123import java .util .List ;
2426import java .util .concurrent .Flow ;
2527
2628import static org .elasticsearch .xpack .core .inference .results .ChatCompletionResults .COMPLETION ;
29+ import static org .elasticsearch .xpack .core .inference .results .ChatCompletionResults .Result .RESULT ;
2730
2831/**
2932 * Chat Completion results that only contain a Flow.Publisher.
@@ -32,6 +35,10 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Chu
3235 implements
3336 InferenceServiceResults {
3437
38+ public static final String MODEL_FIELD = "model" ;
39+ public static final String OBJECT_FIELD = "object" ;
40+ public static final String USAGE_FIELD = "usage" ;
41+
3542 @ Override
3643 public boolean isStreaming () {
3744 return true ;
@@ -80,25 +87,51 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params
8087 }
8188 }
8289
83- public record Result (String delta , String refusal , List <ToolCall > toolCalls ) implements ChunkedToXContent {
84-
85- private static final String RESULT = "delta" ;
86- private static final String REFUSAL = "refusal" ;
87- private static final String TOOL_CALLS = "tool_calls" ;
90+ private static final String REFUSAL_FIELD = "refusal" ;
91+ private static final String TOOL_CALLS_FIELD = "tool_calls" ;
92+ public static final String FINISH_REASON_FIELD = "finish_reason" ;
8893
89- public Result (String delta ) {
90- this (delta , "" , List .of ());
91- }
94+ public record Result (
95+ String delta ,
96+ String refusal ,
97+ List <ToolCall > toolCalls ,
98+ String finishReason ,
99+ String model ,
100+ String object ,
101+ ChatCompletionChunk .Usage usage
102+ ) implements ChunkedToXContent {
92103
93104 @ Override
94105 public Iterator <? extends ToXContent > toXContentChunked (ToXContent .Params params ) {
106+ Iterator <? extends ToXContent > toolCallsIterator = Collections .emptyIterator ();
107+ if (toolCalls != null && toolCalls .isEmpty () == false ) {
108+ toolCallsIterator = Iterators .concat (
109+ ChunkedToXContentHelper .startArray (TOOL_CALLS_FIELD ),
110+ Iterators .flatMap (toolCalls .iterator (), d -> d .toXContentChunked (params )),
111+ ChunkedToXContentHelper .endArray ()
112+ );
113+ }
114+
115+ Iterator <? extends ToXContent > usageIterator = Collections .emptyIterator ();
116+ if (usage != null ) {
117+ usageIterator = Iterators .concat (
118+ ChunkedToXContentHelper .startObject (USAGE_FIELD ),
119+ ChunkedToXContentHelper .field ("completion_tokens" , usage .completionTokens ()),
120+ ChunkedToXContentHelper .field ("prompt_tokens" , usage .promptTokens ()),
121+ ChunkedToXContentHelper .field ("total_tokens" , usage .totalTokens ()),
122+ ChunkedToXContentHelper .endObject ()
123+ );
124+ }
125+
95126 return Iterators .concat (
96127 ChunkedToXContentHelper .startObject (),
97128 ChunkedToXContentHelper .field (RESULT , delta ),
98- ChunkedToXContentHelper .field (REFUSAL , refusal ),
99- ChunkedToXContentHelper .startArray (TOOL_CALLS ),
100- Iterators .flatMap (toolCalls .iterator (), t -> t .toXContentChunked (params )),
101- ChunkedToXContentHelper .endArray (),
129+ ChunkedToXContentHelper .field (REFUSAL_FIELD , refusal ),
130+ toolCallsIterator ,
131+ ChunkedToXContentHelper .field (FINISH_REASON_FIELD , finishReason ),
132+ ChunkedToXContentHelper .field (MODEL_FIELD , model ),
133+ ChunkedToXContentHelper .field (OBJECT_FIELD , object ),
134+ usageIterator ,
102135 ChunkedToXContentHelper .endObject ()
103136 );
104137 }
@@ -178,4 +211,158 @@ public String toString() {
178211 + '}' ;
179212 }
180213 }
214+
215+ public static class ChatCompletionChunk {
216+ private final String id ;
217+ private List <Choice > choices ;
218+ private final String model ;
219+ private final String object ;
220+ private ChatCompletionChunk .Usage usage ;
221+
222+ public ChatCompletionChunk (String id , List <Choice > choices , String model , String object , ChatCompletionChunk .Usage usage ) {
223+ this .id = id ;
224+ this .choices = choices ;
225+ this .model = model ;
226+ this .object = object ;
227+ this .usage = usage ;
228+ }
229+
230+ public ChatCompletionChunk (
231+ String id ,
232+ ChatCompletionChunk .Choice [] choices ,
233+ String model ,
234+ String object ,
235+ ChatCompletionChunk .Usage usage
236+ ) {
237+ this .id = id ;
238+ this .choices = Arrays .stream (choices ).toList ();
239+ this .model = model ;
240+ this .object = object ;
241+ this .usage = usage ;
242+ }
243+
244+ public String getId () {
245+ return id ;
246+ }
247+
248+ public List <Choice > getChoices () {
249+ return choices ;
250+ }
251+
252+ public String getModel () {
253+ return model ;
254+ }
255+
256+ public String getObject () {
257+ return object ;
258+ }
259+
260+ public ChatCompletionChunk .Usage getUsage () {
261+ return usage ;
262+ }
263+
264+ public static class Choice {
265+ private final ChatCompletionChunk .Choice .Delta delta ;
266+ private final String finishReason ;
267+ private final int index ;
268+
269+ public Choice (ChatCompletionChunk .Choice .Delta delta , String finishReason , int index ) {
270+ this .delta = delta ;
271+ this .finishReason = finishReason ;
272+ this .index = index ;
273+ }
274+
275+ public ChatCompletionChunk .Choice .Delta getDelta () {
276+ return delta ;
277+ }
278+
279+ public String getFinishReason () {
280+ return finishReason ;
281+ }
282+
283+ public int getIndex () {
284+ return index ;
285+ }
286+
287+ public static class Delta {
288+ private final String content ;
289+ private final String refusal ;
290+ private final String role ;
291+ private List <ToolCall > toolCalls ;
292+
293+ public Delta (String content , String refusal , String role , List <ToolCall > toolCalls ) {
294+ this .content = content ;
295+ this .refusal = refusal ;
296+ this .role = role ;
297+ this .toolCalls = toolCalls ;
298+ }
299+
300+ public String getContent () {
301+ return content ;
302+ }
303+
304+ public String getRefusal () {
305+ return refusal ;
306+ }
307+
308+ public String getRole () {
309+ return role ;
310+ }
311+
312+ public List <ToolCall > getToolCalls () {
313+ return toolCalls ;
314+ }
315+
316+ public static class ToolCall {
317+ private final int index ;
318+ private final String id ;
319+ public ChatCompletionChunk .Choice .Delta .ToolCall .Function function ;
320+ private final String type ;
321+
322+ public ToolCall (int index , String id , ChatCompletionChunk .Choice .Delta .ToolCall .Function function , String type ) {
323+ this .index = index ;
324+ this .id = id ;
325+ this .function = function ;
326+ this .type = type ;
327+ }
328+
329+ public int getIndex () {
330+ return index ;
331+ }
332+
333+ public String getId () {
334+ return id ;
335+ }
336+
337+ public ChatCompletionChunk .Choice .Delta .ToolCall .Function getFunction () {
338+ return function ;
339+ }
340+
341+ public String getType () {
342+ return type ;
343+ }
344+
345+ public static class Function {
346+ private final String arguments ;
347+ private final String name ;
348+
349+ public Function (String arguments , String name ) {
350+ this .arguments = arguments ;
351+ this .name = name ;
352+ }
353+
354+ public String getArguments () {
355+ return arguments ;
356+ }
357+
358+ public String getName () {
359+ return name ;
360+ }
361+ }
362+ }
363+ }
364+ }
365+
366+ public record Usage (int completionTokens , int promptTokens , int totalTokens ) {}
367+ }
181368}
0 commit comments