Skip to content

Commit df1b006

Browse files
Adding openai service tests
1 parent 1d37d8c commit df1b006

File tree

5 files changed

+96
-24
lines changed

5 files changed

+96
-24
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec
4343
private static final String TOOL_FIELD = "tools";
4444
private static final String TEXT_FIELD = "text";
4545
private static final String TYPE_FIELD = "type";
46+
private static final String STREAM_OPTIONS_FIELD = "stream_options";
47+
private static final String INCLUDE_USAGE_FIELD = "include_usage";
4648

4749
private final UnifiedCompletionRequest unifiedRequest;
4850
private final boolean stream;
@@ -169,6 +171,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
169171
}
170172

171173
builder.field(STREAM_FIELD, stream);
174+
if (stream) {
175+
builder.startObject(STREAM_OPTIONS_FIELD);
176+
builder.field(INCLUDE_USAGE_FIELD, true);
177+
builder.endObject();
178+
}
172179
builder.endObject();
173180

174181
return builder;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
public class OpenAiUnifiedChatCompletionRequestTests extends ESTestCase {
3131

3232
public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOException {
33-
var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user");
33+
var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user", true);
3434
var httpRequest = request.createHttpRequest();
3535

3636
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -42,16 +42,27 @@ public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOExceptio
4242
assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org"));
4343

4444
var requestMap = entityAsMap(httpPost.getEntity().getContent());
45-
assertThat(requestMap, aMapWithSize(5));
45+
assertRequestMapWithUser(requestMap, "user");
46+
}
47+
48+
private void assertRequestMapWithoutUser(Map<String, Object> requestMap) {
49+
assertRequestMapWithUser(requestMap, null);
50+
}
51+
52+
private void assertRequestMapWithUser(Map<String, Object> requestMap, @Nullable String user) {
53+
assertThat(requestMap, aMapWithSize(user != null ? 6 : 5));
4654
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
4755
assertThat(requestMap.get("model"), is("model"));
48-
assertThat(requestMap.get("user"), is("user"));
56+
if (user != null) {
57+
assertThat(requestMap.get("user"), is(user));
58+
}
4959
assertThat(requestMap.get("n"), is(1));
5060
assertTrue((Boolean) requestMap.get("stream"));
61+
assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true)));
5162
}
5263

5364
public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOException {
54-
var request = createRequest(null, "org", "secret", "abc", "model", "user");
65+
var request = createRequest(null, "org", "secret", "abc", "model", "user", true);
5566
var httpRequest = request.createHttpRequest();
5667

5768
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -63,16 +74,12 @@ public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOExce
6374
assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org"));
6475

6576
var requestMap = entityAsMap(httpPost.getEntity().getContent());
66-
assertThat(requestMap, aMapWithSize(5));
67-
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
68-
assertThat(requestMap.get("model"), is("model"));
69-
assertThat(requestMap.get("user"), is("user"));
70-
assertThat(requestMap.get("n"), is(1));
71-
assertTrue((Boolean) requestMap.get("stream"));
77+
assertRequestMapWithUser(requestMap, "user");
78+
7279
}
7380

7481
public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws URISyntaxException, IOException {
75-
var request = createRequest(null, null, "secret", "abc", "model", null);
82+
var request = createRequest(null, null, "secret", "abc", "model", null, true);
7683
var httpRequest = request.createHttpRequest();
7784

7885
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -84,14 +91,10 @@ public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws
8491
assertNull(httpPost.getLastHeader(ORGANIZATION_HEADER));
8592

8693
var requestMap = entityAsMap(httpPost.getEntity().getContent());
87-
assertThat(requestMap, aMapWithSize(4));
88-
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
89-
assertThat(requestMap.get("model"), is("model"));
90-
assertThat(requestMap.get("n"), is(1));
91-
assertTrue((Boolean) requestMap.get("stream"));
94+
assertRequestMapWithoutUser(requestMap);
9295
}
9396

94-
public void testCreateRequest_WithStreaming() throws URISyntaxException, IOException {
97+
public void testCreateRequest_WithStreaming() throws IOException {
9598
var request = createRequest(null, null, "secret", "abc", "model", null, true);
9699
var httpRequest = request.createHttpRequest();
97100

@@ -103,7 +106,7 @@ public void testCreateRequest_WithStreaming() throws URISyntaxException, IOExcep
103106
}
104107

105108
public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException {
106-
var request = createRequest(null, null, "secret", "abcd", "model", null);
109+
var request = createRequest(null, null, "secret", "abcd", "model", null, true);
107110
var truncatedRequest = request.truncate();
108111
assertThat(request.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString()));
109112

@@ -112,17 +115,18 @@ public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException,
112115

113116
var httpPost = (HttpPost) httpRequest.httpRequestBase();
114117
var requestMap = entityAsMap(httpPost.getEntity().getContent());
115-
assertThat(requestMap, aMapWithSize(4));
118+
assertThat(requestMap, aMapWithSize(5));
116119

117120
// We do not truncate for OpenAi chat completions
118121
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
119122
assertThat(requestMap.get("model"), is("model"));
120123
assertThat(requestMap.get("n"), is(1));
121124
assertTrue((Boolean) requestMap.get("stream"));
125+
assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true)));
122126
}
123127

