Skip to content

Commit 6d216d5

Browse files
Merge branch 'ml-inference-unified-api-elastic' of github.com:elastic/elasticsearch into ml-inference-unified-api-elastic
2 parents ecdf5c3 + ab53397 commit 6d216d5

File tree

2 files changed

+425
-188
lines changed

2 files changed

+425
-188
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java

Lines changed: 199 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import org.elasticsearch.xcontent.ToXContent;
1717

1818
import java.io.IOException;
19+
import java.util.Arrays;
20+
import java.util.Collections;
1921
import java.util.Deque;
2022
import java.util.Iterator;
2123
import java.util.List;
@@ -24,6 +26,7 @@
2426
import java.util.concurrent.Flow;
2527

2628
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION;
29+
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.Result.RESULT;
2730

2831
/**
2932
* Chat Completion results that only contain a Flow.Publisher.
@@ -32,6 +35,10 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Chu
3235
implements
3336
InferenceServiceResults {
3437

38+
public static final String MODEL_FIELD = "model";
39+
public static final String OBJECT_FIELD = "object";
40+
public static final String USAGE_FIELD = "usage";
41+
3542
@Override
3643
public boolean isStreaming() {
3744
return true;
@@ -80,25 +87,51 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params
8087
}
8188
}
8289

83-
public record Result(String delta, String refusal, List<ToolCall> toolCalls) implements ChunkedToXContent {
84-
85-
private static final String RESULT = "delta";
86-
private static final String REFUSAL = "refusal";
87-
private static final String TOOL_CALLS = "tool_calls";
90+
private static final String REFUSAL_FIELD = "refusal";
91+
private static final String TOOL_CALLS_FIELD = "tool_calls";
92+
public static final String FINISH_REASON_FIELD = "finish_reason";
8893

89-
public Result(String delta) {
90-
this(delta, "", List.of());
91-
}
94+
public record Result(
95+
String delta,
96+
String refusal,
97+
List<ToolCall> toolCalls,
98+
String finishReason,
99+
String model,
100+
String object,
101+
ChatCompletionChunk.Usage usage
102+
) implements ChunkedToXContent {
92103

93104
@Override
94105
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
106+
Iterator<? extends ToXContent> toolCallsIterator = Collections.emptyIterator();
107+
if (toolCalls != null && toolCalls.isEmpty() == false) {
108+
toolCallsIterator = Iterators.concat(
109+
ChunkedToXContentHelper.startArray(TOOL_CALLS_FIELD),
110+
Iterators.flatMap(toolCalls.iterator(), d -> d.toXContentChunked(params)),
111+
ChunkedToXContentHelper.endArray()
112+
);
113+
}
114+
115+
Iterator<? extends ToXContent> usageIterator = Collections.emptyIterator();
116+
if (usage != null) {
117+
usageIterator = Iterators.concat(
118+
ChunkedToXContentHelper.startObject(USAGE_FIELD),
119+
ChunkedToXContentHelper.field("completion_tokens", usage.completionTokens()),
120+
ChunkedToXContentHelper.field("prompt_tokens", usage.promptTokens()),
121+
ChunkedToXContentHelper.field("total_tokens", usage.totalTokens()),
122+
ChunkedToXContentHelper.endObject()
123+
);
124+
}
125+
95126
return Iterators.concat(
96127
ChunkedToXContentHelper.startObject(),
97128
ChunkedToXContentHelper.field(RESULT, delta),
98-
ChunkedToXContentHelper.field(REFUSAL, refusal),
99-
ChunkedToXContentHelper.startArray(TOOL_CALLS),
100-
Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)),
101-
ChunkedToXContentHelper.endArray(),
129+
ChunkedToXContentHelper.field(REFUSAL_FIELD, refusal),
130+
toolCallsIterator,
131+
ChunkedToXContentHelper.field(FINISH_REASON_FIELD, finishReason),
132+
ChunkedToXContentHelper.field(MODEL_FIELD, model),
133+
ChunkedToXContentHelper.field(OBJECT_FIELD, object),
134+
usageIterator,
102135
ChunkedToXContentHelper.endObject()
103136
);
104137
}
@@ -178,4 +211,158 @@ public String toString() {
178211
+ '}';
179212
}
180213
}
214+
215+
public static class ChatCompletionChunk {
216+
private final String id;
217+
private List<Choice> choices;
218+
private final String model;
219+
private final String object;
220+
private ChatCompletionChunk.Usage usage;
221+
222+
public ChatCompletionChunk(String id, List<Choice> choices, String model, String object, ChatCompletionChunk.Usage usage) {
223+
this.id = id;
224+
this.choices = choices;
225+
this.model = model;
226+
this.object = object;
227+
this.usage = usage;
228+
}
229+
230+
public ChatCompletionChunk(
231+
String id,
232+
ChatCompletionChunk.Choice[] choices,
233+
String model,
234+
String object,
235+
ChatCompletionChunk.Usage usage
236+
) {
237+
this.id = id;
238+
this.choices = Arrays.stream(choices).toList();
239+
this.model = model;
240+
this.object = object;
241+
this.usage = usage;
242+
}
243+
244+
public String getId() {
245+
return id;
246+
}
247+
248+
public List<Choice> getChoices() {
249+
return choices;
250+
}
251+
252+
public String getModel() {
253+
return model;
254+
}
255+
256+
public String getObject() {
257+
return object;
258+
}
259+
260+
public ChatCompletionChunk.Usage getUsage() {
261+
return usage;
262+
}
263+
264+
public static class Choice {
265+
private final ChatCompletionChunk.Choice.Delta delta;
266+
private final String finishReason;
267+
private final int index;
268+
269+
public Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index) {
270+
this.delta = delta;
271+
this.finishReason = finishReason;
272+
this.index = index;
273+
}
274+
275+
public ChatCompletionChunk.Choice.Delta getDelta() {
276+
return delta;
277+
}
278+
279+
public String getFinishReason() {
280+
return finishReason;
281+
}
282+
283+
public int getIndex() {
284+
return index;
285+
}
286+
287+
public static class Delta {
288+
private final String content;
289+
private final String refusal;
290+
private final String role;
291+
private List<ToolCall> toolCalls;
292+
293+
public Delta(String content, String refusal, String role, List<ToolCall> toolCalls) {
294+
this.content = content;
295+
this.refusal = refusal;
296+
this.role = role;
297+
this.toolCalls = toolCalls;
298+
}
299+
300+
public String getContent() {
301+
return content;
302+
}
303+
304+
public String getRefusal() {
305+
return refusal;
306+
}
307+
308+
public String getRole() {
309+
return role;
310+
}
311+
312+
public List<ToolCall> getToolCalls() {
313+
return toolCalls;
314+
}
315+
316+
public static class ToolCall {
317+
private final int index;
318+
private final String id;
319+
public ChatCompletionChunk.Choice.Delta.ToolCall.Function function;
320+
private final String type;
321+
322+
public ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type) {
323+
this.index = index;
324+
this.id = id;
325+
this.function = function;
326+
this.type = type;
327+
}
328+
329+
public int getIndex() {
330+
return index;
331+
}
332+
333+
public String getId() {
334+
return id;
335+
}
336+
337+
public ChatCompletionChunk.Choice.Delta.ToolCall.Function getFunction() {
338+
return function;
339+
}
340+
341+
public String getType() {
342+
return type;
343+
}
344+
345+
public static class Function {
346+
private final String arguments;
347+
private final String name;
348+
349+
public Function(String arguments, String name) {
350+
this.arguments = arguments;
351+
this.name = name;
352+
}
353+
354+
public String getArguments() {
355+
return arguments;
356+
}
357+
358+
public String getName() {
359+
return name;
360+
}
361+
}
362+
}
363+
}
364+
}
365+
366+
public record Usage(int completionTokens, int promptTokens, int totalTokens) {}
367+
}
181368
}

0 commit comments

Comments
 (0)