Skip to content

Commit 0c2e90a

Browse files
authored
[ML] Write Chat Completion JSON (#128592) (#128642)
Most providers write the UnifiedCompletionRequest JSON as we received it, with some exception: - the modelId can be null and/or overwritten from various locations - `max_completion_tokens` repalced `max_tokens`, but some providers still use the deprecated field name We will handle the variations using Params, otherwise all of the XContent building code has moved into UnifiedCompletionRequest so it can be reused across providers.
1 parent 8ea4172 commit 0c2e90a

File tree

8 files changed

+233
-184
lines changed

8 files changed

+233
-184
lines changed

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 196 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
import org.elasticsearch.xcontent.ConstructingObjectParser;
2020
import org.elasticsearch.xcontent.ObjectParser;
2121
import org.elasticsearch.xcontent.ParseField;
22+
import org.elasticsearch.xcontent.ToXContent;
23+
import org.elasticsearch.xcontent.ToXContentFragment;
24+
import org.elasticsearch.xcontent.ToXContentObject;
25+
import org.elasticsearch.xcontent.XContentBuilder;
2226
import org.elasticsearch.xcontent.XContentParseException;
2327
import org.elasticsearch.xcontent.XContentParser;
2428

@@ -38,9 +42,68 @@ public record UnifiedCompletionRequest(
3842
@Nullable ToolChoice toolChoice,
3943
@Nullable List<Tool> tools,
4044
@Nullable Float topP
41-
) implements Writeable {
45+
) implements Writeable, ToXContentFragment {
46+
47+
public static final String NAME_FIELD = "name";
48+
public static final String TOOL_CALL_ID_FIELD = "tool_call_id";
49+
public static final String TOOL_CALLS_FIELD = "tool_calls";
50+
public static final String ID_FIELD = "id";
51+
public static final String FUNCTION_FIELD = "function";
52+
public static final String ARGUMENTS_FIELD = "arguments";
53+
public static final String DESCRIPTION_FIELD = "description";
54+
public static final String PARAMETERS_FIELD = "parameters";
55+
public static final String STRICT_FIELD = "strict";
56+
public static final String TOP_P_FIELD = "top_p";
57+
public static final String MESSAGES_FIELD = "messages";
58+
private static final String ROLE_FIELD = "role";
59+
private static final String CONTENT_FIELD = "content";
60+
private static final String STOP_FIELD = "stop";
61+
private static final String TEMPERATURE_FIELD = "temperature";
62+
private static final String TOOL_CHOICE_FIELD = "tool_choice";
63+
private static final String TOOL_FIELD = "tools";
64+
private static final String TEXT_FIELD = "text";
65+
private static final String TYPE_FIELD = "type";
66+
private static final String MODEL_FIELD = "model";
67+
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
68+
private static final String MAX_TOKENS_FIELD = "max_tokens";
69+
70+
/**
71+
* We currently allow providers to override the model id that is written to JSON.
72+
* Rather than use {@link #model()}, providers are expected to pass in the modelId via
73+
* {@link org.elasticsearch.xcontent.ToXContent.Params}.
74+
*/
75+
private static final String MODEL_ID_PARAM = "model_id_value";
76+
/**
77+
* Some providers only support the now-deprecated {@link #MAX_TOKENS_FIELD}, others have migrated to
78+
* {@link #MAX_COMPLETION_TOKENS_FIELD}. Providers are expected to pass in their supported field name.
79+
*/
80+
private static final String MAX_TOKENS_PARAM = "max_tokens_field";
81+
82+
/**
83+
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
84+
* - Key: {@link #MODEL_FIELD}, Value: modelId
85+
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
86+
*/
87+
public static Params withMaxTokens(String modelId, Params params) {
88+
return new DelegatingMapParams(
89+
Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD)),
90+
params
91+
);
92+
}
4293

43-
public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {}
94+
/**
95+
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
96+
* - Key: {@link #MODEL_FIELD}, Value: modelId
97+
* - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
98+
*/
99+
public static Params withMaxCompletionTokensTokens(String modelId, Params params) {
100+
return new DelegatingMapParams(
101+
Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD)),
102+
params
103+
);
104+
}
105+
106+
public sealed interface Content extends NamedWriteable, ToXContent permits ContentObjects, ContentString {}
44107

