Skip to content

Commit 1573d53

Browse files
Add unit tests for LlamaEmbeddingsRequest to validate request creation and truncation behavior
1 parent 4d2a5dd commit 1573d53

File tree

3 files changed

+116
-1
lines changed

3 files changed

+116
-1
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.llama.embeddings;
99

1010
import org.elasticsearch.common.settings.SecureString;
11+
import org.elasticsearch.inference.EmptySecretSettings;
1112
import org.elasticsearch.inference.EmptyTaskSettings;
1213
import org.elasticsearch.inference.TaskType;
1314
import org.elasticsearch.test.ESTestCase;
@@ -25,4 +26,16 @@ public static LlamaEmbeddingsModel createEmbeddingsModel(String modelId, String
2526
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
2627
);
2728
}
29+
30+
public static LlamaEmbeddingsModel createEmbeddingsModelNoAuth(String modelId, String url) {
31+
return new LlamaEmbeddingsModel(
32+
"id",
33+
TaskType.TEXT_EMBEDDING,
34+
"llama",
35+
new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null),
36+
EmptyTaskSettings.INSTANCE,
37+
null,
38+
EmptySecretSettings.INSTANCE
39+
);
40+
}
2841
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.services.llama.request.completion;
99

10+
import org.apache.http.HttpHeaders;
1011
import org.apache.http.client.methods.HttpPost;
1112
import org.elasticsearch.test.ESTestCase;
1213
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -37,7 +38,7 @@ public void testCreateRequest_WithStreaming() throws IOException {
3738
assertThat(requestMap.get("n"), is(1));
3839
assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true)));
3940
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
40-
assertNotNull(httpPost.getHeaders("Authorization"));
41+
assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
4142
}
4243

4344
public void testCreateRequest_NoStreaming_NoAuthorization() throws IOException {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.llama.request.embeddings;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.elasticsearch.test.ESTestCase;
13+
import org.elasticsearch.xcontent.XContentType;
14+
import org.elasticsearch.xpack.inference.common.Truncator;
15+
import org.elasticsearch.xpack.inference.common.TruncatorTests;
16+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
17+
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
22+
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
23+
import static org.hamcrest.Matchers.aMapWithSize;
24+
import static org.hamcrest.Matchers.instanceOf;
25+
import static org.hamcrest.Matchers.is;
26+
27+
public class LlamaEmbeddingsRequestTests extends ESTestCase {
28+
29+
public void testCreateRequest_WithAuth_Success() throws IOException {
30+
var request = createRequest();
31+
var httpRequest = request.createHttpRequest();
32+
var httpPost = validateRequestUrlAndContentType(httpRequest);
33+
34+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
35+
assertThat(requestMap, aMapWithSize(2));
36+
assertThat(requestMap.get("contents"), is(List.of("ABCD")));
37+
assertThat(requestMap.get("model_id"), is("llama-embed"));
38+
assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey"));
39+
}
40+
41+
public void testCreateRequest_NoAuth_Success() throws IOException {
42+
var request = createRequestNoAuth();
43+
var httpRequest = request.createHttpRequest();
44+
var httpPost = validateRequestUrlAndContentType(httpRequest);
45+
46+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
47+
assertThat(requestMap, aMapWithSize(2));
48+
assertThat(requestMap.get("contents"), is(List.of("ABCD")));
49+
assertThat(requestMap.get("model_id"), is("llama-embed"));
50+
assertNull(httpPost.getFirstHeader("Authorization"));
51+
}
52+
53+
public void testTruncate_ReducesInputTextSizeByHalf() throws IOException {
54+
var request = createRequest();
55+
var truncatedRequest = request.truncate();
56+
57+
var httpRequest = truncatedRequest.createHttpRequest();
58+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
59+
60+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
61+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
62+
assertThat(requestMap, aMapWithSize(2));
63+
assertThat(requestMap.get("contents"), is(List.of("AB")));
64+
assertThat(requestMap.get("model_id"), is("llama-embed"));
65+
}
66+
67+
public void testIsTruncated_ReturnsTrue() {
68+
var request = createRequest();
69+
assertFalse(request.getTruncationInfo()[0]);
70+
71+
var truncatedRequest = request.truncate();
72+
assertTrue(truncatedRequest.getTruncationInfo()[0]);
73+
}
74+
75+
private HttpPost validateRequestUrlAndContentType(HttpRequest request) {
76+
assertThat(request.httpRequestBase(), instanceOf(HttpPost.class));
77+
var httpPost = (HttpPost) request.httpRequestBase();
78+
assertThat(httpPost.getURI().toString(), is("url"));
79+
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters()));
80+
return httpPost;
81+
}
82+
83+
private static LlamaEmbeddingsRequest createRequest() {
84+
var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModel("llama-embed", "url", "apikey");
85+
return new LlamaEmbeddingsRequest(
86+
TruncatorTests.createTruncator(),
87+
new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }),
88+
embeddingsModel
89+
);
90+
}
91+
92+
private static LlamaEmbeddingsRequest createRequestNoAuth() {
93+
var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModelNoAuth("llama-embed", "url");
94+
return new LlamaEmbeddingsRequest(
95+
TruncatorTests.createTruncator(),
96+
new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }),
97+
embeddingsModel
98+
);
99+
}
100+
101+
}

0 commit comments

Comments
 (0)