Skip to content

Commit 4d2a5dd

Browse files
Add unit tests for LlamaChatCompletionRequestEntity to validate message serialization
1 parent 41591ae commit 4d2a5dd

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.llama.request.completion;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.common.xcontent.XContentHelper;
12+
import org.elasticsearch.inference.UnifiedCompletionRequest;
13+
import org.elasticsearch.test.ESTestCase;
14+
import org.elasticsearch.xcontent.ToXContent;
15+
import org.elasticsearch.xcontent.XContentBuilder;
16+
import org.elasticsearch.xcontent.json.JsonXContent;
17+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
18+
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
19+
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests;
20+
21+
import java.io.IOException;
22+
import java.util.ArrayList;
23+
24+
public class LlamaChatCompletionRequestEntityTests extends ESTestCase {
25+
private static final String ROLE = "user";
26+
27+
public void testModelUserFieldsSerialization() throws IOException {
28+
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
29+
new UnifiedCompletionRequest.ContentString("Hello, world!"),
30+
ROLE,
31+
null,
32+
null
33+
);
34+
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
35+
messageList.add(message);
36+
37+
var unifiedRequest = UnifiedCompletionRequest.of(messageList);
38+
39+
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
40+
LlamaChatCompletionModel model = LlamaChatCompletionModelTests.createChatCompletionModel("model", "url", "api-key");
41+
42+
LlamaChatCompletionRequestEntity entity = new LlamaChatCompletionRequestEntity(unifiedChatInput, model);
43+
44+
XContentBuilder builder = JsonXContent.contentBuilder();
45+
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
46+
String expectedJson = """
47+
{
48+
"messages": [{
49+
"content": "Hello, world!",
50+
"role": "user"
51+
}
52+
],
53+
"model": "model",
54+
"n": 1,
55+
"stream": true,
56+
"stream_options": {
57+
"include_usage": true
58+
}
59+
}
60+
""";
61+
assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder));
62+
}
63+
64+
}

0 commit comments

Comments
 (0)