Skip to content

Commit 8a185b9

Browse files
Make API_COMPLETIONS_PATH public and add unit tests for Ai21ChatCompletionRequestEntity and Ai21ChatCompletionRequest
1 parent 0e7f310 commit 8a185b9

File tree

4 files changed

+141
-2
lines changed

4 files changed

+141
-2
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/completion/Ai21ChatCompletionModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
* This class extends RateLimitGroupingModel to handle rate limiting based on model and API key.
2929
*/
3030
public class Ai21ChatCompletionModel extends Ai21Model {
31-
private static final String API_COMPLETIONS_PATH = "https://api.ai21.com/studio/v1/chat/completions";
31+
public static final String API_COMPLETIONS_PATH = "https://api.ai21.com/studio/v1/chat/completions";
3232

3333
/**
3434
* Constructor for Ai21ChatCompletionModel.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.ai21.request;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.common.xcontent.XContentHelper;
12+
import org.elasticsearch.inference.UnifiedCompletionRequest;
13+
import org.elasticsearch.test.ESTestCase;
14+
import org.elasticsearch.xcontent.ToXContent;
15+
import org.elasticsearch.xcontent.XContentBuilder;
16+
import org.elasticsearch.xcontent.json.JsonXContent;
17+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
18+
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModel;
19+
20+
import java.io.IOException;
21+
import java.util.ArrayList;
22+
23+
import static org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModelTests.createCompletionModel;
24+
25+
public class Ai21ChatCompletionRequestEntityTests extends ESTestCase {
26+
private static final String ROLE = "user";
27+
28+
public void testModelUserFieldsSerialization() throws IOException {
29+
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
30+
new UnifiedCompletionRequest.ContentString("Hello, world!"),
31+
ROLE,
32+
null,
33+
null
34+
);
35+
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
36+
messageList.add(message);
37+
38+
var unifiedRequest = UnifiedCompletionRequest.of(messageList);
39+
40+
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
41+
Ai21ChatCompletionModel model = createCompletionModel("api-key", "test-model");
42+
43+
Ai21ChatCompletionRequestEntity entity = new Ai21ChatCompletionRequestEntity(unifiedChatInput, model);
44+
45+
XContentBuilder builder = JsonXContent.contentBuilder();
46+
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
47+
String expectedJson = """
48+
{
49+
"messages": [{
50+
"content": "Hello, world!",
51+
"role": "user"
52+
}
53+
],
54+
"model": "test-model",
55+
"n": 1,
56+
"stream": true,
57+
"stream_options": {
58+
"include_usage": true
59+
}
60+
}
61+
""";
62+
assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder));
63+
}
64+
65+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.ai21.request;
9+
10+
import org.apache.http.client.methods.HttpPost;
11+
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.test.ESTestCase;
13+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
14+
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModel;
15+
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModelTests;
16+
17+
import java.io.IOException;
18+
import java.util.List;
19+
import java.util.Map;
20+
21+
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
22+
import static org.hamcrest.Matchers.aMapWithSize;
23+
import static org.hamcrest.Matchers.instanceOf;
24+
import static org.hamcrest.Matchers.is;
25+
26+
public class Ai21ChatCompletionRequestTests extends ESTestCase {
27+
28+
public void testCreateRequest_WithStreaming() throws IOException {
29+
var request = createRequest("secret", randomAlphaOfLength(15), "model", true);
30+
var httpRequest = request.createHttpRequest();
31+
32+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
33+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
34+
35+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
36+
assertThat(requestMap.get("stream"), is(true));
37+
}
38+
39+
public void testTruncate_DoesNotReduceInputTextSize() throws IOException {
40+
String input = randomAlphaOfLength(5);
41+
var request = createRequest("secret", input, "model", true);
42+
var truncatedRequest = request.truncate();
43+
assertThat(request.getURI().toString(), is(Ai21ChatCompletionModel.API_COMPLETIONS_PATH));
44+
45+
var httpRequest = truncatedRequest.createHttpRequest();
46+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
47+
48+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
49+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
50+
assertThat(requestMap, aMapWithSize(5));
51+
52+
// We do not truncate for AI21 chat completions
53+
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
54+
assertThat(requestMap.get("model"), is("model"));
55+
assertThat(requestMap.get("n"), is(1));
56+
assertTrue((Boolean) requestMap.get("stream"));
57+
assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true)));
58+
}
59+
60+
public void testTruncationInfo_ReturnsNull() {
61+
var request = createRequest("secret", randomAlphaOfLength(5), "model", true);
62+
assertNull(request.getTruncationInfo());
63+
}
64+
65+
public static Ai21ChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model) {
66+
return createRequest(apiKey, input, model, false);
67+
}
68+
69+
public static Ai21ChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model, boolean stream) {
70+
var chatCompletionModel = Ai21ChatCompletionModelTests.createCompletionModel(apiKey, model);
71+
return new Ai21ChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
72+
}
73+
74+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void testTruncate_DoesNotReduceInputTextSize() throws IOException {
4949
var requestMap = entityAsMap(httpPost.getEntity().getContent());
5050
assertThat(requestMap, aMapWithSize(4));
5151

52-
// We do not truncate for Hugging Face chat completions
52+
// We do not truncate for Mistral chat completions
5353
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
5454
assertThat(requestMap.get("model"), is("model"));
5555
assertThat(requestMap.get("n"), is(1));

0 commit comments

Comments
 (0)