Skip to content

Commit 7245632

Browse files
[ML] Unified schema API remove name field (#119799)
* Removing name field * Fixing test
1 parent 75d1050 commit 7245632

File tree

11 files changed

+12
-65
lines changed

11 files changed

+12
-65
lines changed

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,14 @@ public void writeTo(StreamOutput out) throws IOException {
111111
out.writeOptionalFloat(topP);
112112
}
113113

114-
public record Message(
115-
Content content,
116-
String role,
117-
@Nullable String name,
118-
@Nullable String toolCallId,
119-
@Nullable List<ToolCall> toolCalls
120-
) implements Writeable {
114+
public record Message(Content content, String role, @Nullable String toolCallId, @Nullable List<ToolCall> toolCalls)
115+
implements
116+
Writeable {
121117

122118
@SuppressWarnings("unchecked")
123119
static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(
124120
Message.class.getSimpleName(),
125-
args -> new Message((Content) args[0], (String) args[1], (String) args[2], (String) args[3], (List<ToolCall>) args[4])
121+
args -> new Message((Content) args[0], (String) args[1], (String) args[2], (List<ToolCall>) args[3])
126122
);
127123

128124
static {
@@ -133,7 +129,6 @@ public record Message(
133129
ObjectParser.ValueType.VALUE_ARRAY
134130
);
135131
PARSER.declareString(constructorArg(), new ParseField("role"));
136-
PARSER.declareString(optionalConstructorArg(), new ParseField("name"));
137132
PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id"));
138133
PARSER.declareObjectArray(optionalConstructorArg(), ToolCall.PARSER::apply, new ParseField("tool_calls"));
139134
}
@@ -155,7 +150,6 @@ public Message(StreamInput in) throws IOException {
155150
in.readOptionalNamedWriteable(Content.class),
156151
in.readString(),
157152
in.readOptionalString(),
158-
in.readOptionalString(),
159153
in.readOptionalCollectionAsList(ToolCall::new)
160154
);
161155
}
@@ -164,7 +158,6 @@ public Message(StreamInput in) throws IOException {
164158
public void writeTo(StreamOutput out) throws IOException {
165159
out.writeOptionalNamedWriteable(content);
166160
out.writeString(role);
167-
out.writeOptionalString(name);
168161
out.writeOptionalString(toolCallId);
169162
out.writeOptionalCollection(toolCalls);
170163
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ public void testParseAllFields() throws IOException {
3535
"type": "string"
3636
}
3737
],
38-
"name": "a name",
3938
"tool_call_id": "100",
4039
"tool_calls": [
4140
{
@@ -83,7 +82,6 @@ public void testParseAllFields() throws IOException {
8382
List.of(new UnifiedCompletionRequest.ContentObject("some text", "string"))
8483
),
8584
"user",
86-
"a name",
8785
"100",
8886
List.of(
8987
new UnifiedCompletionRequest.ToolCall(
@@ -155,7 +153,6 @@ public void testParsing() throws IOException {
155153
new UnifiedCompletionRequest.ContentString("What is the weather like in Boston today?"),
156154
"user",
157155
null,
158-
null,
159156
null
160157
)
161158
),
@@ -200,7 +197,6 @@ public static UnifiedCompletionRequest.Message randomMessage() {
200197
randomContent(),
201198
randomAlphaOfLength(10),
202199
randomAlphaOfLengthOrNull(10),
203-
randomAlphaOfLengthOrNull(10),
204200
randomToolCallListOrNull()
205201
);
206202
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,7 @@ public UnifiedChatInput(List<String> inputs, String roleValue, boolean stream) {
4040

4141
private static List<UnifiedCompletionRequest.Message> convertToMessages(List<String> inputs, String roleValue) {
4242
return inputs.stream()
43-
.map(
44-
value -> new UnifiedCompletionRequest.Message(
45-
new UnifiedCompletionRequest.ContentString(value),
46-
roleValue,
47-
null,
48-
null,
49-
null
50-
)
51-
)
43+
.map(value -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(value), roleValue, null, null))
5244
.toList();
5345
}
5446

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
7777
}
7878

7979
builder.field(ROLE_FIELD, message.role());
80-
if (message.name() != null) {
81-
builder.field(NAME_FIELD, message.name());
82-
}
8380
if (message.toolCallId() != null) {
8481
builder.field(TOOL_CALL_ID_FIELD, message.toolCallId());
8582
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,8 @@ public void testConvertsStringInputToMessages() {
2424
Matchers.is(
2525
UnifiedCompletionRequest.of(
2626
List.of(
27-
new UnifiedCompletionRequest.Message(
28-
new UnifiedCompletionRequest.ContentString("hello"),
29-
"a role",
30-
null,
31-
null,
32-
null
33-
),
34-
new UnifiedCompletionRequest.Message(
35-
new UnifiedCompletionRequest.ContentString("awesome"),
36-
"a role",
37-
null,
38-
null,
39-
null
40-
)
27+
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "a role", null, null),
28+
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("awesome"), "a role", null, null)
4129
)
4230
)
4331
)

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ public void testModelUserFieldsSerialization() throws IOException {
3232
new UnifiedCompletionRequest.ContentString("Hello, world!"),
3333
ROLE,
3434
null,
35-
null,
3635
null
3736
);
3837
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ public void testModelUserFieldsSerialization() throws IOException {
3232
new UnifiedCompletionRequest.ContentString("Hello, world!"),
3333
ROLE,
3434
null,
35-
null,
3635
null
3736
);
3837
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ public void testBasicSerialization() throws IOException {
3939
new UnifiedCompletionRequest.ContentString("Hello, world!"),
4040
ROLE,
4141
null,
42-
null,
4342
null
4443
);
4544
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
@@ -78,7 +77,6 @@ public void testSerializationWithAllFields() throws IOException {
7877
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
7978
new UnifiedCompletionRequest.ContentString("Hello, world!"),
8079
ROLE,
81-
"name",
8280
"tool_call_id",
8381
Collections.singletonList(
8482
new UnifiedCompletionRequest.ToolCall(
@@ -127,7 +125,6 @@ public void testSerializationWithAllFields() throws IOException {
127125
{
128126
"content": "Hello, world!",
129127
"role": "user",
130-
"name": "name",
131128
"tool_call_id": "tool_call_id",
132129
"tool_calls": [
133130
{
@@ -189,7 +186,6 @@ public void testSerializationWithNullOptionalFields() throws IOException {
189186
new UnifiedCompletionRequest.ContentString("Hello, world!"),
190187
ROLE,
191188
null,
192-
null,
193189
null
194190
);
195191
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
@@ -240,7 +236,6 @@ public void testSerializationWithEmptyLists() throws IOException {
240236
new UnifiedCompletionRequest.ContentString("Hello, world!"),
241237
ROLE,
242238
null,
243-
null,
244239
Collections.emptyList() // empty toolCalls list
245240
);
246241
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
@@ -290,7 +285,6 @@ public void testSerializationWithNestedObjects() throws IOException {
290285
Random random = Randomness.get();
291286

292287
String randomContent = "Hello, world! " + random.nextInt(1000);
293-
String randomName = "name" + random.nextInt(1000);
294288
String randomToolCallId = "tool_call_id" + random.nextInt(1000);
295289
String randomArguments = "arguments" + random.nextInt(1000);
296290
String randomFunctionName = "function_name" + random.nextInt(1000);
@@ -303,7 +297,6 @@ public void testSerializationWithNestedObjects() throws IOException {
303297
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
304298
new UnifiedCompletionRequest.ContentString(randomContent),
305299
ROLE,
306-
randomName,
307300
randomToolCallId,
308301
Collections.singletonList(
309302
new UnifiedCompletionRequest.ToolCall(
@@ -357,7 +350,6 @@ public void testSerializationWithNestedObjects() throws IOException {
357350
{
358351
"content": "%s",
359352
"role": "user",
360-
"name": "%s",
361353
"tool_call_id": "%s",
362354
"tool_calls": [
363355
{
@@ -416,7 +408,6 @@ public void testSerializationWithNestedObjects() throws IOException {
416408
}
417409
""",
418410
randomContent,
419-
randomName,
420411
randomToolCallId,
421412
randomArguments,
422413
randomFunctionName,
@@ -449,11 +440,10 @@ public void testSerializationWithDifferentContentTypes() throws IOException {
449440
new UnifiedCompletionRequest.ContentString(randomContentString),
450441
ROLE,
451442
null,
452-
null,
453443
null
454444
);
455445

456-
UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message(contentObjects, ROLE, null, null, null);
446+
UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message(contentObjects, ROLE, null, null);
457447
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
458448
messageList.add(messageWithString);
459449
messageList.add(messageWithObjects);
@@ -502,7 +492,6 @@ public void testSerializationWithSpecialCharacters() throws IOException {
502492
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
503493
new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"),
504494
ROLE,
505-
"name\nwith\nnewlines",
506495
"tool_call_id\twith\ttabs",
507496
Collections.singletonList(
508497
new UnifiedCompletionRequest.ToolCall(
@@ -541,7 +530,6 @@ public void testSerializationWithSpecialCharacters() throws IOException {
541530
{
542531
"content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /",
543532
"role": "user",
544-
"name": "name\\nwith\\nnewlines",
545533
"tool_call_id": "tool_call_id\\twith\\ttabs",
546534
"tool_calls": [
547535
{
@@ -571,7 +559,6 @@ public void testSerializationWithBooleanFields() throws IOException {
571559
new UnifiedCompletionRequest.ContentString("Hello, world!"),
572560
ROLE,
573561
null,
574-
null,
575562
null
576563
);
577564
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
@@ -641,7 +628,6 @@ public void testSerializationWithoutContentField() throws IOException {
641628
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
642629
null,
643630
"assistant",
644-
"name\nwith\nnewlines",
645631
"tool_call_id\twith\ttabs",
646632
Collections.singletonList(
647633
new UnifiedCompletionRequest.ToolCall(
@@ -669,7 +655,6 @@ public void testSerializationWithoutContentField() throws IOException {
669655
"messages": [
670656
{
671657
"role": "assistant",
672-
"name": "name\\nwith\\nnewlines",
673658
"tool_call_id": "tool_call_id\\twith\\ttabs",
674659
"tool_calls": [
675660
{

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public void testOverridingModelId() {
3333
);
3434

3535
var request = new UnifiedCompletionRequest(
36-
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null, null)),
36+
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null)),
3737
"new_model_id",
3838
null,
3939
null,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -967,9 +967,7 @@ public void testUnifiedCompletionInfer() throws Exception {
967967
service.unifiedCompletionInfer(
968968
model,
969969
UnifiedCompletionRequest.of(
970-
List.of(
971-
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null, null)
972-
)
970+
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
973971
),
974972
InferenceAction.Request.DEFAULT_TIMEOUT,
975973
listener

0 commit comments

Comments
 (0)