Skip to content

Commit 0dfd081

Browse files
Refactoring tests and request entities
1 parent 03fada0 commit 0dfd081

File tree

12 files changed

+160
-180
lines changed

12 files changed

+160
-180
lines changed

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,11 @@ public record UnifiedCompletionRequest(
3333
List<Message> messages,
3434
@Nullable String model,
3535
@Nullable Long maxCompletionTokens,
36-
@Nullable Integer n,
3736
@Nullable Stop stop,
3837
@Nullable Float temperature,
3938
@Nullable ToolChoice toolChoice,
4039
@Nullable List<Tool> tools,
41-
@Nullable Float topP,
42-
@Nullable String user
40+
@Nullable Float topP
4341
) implements Writeable {
4442

4543
public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {}
@@ -51,21 +49,18 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C
5149
(List<Message>) args[0],
5250
(String) args[1],
5351
(Long) args[2],
54-
(Integer) args[3],
55-
(Stop) args[4],
56-
(Float) args[5],
57-
(ToolChoice) args[6],
58-
(List<Tool>) args[7],
59-
(Float) args[8],
60-
(String) args[9]
52+
(Stop) args[3],
53+
(Float) args[4],
54+
(ToolChoice) args[5],
55+
(List<Tool>) args[6],
56+
(Float) args[7]
6157
)
6258
);
6359

6460
static {
6561
PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages"));
6662
PARSER.declareString(optionalConstructorArg(), new ParseField("model"));
6763
PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens"));
68-
PARSER.declareInt(optionalConstructorArg(), new ParseField("n"));
6964
PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"), ObjectParser.ValueType.VALUE_ARRAY);
7065
PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature"));
7166
PARSER.declareField(
@@ -76,7 +71,6 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C
7671
);
7772
PARSER.declareObjectArray(optionalConstructorArg(), Tool.PARSER::apply, new ParseField("tools"));
7873
PARSER.declareFloat(optionalConstructorArg(), new ParseField("top_p"));
79-
PARSER.declareString(optionalConstructorArg(), new ParseField("user"));
8074
}
8175

8276
public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
@@ -90,18 +84,20 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
9084
);
9185
}
9286

87+
public static UnifiedCompletionRequest of(List<Message> messages) {
88+
return new UnifiedCompletionRequest(messages, null, null, null, null, null, null, null);
89+
}
90+
9391
public UnifiedCompletionRequest(StreamInput in) throws IOException {
9492
this(
9593
in.readCollectionAsImmutableList(Message::new),
9694
in.readOptionalString(),
9795
in.readOptionalVLong(),
98-
in.readOptionalVInt(),
9996
in.readOptionalNamedWriteable(Stop.class),
10097
in.readOptionalFloat(),
10198
in.readOptionalNamedWriteable(ToolChoice.class),
102-
in.readCollectionAsImmutableList(Tool::new),
103-
in.readOptionalFloat(),
104-
in.readOptionalString()
99+
in.readOptionalCollectionAsList(Tool::new),
100+
in.readOptionalFloat()
105101
);
106102
}
107103

@@ -110,13 +106,11 @@ public void writeTo(StreamOutput out) throws IOException {
110106
out.writeCollection(messages);
111107
out.writeOptionalString(model);
112108
out.writeOptionalVLong(maxCompletionTokens);
113-
out.writeOptionalVInt(n);
114109
out.writeOptionalNamedWriteable(stop);
115110
out.writeOptionalFloat(temperature);
116111
out.writeOptionalNamedWriteable(toolChoice);
117112
out.writeOptionalCollection(tools);
118113
out.writeOptionalFloat(topP);
119-
out.writeOptionalString(user);
120114
}
121115

122116
public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List<ToolCall> toolCalls)
@@ -155,7 +149,7 @@ public Message(StreamInput in) throws IOException {
155149
in.readString(),
156150
in.readOptionalString(),
157151
in.readOptionalString(),
158-
in.readCollectionAsImmutableList(ToolCall::new)
152+
in.readOptionalCollectionAsList(ToolCall::new)
159153
);
160154
}
161155

@@ -165,7 +159,7 @@ public void writeTo(StreamOutput out) throws IOException {
165159
out.writeString(role);
166160
out.writeOptionalString(name);
167161
out.writeOptionalString(toolCallId);
168-
out.writeCollection(toolCalls);
162+
out.writeOptionalCollection(toolCalls);
169163
}
170164
}
171165

0 commit comments

Comments
 (0)