Skip to content

Commit 97b330f

Browse files
author
Max Hniebergall
committed
Add suport for toolCalls and refusal in streaming completion
1 parent 657561e commit 97b330f

File tree

3 files changed

+168
-17
lines changed

3 files changed

+168
-17
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.Iterator;
2121
import java.util.List;
2222
import java.util.Map;
23+
import java.util.Objects;
2324
import java.util.concurrent.Flow;
2425

2526
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION;
@@ -77,16 +78,102 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params
7778
}
7879
}
7980

80-
public record Result(String delta) implements ChunkedToXContent {
81+
public record Result(String delta, String refusal, List<ToolCall> toolCalls) implements ChunkedToXContent {
82+
8183
private static final String RESULT = "delta";
84+
private static final String REFUSAL = "refusal";
85+
private static final String TOOL_CALLS = "tool_calls";
86+
87+
public Result(String delta) {
88+
this(delta, "", List.of());
89+
}
8290

8391
@Override
8492
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
8593
return Iterators.concat(
8694
ChunkedToXContentHelper.startObject(),
8795
ChunkedToXContentHelper.field(RESULT, delta),
96+
ChunkedToXContentHelper.field(REFUSAL, refusal),
97+
ChunkedToXContentHelper.startArray(TOOL_CALLS),
98+
Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)),
99+
ChunkedToXContentHelper.endArray(),
88100
ChunkedToXContentHelper.endObject()
89101
);
90102
}
91103
}
104+
105+
public static class ToolCall implements ChunkedToXContent {
106+
private final int index;
107+
private final String id;
108+
private final String functionName;
109+
private final String functionArguments;
110+
111+
public ToolCall(int index, String id, String functionName, String functionArguments) {
112+
this.index = index;
113+
this.id = id;
114+
this.functionName = functionName;
115+
this.functionArguments = functionArguments;
116+
}
117+
118+
public int getIndex() {
119+
return index;
120+
}
121+
122+
public String getId() {
123+
return id;
124+
}
125+
126+
public String getFunctionName() {
127+
return functionName;
128+
}
129+
130+
public String getFunctionArguments() {
131+
return functionArguments;
132+
}
133+
134+
@Override
135+
public boolean equals(Object o) {
136+
if (this == o) return true;
137+
if (o == null || getClass() != o.getClass()) return false;
138+
ToolCall toolCall = (ToolCall) o;
139+
return index == toolCall.index
140+
&& Objects.equals(id, toolCall.id)
141+
&& Objects.equals(functionName, toolCall.functionName)
142+
&& Objects.equals(functionArguments, toolCall.functionArguments);
143+
}
144+
145+
@Override
146+
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
147+
return Iterators.concat(
148+
ChunkedToXContentHelper.startObject(),
149+
ChunkedToXContentHelper.field("index", index),
150+
ChunkedToXContentHelper.field("id", id),
151+
ChunkedToXContentHelper.field("functionName", functionName),
152+
ChunkedToXContentHelper.field("functionArguments", functionArguments),
153+
ChunkedToXContentHelper.endObject()
154+
);
155+
}
156+
157+
@Override
158+
public int hashCode() {
159+
return Objects.hash(index, id, functionName, functionArguments);
160+
}
161+
162+
@Override
163+
public String toString() {
164+
return "ToolCall{"
165+
+ "index="
166+
+ index
167+
+ ", id='"
168+
+ id
169+
+ '\''
170+
+ ", functionName='"
171+
+ functionName
172+
+ '\''
173+
+ ", functionArguments='"
174+
+ functionArguments
175+
+ '\''
176+
+ '}';
177+
}
178+
}
92179
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222

2323
import java.io.IOException;
2424
import java.util.ArrayDeque;
25+
import java.util.ArrayList;
2526
import java.util.Collections;
2627
import java.util.Deque;
2728
import java.util.Iterator;
29+
import java.util.List;
2830
import java.util.Objects;
29-
import java.util.function.Predicate;
3031