124128
public void testTruncationInfo_ReturnsNull() {
125-
var request = createRequest(null, null, "secret", "abcd", "model", null);
129+
var request = createRequest(null, null, "secret", "abcd", "model", null, true);
126130
assertNull(request.getTruncationInfo());
127131
}
128132

@@ -147,7 +151,7 @@ public static OpenAiUnifiedChatCompletionRequest createRequest(
147151
boolean stream
148152
) {
149153
var chatCompletionModel = OpenAiChatCompletionModelTests.createChatCompletionModel(url, org, apiKey, model, user);
150-
return new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", true), chatCompletionModel);
154+
return new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
151155
}
152156

153157
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public void testParseParams_DefaultsToTaskTypeAny_WhenInferenceId_IsMissing() {
7474
var params = parseParams(
7575
RestRequestTests.contentRestRequest("{}", Map.of(TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString()))
7676
);
77-
assertThat(params, is(new BaseInferenceAction.Params(TASK_TYPE_OR_INFERENCE_ID, TaskType.ANY)));
77+
assertThat(params, is(new BaseInferenceAction.Params("completion", TaskType.ANY)));
7878
}
7979

8080
public void testParseParams_ThrowsStatusException_WhenTaskTypeIsMissing() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ public void sendResponse(RestResponse response) {
7171

7272
// the response content will be null when there is no error
7373
assertNull(responseSetOnce.get().content());
74-
// var responseBody = responseSetOnce.get().content().utf8ToString();
75-
// assertThat(Objects.requireNonNull(responseSetOnce.get().content()).utf8ToString(), equalTo(createResponse()));
7674
assertThat(executeCalled.get(), equalTo(true));
7775
}
7876

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.inference.Model;
2828
import org.elasticsearch.inference.SimilarityMeasure;
2929
import org.elasticsearch.inference.TaskType;
30+
import org.elasticsearch.inference.UnifiedCompletionRequest;
3031
import org.elasticsearch.test.ESTestCase;
3132
import org.elasticsearch.test.http.MockResponse;
3233
import org.elasticsearch.test.http.MockWebServer;
@@ -920,6 +921,68 @@ public void testInfer_SendsRequest() throws IOException {
920921
}
921922
}
922923

924+
public void testUnifiedCompletionInfer() throws Exception {
925+
// streaming response must be on a single line
926+
String responseJson = """
927+
data: {\
928+
"id":"12345",\
929+
"object":"chat.completion.chunk",\
930+
"created":123456789,\
931+
"model":"gpt-4o-mini",\
932+
"system_fingerprint": "123456789",\
933+
"choices":[\
934+
{\
935+
"index":0,\
936+
"delta":{\
937+
"content":"hello, world"\
938+
},\
939+
"logprobs":null,\
940+
"finish_reason":"stop"\
941+
}\
942+
],\
943+
"usage":{\
944+
"prompt_tokens": 16,\
945+
"completion_tokens": 28,\
946+
"total_tokens": 44,\
947+
"prompt_tokens_details": {\
948+
"cached_tokens": 0,\
949+
"audio_tokens": 0\
950+
},\
951+
"completion_tokens_details": {\
952+
"reasoning_tokens": 0,\
953+
"audio_tokens": 0,\
954+
"accepted_prediction_tokens": 0,\
955+
"rejected_prediction_tokens": 0\
956+
}\
957+
}\
958+
}
959+
960+
""";
961+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
962+
963+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
964+
try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
965+
var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user");
966+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
967+
service.unifiedCompletionInfer(
968+
model,
969+
UnifiedCompletionRequest.of(
970+
List.of(
971+
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null, null)
972+
)
973+
),
974+
InferenceAction.Request.DEFAULT_TIMEOUT,
975+
listener
976+
);
977+
978+
var result = listener.actionGet(TIMEOUT);
979+
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
980+
{"id":"12345","choices":[{"delta":{"content":"hello, world"},"finish_reason":"stop","index":0}],""" + """
981+
"model":"gpt-4o-mini","object":"chat.completion.chunk",""" + """
982+
"usage":{"completion_tokens":28,"prompt_tokens":16,"total_tokens":44}}""");
983+
}
984+
}
985+
923986
public void testInfer_StreamRequest() throws Exception {
924987
String responseJson = """
925988
data: {\

0 commit comments

Comments
 (0)