45108
@SuppressWarnings("unchecked")
46109
public static final ConstructingObjectParser<UnifiedCompletionRequest, Void> PARSER = new ConstructingObjectParser<>(
@@ -111,9 +174,40 @@ public void writeTo(StreamOutput out) throws IOException {
111174
out.writeOptionalFloat(topP);
112175
}
113176

177+
@Override
178+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
179+
builder.field(MESSAGES_FIELD, messages);
180+
if (stop != null && (stop.isEmpty() == false)) {
181+
builder.field(STOP_FIELD, stop);
182+
}
183+
if (temperature != null) {
184+
builder.field(TEMPERATURE_FIELD, temperature);
185+
}
186+
if (toolChoice != null) {
187+
toolChoice.toXContent(builder, params);
188+
}
189+
if (tools != null && (tools.isEmpty() == false)) {
190+
builder.field(TOOL_FIELD, tools);
191+
}
192+
if (topP != null) {
193+
builder.field(TOP_P_FIELD, topP);
194+
}
195+
// some providers only support the now-deprecated max_tokens, others have migrated to max_completion_tokens
196+
if (maxCompletionTokens != null && params.param(MAX_TOKENS_PARAM) != null) {
197+
builder.field(params.param(MAX_TOKENS_PARAM), maxCompletionTokens);
198+
}
199+
// some implementations handle modelId differently, for example OpenAI has a default in the server settings and override it there
200+
// so we allow implementations to pass in the model id via the params
201+
if (params.param(MODEL_ID_PARAM) != null) {
202+
builder.field(MODEL_FIELD, params.param(MODEL_ID_PARAM));
203+
}
204+
return builder;
205+
}
206+
114207
public record Message(Content content, String role, @Nullable String toolCallId, @Nullable List<ToolCall> toolCalls)
115208
implements
116-
Writeable {
209+
Writeable,
210+
ToXContentObject {
117211

118212
@SuppressWarnings("unchecked")
119213
static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(
@@ -161,6 +255,24 @@ public void writeTo(StreamOutput out) throws IOException {
161255
out.writeOptionalString(toolCallId);
162256
out.writeOptionalCollection(toolCalls);
163257
}
258+
259+
@Override
260+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
261+
builder.startObject();
262+
263+
if (content != null) {
264+
content.toXContent(builder, params);
265+
}
266+
builder.field(ROLE_FIELD, role);
267+
if (toolCallId != null) {
268+
builder.field(TOOL_CALL_ID_FIELD, toolCallId);
269+
}
270+
if (toolCalls != null) {
271+
builder.field(TOOL_CALLS_FIELD, toolCalls);
272+
}
273+
274+
return builder.endObject();
275+
}
164276
}
165277

166278
public record ContentObjects(List<ContentObject> contentObjects) implements Content, NamedWriteable {
@@ -180,9 +292,14 @@ public void writeTo(StreamOutput out) throws IOException {
180292
public String getWriteableName() {
181293
return NAME;
182294
}
295+
296+
@Override
297+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
298+
return builder.field(CONTENT_FIELD, contentObjects);
299+
}
183300
}
184301

185-
public record ContentObject(String text, String type) implements Writeable {
302+
public record ContentObject(String text, String type) implements Writeable, ToXContentObject {
186303
static final ConstructingObjectParser<ContentObject, Void> PARSER = new ConstructingObjectParser<>(
187304
ContentObject.class.getSimpleName(),
188305
args -> new ContentObject((String) args[0], (String) args[1])
@@ -207,6 +324,13 @@ public String toString() {
207324
return text + ":" + type;
208325
}
209326

327+
@Override
328+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
329+
builder.startObject();
330+
builder.field(TEXT_FIELD, text);
331+
builder.field(TYPE_FIELD, type);
332+
return builder.endObject();
333+
}
210334
}
211335

212336
public record ContentString(String content) implements Content, NamedWriteable {
@@ -234,9 +358,14 @@ public String getWriteableName() {
234358
public String toString() {
235359
return content;
236360
}
361+
362+
@Override
363+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
364+
return builder.field(CONTENT_FIELD, content);
365+
}
237366
}
238367

239-
public record ToolCall(String id, FunctionField function, String type) implements Writeable {
368+
public record ToolCall(String id, FunctionField function, String type) implements Writeable, ToXContentObject {
240369

241370
static final ConstructingObjectParser<ToolCall, Void> PARSER = new ConstructingObjectParser<>(
242371
ToolCall.class.getSimpleName(),
@@ -260,7 +389,16 @@ public void writeTo(StreamOutput out) throws IOException {
260389
out.writeString(type);
261390
}
262391

263-
public record FunctionField(String arguments, String name) implements Writeable {
392+
@Override
393+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
394+
builder.startObject();
395+
builder.field(ID_FIELD, id);
396+
builder.field(FUNCTION_FIELD, function);
397+
builder.field(TYPE_FIELD, type);
398+
return builder.endObject();
399+
}
400+
401+
public record FunctionField(String arguments, String name) implements Writeable, ToXContentObject {
264402
static final ConstructingObjectParser<FunctionField, Void> PARSER = new ConstructingObjectParser<>(
265403
"tool_call_function_field",
266404
args -> new FunctionField((String) args[0], (String) args[1])
@@ -280,6 +418,14 @@ public void writeTo(StreamOutput out) throws IOException {
280418
out.writeString(arguments);
281419
out.writeString(name);
282420
}
421+
422+
@Override
423+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
424+
builder.startObject();
425+
builder.field(ARGUMENTS_FIELD, arguments);
426+
builder.field(NAME_FIELD, name);
427+
return builder.endObject();
428+
}
283429
}
284430
}
285431

@@ -294,7 +440,7 @@ private static ToolChoice parseToolChoice(XContentParser parser) throws IOExcept
294440
throw new XContentParseException("Unsupported token [" + token + "]");
295441
}
296442

297-
public sealed interface ToolChoice extends NamedWriteable permits ToolChoiceObject, ToolChoiceString {}
443+
public sealed interface ToolChoice extends NamedWriteable, ToXContent permits ToolChoiceObject, ToolChoiceString {}
298444

299445
public record ToolChoiceObject(String type, FunctionField function) implements ToolChoice, NamedWriteable {
300446

@@ -325,7 +471,15 @@ public String getWriteableName() {
325471
return NAME;
326472
}
327473

328-
public record FunctionField(String name) implements Writeable {
474+
@Override
475+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
476+
builder.startObject(TOOL_CHOICE_FIELD);
477+
builder.field(TYPE_FIELD, type);
478+
builder.field(FUNCTION_FIELD, function);
479+
return builder.endObject();
480+
}
481+
482+
public record FunctionField(String name) implements Writeable, ToXContentObject {
329483
static final ConstructingObjectParser<FunctionField, Void> PARSER = new ConstructingObjectParser<>(
330484
"tool_choice_function_field",
331485
args -> new FunctionField((String) args[0])
@@ -343,6 +497,11 @@ public FunctionField(StreamInput in) throws IOException {
343497
public void writeTo(StreamOutput out) throws IOException {
344498
out.writeString(name);
345499
}
500+
501+
@Override
502+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
503+
return builder.startObject().field(NAME_FIELD, name).endObject();
504+
}
346505
}
347506
}
348507

@@ -367,9 +526,14 @@ public void writeTo(StreamOutput out) throws IOException {
367526
public String getWriteableName() {
368527
return NAME;
369528
}
529+
530+
@Override
531+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
532+
return builder.field(TOOL_CHOICE_FIELD, value);
533+
}
370534
}
371535

372-
public record Tool(String type, FunctionField function) implements Writeable {
536+
public record Tool(String type, FunctionField function) implements Writeable, ToXContentObject {
373537

374538
static final ConstructingObjectParser<Tool, Void> PARSER = new ConstructingObjectParser<>(
375539
Tool.class.getSimpleName(),
@@ -391,12 +555,22 @@ public void writeTo(StreamOutput out) throws IOException {
391555
function.writeTo(out);
392556
}
393557

558+
@Override
559+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
560+
builder.startObject();
561+
562+
builder.field(TYPE_FIELD, type);
563+
builder.field(FUNCTION_FIELD, function);
564+
565+
return builder.endObject();
566+
}
567+
394568
public record FunctionField(
395569
@Nullable String description,
396570
String name,
397571
@Nullable Map<String, Object> parameters,
398572
@Nullable Boolean strict
399-
) implements Writeable {
573+
) implements Writeable, ToXContentObject {
400574

401575
@SuppressWarnings("unchecked")
402576
static final ConstructingObjectParser<FunctionField, Void> PARSER = new ConstructingObjectParser<>(
@@ -422,6 +596,18 @@ public void writeTo(StreamOutput out) throws IOException {
422596
out.writeGenericMap(parameters);
423597
out.writeOptionalBoolean(strict);
424598
}
599+
600+
@Override
601+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
602+
builder.startObject();
603+
builder.field(DESCRIPTION_FIELD, description);
604+
builder.field(NAME_FIELD, name);
605+
builder.field(PARAMETERS_FIELD, parameters);
606+
if (strict != null) {
607+
builder.field(STRICT_FIELD, strict);
608+
}
609+
return builder.endObject();
610+
}
425611
}
426612
}
427613
}

0 commit comments

Comments
 (0)