Skip to content

Commit 66a41fb

Browse files
[Inference API] Make message content optional in unified API (elastic#118998)
* Allow for null/empty content field * remove tests which checked for null content * [CI] Auto commit changes from spotless * Improvements from review --------- Co-authored-by: elasticsearchmachine <[email protected]> (cherry picked from commit 79a8226) # Conflicts: # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java
1 parent 12a9b37 commit 66a41fb

File tree

3 files changed

+54
-111
lines changed

3 files changed

+54
-111
lines changed

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,12 @@ public record Message(Content content, String role, @Nullable String name, @Null
122122
);
123123

124124
static {
125-
PARSER.declareField(constructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE_ARRAY);
125+
PARSER.declareField(
126+
optionalConstructorArg(),
127+
(p, c) -> parseContent(p),
128+
new ParseField("content"),
129+
ObjectParser.ValueType.VALUE_ARRAY
130+
);
126131
PARSER.declareString(constructorArg(), new ParseField("role"));
127132
PARSER.declareString(optionalConstructorArg(), new ParseField("name"));
128133
PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id"));
@@ -143,7 +148,7 @@ private static Content parseContent(XContentParser parser) throws IOException {
143148

144149
public Message(StreamInput in) throws IOException {
145150
this(
146-
in.readNamedWriteable(Content.class),
151+
in.readOptionalNamedWriteable(Content.class),
147152
in.readString(),
148153
in.readOptionalString(),
149154
in.readOptionalString(),
@@ -153,7 +158,7 @@ public Message(StreamInput in) throws IOException {
153158

154159
@Override
155160
public void writeTo(StreamOutput out) throws IOException {
156-
out.writeNamedWriteable(content);
161+
out.writeOptionalNamedWriteable(content);
157162
out.writeString(role);
158163
out.writeOptionalString(name);
159164
out.writeOptionalString(toolCallId);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
6666
for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) {
6767
builder.startObject();
6868
{
69-
if (message.content() instanceof UnifiedCompletionRequest.ContentString contentString) {
69+
if (message.content() == null) {
70+
// content is optional
71+
} else if (message.content() instanceof UnifiedCompletionRequest.ContentString contentString) {
7072
builder.field(CONTENT_FIELD, contentString.content());
7173
} else if (message.content() instanceof UnifiedCompletionRequest.ContentObjects contentObjects) {
7274
builder.startArray(CONTENT_FIELD);
@@ -77,10 +79,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
7779
builder.endObject();
7880
}
7981
builder.endArray();
80-
} else {
81-
throw new IllegalArgumentException(
82-
Strings.format("Unsupported message.content class received: %s", message.content().getClass().getSimpleName())
83-
);
8482
}
8583

8684
builder.field(ROLE_FIELD, message.role());

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java

Lines changed: 43 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -702,122 +702,62 @@ public void testSerializationWithBooleanFields() throws IOException {
702702
assertJsonEquals(expectedJsonFalse, jsonStringFalse);
703703
}
704704

705-
// 9. Serialization with Missing Required Fields
706-
// Test with missing required fields to ensure appropriate exceptions are thrown.
707-
public void testSerializationWithMissingRequiredFields() {
708-
// Create a message with missing content (required field)
705+
// 9. a test without the content field to show that the content field is optional
706+
public void testSerializationWithoutContentField() throws IOException {
709707
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
710-
null, // missing content
711-
OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD,
712-
null,
713708
null,
714-
null
715-
);
716-
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
717-
messageList.add(message);
718-
// Create the unified request
719-
UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(
720-
messageList,
721-
null, // model
722-
null, // maxCompletionTokens
723-
null, // stop
724-
null, // temperature
725-
null, // toolChoice
726-
null, // tools
727-
null // topP
728-
);
729-
730-
// Create the unified chat input
731-
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
732-
733-
OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null);
734-
735-
// Create the entity
736-
OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model);
737-
738-
// Attempt to serialize to XContent and expect an exception
739-
try {
740-
XContentBuilder builder = JsonXContent.contentBuilder();
741-
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
742-
fail("Expected an exception due to missing required fields");
743-
} catch (NullPointerException | IOException e) {
744-
// Expected exception
745-
}
746-
}
747-
748-
// 10. Serialization with Mixed Valid and Invalid Data
749-
// Test with a mix of valid and invalid data to ensure the serializer handles it gracefully.
750-
public void testSerializationWithMixedValidAndInvalidData() throws IOException {
751-
// Create a valid message
752-
UnifiedCompletionRequest.Message validMessage = new UnifiedCompletionRequest.Message(
753-
new UnifiedCompletionRequest.ContentString("Valid content"),
754-
OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD,
755-
"validName",
756-
"validToolCallId",
757-
Collections.singletonList(
758-
new UnifiedCompletionRequest.ToolCall(
759-
"validId",
760-
new UnifiedCompletionRequest.ToolCall.FunctionField("validArguments", "validFunctionName"),
761-
"validType"
762-
)
763-
)
764-
);
765-
766-
// Create an invalid message with null content
767-
UnifiedCompletionRequest.Message invalidMessage = new UnifiedCompletionRequest.Message(
768-
null, // invalid content
769-
OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD,
770-
"invalidName",
771-
"invalidToolCallId",
709+
"assistant",
710+
"name\nwith\nnewlines",
711+
"tool_call_id\twith\ttabs",
772712
Collections.singletonList(
773713
new UnifiedCompletionRequest.ToolCall(
774-
"invalidId",
775-
new UnifiedCompletionRequest.ToolCall.FunctionField("invalidArguments", "invalidFunctionName"),
776-
"invalidType"
714+
"id\\with\\backslashes",
715+
new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"),
716+
"type"
777717
)
778718
)
779719
);
780720
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
781-
messageList.add(validMessage);
782-
messageList.add(invalidMessage);
783-
// Create the unified request with both valid and invalid messages
784-
UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(
785-
messageList,
786-
"model-name",
787-
100L, // maxCompletionTokens
788-
Collections.singletonList("stop"),
789-
0.9f, // temperature
790-
new UnifiedCompletionRequest.ToolChoiceString("tool_choice"),
791-
Collections.singletonList(
792-
new UnifiedCompletionRequest.Tool(
793-
"type",
794-
new UnifiedCompletionRequest.Tool.FunctionField(
795-
"Fetches the weather in the given location",
796-
"get_weather",
797-
createParameters(),
798-
true
799-
)
800-
)
801-
),
802-
0.8f // topP
803-
);
721+
messageList.add(message);
722+
UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null);
804723

805-
// Create the unified chat input
806724
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
725+
OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null);
807726

808-
OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null);
809-
810-
// Create the entity
811727
OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model);
812728

813-
// Serialize to XContent and verify
814-
try {
815-
XContentBuilder builder = JsonXContent.contentBuilder();
816-
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
817-
fail("Expected an exception due to invalid data");
818-
} catch (NullPointerException | IOException e) {
819-
// Expected exception
820-
}
729+
XContentBuilder builder = JsonXContent.contentBuilder();
730+
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
731+
732+
String jsonString = Strings.toString(builder);
733+
String expectedJson = """
734+
{
735+
"messages": [
736+
{
737+
"role": "assistant",
738+
"name": "name\\nwith\\nnewlines",
739+
"tool_call_id": "tool_call_id\\twith\\ttabs",
740+
"tool_calls": [
741+
{
742+
"id": "id\\\\with\\\\backslashes",
743+
"function": {
744+
"arguments": "arguments\\"with\\"quotes",
745+
"name": "function_name/with/slashes"
746+
},
747+
"type": "type"
748+
}
749+
]
750+
}
751+
],
752+
"model": "test-endpoint",
753+
"n": 1,
754+
"stream": true,
755+
"stream_options": {
756+
"include_usage": true
757+
}
758+
}
759+
""";
760+
assertJsonEquals(jsonString, expectedJson);
821761
}
822762

823763
public static Map<String, Object> createParameters() {

0 commit comments

Comments
 (0)