Skip to content

Commit 7e1bdb2

Browse files
Update tests
1 parent c8775f7 commit 7e1bdb2

File tree

1 file changed

+47
-15
lines changed

1 file changed

+47
-15
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,52 @@
2626

2727
public class LlamaEmbeddingsRequestTests extends ESTestCase {
2828

29-
public void testCreateRequest_WithAuth_Success() throws IOException {
30-
testCreateRequest_WithAuth_Success("user");
29+
public void testCreateRequest_WithAuth_WithUser_NoDimensions_DimensionsSetByUserFalse_Success() throws IOException {
30+
testCreateRequest_WithAuth_Success("user", null, false, null);
3131
}
3232

33-
public void testCreateRequest_WithAuth_NoUser_Success() throws IOException {
34-
testCreateRequest_WithAuth_Success(null);
33+
public void testCreateRequest_WithAuth_WithUser_NoDimensions_DimensionsSetByUserTrue_Success() throws IOException {
34+
testCreateRequest_WithAuth_Success("user", null, true, null);
3535
}
3636

37-
private void testCreateRequest_WithAuth_Success(String user) throws IOException {
38-
var request = createRequest(user);
37+
public void testCreateRequest_WithAuth_WithUser_WithDimensions_DimensionsSetByUserFalse_Success() throws IOException {
38+
testCreateRequest_WithAuth_Success("user", 384, false, null);
39+
}
40+
41+
public void testCreateRequest_WithAuth_WithUser_WithDimensions_DimensionsSetByUserTrue_Success() throws IOException {
42+
testCreateRequest_WithAuth_Success("user", 384, true, 384);
43+
}
44+
45+
public void testCreateRequest_WithAuth_NoUser_NoDimensions_DimensionsSetByUserFalse_Success() throws IOException {
46+
testCreateRequest_WithAuth_Success(null, null, false, null);
47+
}
48+
49+
public void testCreateRequest_WithAuth_NoUser_NoDimensions_DimensionsSetByUserTrue_Success() throws IOException {
50+
testCreateRequest_WithAuth_Success(null, null, true, null);
51+
}
52+
53+
public void testCreateRequest_WithAuth_NoUser_WithDimensions_DimensionsSetByUserFalse_Success() throws IOException {
54+
testCreateRequest_WithAuth_Success(null, 384, false, null);
55+
}
56+
57+
public void testCreateRequest_WithAuth_NoUser_WithDimensions_DimensionsSetByUserTrue_Success() throws IOException {
58+
testCreateRequest_WithAuth_Success(null, 384, true, 384);
59+
}
60+
61+
private void testCreateRequest_WithAuth_Success(
62+
String user,
63+
Integer dimensions,
64+
boolean dimensionsSetByUser,
65+
Integer expectedDimensions
66+
) throws IOException {
67+
var request = createRequest(user, dimensions, dimensionsSetByUser);
3968
var httpRequest = request.createHttpRequest();
4069
var httpPost = validateRequestUrlAndContentType(httpRequest);
4170

4271
var requestMap = entityAsMap(httpPost.getEntity().getContent());
43-
if (user == null) {
44-
assertThat(requestMap, aMapWithSize(2));
45-
} else {
46-
assertThat(requestMap, aMapWithSize(3));
47-
}
4872
assertThat(requestMap.get("input"), is(List.of("ABCD")));
4973
assertThat(requestMap.get("model"), is("llama-embed"));
74+
assertThat(requestMap.get("dimensions"), is(expectedDimensions));
5075
assertThat(requestMap.get("user"), is(user));
5176
assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey"));
5277
}
@@ -85,7 +110,7 @@ public void testTruncate_ReducesInputTextSizeByHalf_NoUser() throws IOException
85110
}
86111

87112
private static void testTruncate_ReducesInputTextSizeByHalf(String user) throws IOException {
88-
var request = createRequest(user);
113+
var request = createRequest(user, null, false);
89114
var truncatedRequest = request.truncate();
90115

91116
var httpRequest = truncatedRequest.createHttpRequest();
@@ -104,7 +129,7 @@ private static void testTruncate_ReducesInputTextSizeByHalf(String user) throws
104129
}
105130

106131
public void testIsTruncated_ReturnsTrue() {
107-
var request = createRequest("user");
132+
var request = createRequest("user", null, false);
108133
assertFalse(request.getTruncationInfo()[0]);
109134

110135
var truncatedRequest = request.truncate();
@@ -119,8 +144,15 @@ private HttpPost validateRequestUrlAndContentType(HttpRequest request) {
119144
return httpPost;
120145
}
121146

122-
private static LlamaEmbeddingsRequest createRequest(String user) {
123-
var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModel("llama-embed", "url", "apikey", user, null, false);
147+
private static LlamaEmbeddingsRequest createRequest(String user, Integer dimensions, boolean dimensionsSetByUser) {
148+
var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModel(
149+
"llama-embed",
150+
"url",
151+
"apikey",
152+
user,
153+
dimensions,
154+
dimensionsSetByUser
155+
);
124156
return new LlamaEmbeddingsRequest(
125157
TruncatorTests.createTruncator(),
126158
new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }),

0 commit comments

Comments
 (0)