3132
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3233
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
@@ -111,6 +112,8 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSe
111112
private static final String DELTA_FIELD = "delta";
112113
private static final String CONTENT_FIELD = "content";
113114
private static final String DONE_MESSAGE = "[done]";
115+
private static final String REFUSAL_FIELD = "refusal";
116+
private static final String TOOL_CALLS_FIELD = "tool_calls";
114117

115118
@Override
116119
protected void next(Deque<ServerSentEvent> item) throws Exception {
@@ -159,6 +162,10 @@ private Iterator<StreamingChatCompletionResults.Result> parse(XContentParserConf
159162

160163
ensureExpectedToken(XContentParser.Token.START_OBJECT, currentToken, parser);
161164

165+
String content = null;
166+
String refusal = null;
167+
List<StreamingChatCompletionResults.ToolCall> toolCalls = new ArrayList<>();
168+
162169
currentToken = parser.nextToken();
163170

164171
// continue until the end of delta
@@ -167,25 +174,84 @@ private Iterator<StreamingChatCompletionResults.Result> parse(XContentParserConf
167174
parser.skipChildren();
168175
}
169176

170-
if (currentToken == XContentParser.Token.FIELD_NAME && parser.currentName().equals(CONTENT_FIELD)) {
171-
parser.nextToken();
172-
ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
173-
var content = parser.text();
174-
consumeUntilObjectEnd(parser); // end delta
175-
consumeUntilObjectEnd(parser); // end choices
176-
return content;
177+
if (currentToken == XContentParser.Token.FIELD_NAME) {
178+
switch (parser.currentName()) {
179+
case CONTENT_FIELD:
180+
parser.nextToken();
181+
ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
182+
content = parser.text();
183+
break;
184+
case REFUSAL_FIELD:
185+
parser.nextToken();
186+
ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
187+
refusal = parser.text();
188+
break;
189+
case TOOL_CALLS_FIELD:
190+
parser.nextToken();
191+
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
192+
toolCalls = parseToolCalls(parser);
193+
break;
194+
}
177195
}
178196

179197
currentToken = parser.nextToken();
180198
}
181199

200+
consumeUntilObjectEnd(parser); // end delta
182201
consumeUntilObjectEnd(parser); // end choices
183-
return ""; // stopped
184-
}).stream()
185-
.filter(Objects::nonNull)
186-
.filter(Predicate.not(String::isEmpty))
187-
.map(StreamingChatCompletionResults.Result::new)
188-
.iterator();
202+
203+
return new StreamingChatCompletionResults.Result(content, refusal, toolCalls);
204+
}).stream().filter(Objects::nonNull).iterator();
205+
}
206+
}
207+
208+
private List<StreamingChatCompletionResults.ToolCall> parseToolCalls(XContentParser parser) throws IOException {
209+
List<StreamingChatCompletionResults.ToolCall> toolCalls = new ArrayList<>();
210+
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
211+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
212+
int index = -1;
213+
String id = null;
214+
String functionName = null;
215+
String functionArguments = null;
216+
217+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
218+
if (parser.currentToken() == XContentParser.Token.FIELD_NAME) {
219+
switch (parser.currentName()) {
220+
case "index":
221+
parser.nextToken();
222+
ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, parser.currentToken(), parser);
223+
index = parser.intValue();
224+
break;
225+
case "id":
226+
parser.nextToken();
227+
ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
228+
id = parser.text();
229+
break;
230+
case "function":
231+
parser.nextToken();
232+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
233+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
234+
if (parser.currentToken() == XContentParser.Token.FIELD_NAME) {
235+
switch (parser.currentName()) {
236+
case "name":
237+
parser.nextToken();
238+
ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
239+
functionName = parser.text();
240+
break;
241+
case "arguments":
242+
parser.nextToken();
243+
ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
244+
functionArguments = parser.text();
245+
break;
246+
}
247+
}
248+
}
249+
break;
250+
}
251+
}
252+
}
253+
toolCalls.add(new StreamingChatCompletionResults.ToolCall(index, id, functionName, functionArguments));
189254
}
255+
return toolCalls;
190256
}
191257
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
174174
builder.field(STREAM_FIELD, stream);
175175
builder.endObject();
176176

177-
System.out.println(Strings.toString(builder));
178-
179177
return builder;
180178
}
181179
}

0 commit comments

Comments
 (0)