Skip to content

Commit 41591ae

Browse files
Refactor Llama embedding and chat completion tests for consistency and clarity
1 parent e2dce7c commit 41591ae

File tree

5 files changed

+96
-4
lines changed

5 files changed

+96
-4
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
3131
import org.elasticsearch.xpack.inference.services.ServiceComponents;
3232
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests;
33-
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingModelTests;
33+
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests;
3434
import org.junit.After;
3535
import org.junit.Before;
3636

@@ -226,7 +226,7 @@ public void testExecute_FailsFromInvalidResponseFormat_ForCompletionAction() thr
226226
}
227227

228228
private PlainActionFuture<InferenceServiceResults> createEmbeddingsFuture(Sender sender, ServiceComponents threadPool) {
229-
var model = LlamaEmbeddingModelTests.createEmbeddingsModel("model", getUrl(webServer), "secret");
229+
var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("model", getUrl(webServer), "secret");
230230
var actionCreator = new LlamaActionCreator(sender, threadPool);
231231
var action = actionCreator.create(model);
232232

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.llama.completion;
99

1010
import org.elasticsearch.common.settings.SecureString;
11+
import org.elasticsearch.inference.EmptySecretSettings;
1112
import org.elasticsearch.inference.TaskType;
1213
import org.elasticsearch.inference.UnifiedCompletionRequest;
1314
import org.elasticsearch.test.ESTestCase;
@@ -39,6 +40,16 @@ public static LlamaChatCompletionModel createChatCompletionModel(String modelId,
3940
);
4041
}
4142

43+
public static LlamaChatCompletionModel createChatCompletionModelNoAuth(String modelId, String url) {
44+
return new LlamaChatCompletionModel(
45+
"id",
46+
TaskType.CHAT_COMPLETION,
47+
"llama",
48+
new LlamaChatCompletionServiceSettings(modelId, url, null),
49+
EmptySecretSettings.INSTANCE
50+
);
51+
}
52+
4253
public void testOverrideWith_UnifiedCompletionRequest_KeepsSameModelId() {
4354
var model = createCompletionModel("model_name", "url", "api_key");
4455
var request = new UnifiedCompletionRequest(
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import org.elasticsearch.test.ESTestCase;
1414
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
1515

16-
public class LlamaEmbeddingModelTests extends ESTestCase {
16+
public class LlamaEmbeddingsModelTests extends ESTestCase {
1717
public static LlamaEmbeddingsModel createEmbeddingsModel(String modelId, String url, String apiKey) {
1818
return new LlamaEmbeddingsModel(
1919
"id",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.llama.request.completion;
9+
10+
import org.apache.http.client.methods.HttpPost;
11+
import org.elasticsearch.test.ESTestCase;
12+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
13+
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests;
14+
15+
import java.io.IOException;
16+
import java.util.List;
17+
import java.util.Map;
18+
19+
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
20+
import static org.hamcrest.Matchers.instanceOf;
21+
import static org.hamcrest.Matchers.is;
22+
23+
public class LlamaChatCompletionRequestTests extends ESTestCase {
24+
25+
public void testCreateRequest_WithStreaming() throws IOException {
26+
String input = randomAlphaOfLength(15);
27+
var request = createRequest("model", "url", "secret", input, true);
28+
var httpRequest = request.createHttpRequest();
29+
30+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
31+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
32+
33+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
34+
assertThat(request.getURI().toString(), is("url"));
35+
assertThat(requestMap.get("stream"), is(true));
36+
assertThat(requestMap.get("model"), is("model"));
37+
assertThat(requestMap.get("n"), is(1));
38+
assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true)));
39+
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
40+
assertNotNull(httpPost.getHeaders("Authorization"));
41+
}
42+
43+
public void testCreateRequest_NoStreaming_NoAuthorization() throws IOException {
44+
String input = randomAlphaOfLength(15);
45+
var request = createRequestWithNoAuth("model", "url", input, false);
46+
var httpRequest = request.createHttpRequest();
47+
48+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
49+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
50+
51+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
52+
assertThat(request.getURI().toString(), is("url"));
53+
assertThat(requestMap.get("stream"), is(false));
54+
assertThat(requestMap.get("model"), is("model"));
55+
assertThat(requestMap.get("n"), is(1));
56+
assertNull(requestMap.get("stream_options"));
57+
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
58+
assertNull(httpPost.getFirstHeader("Authorization"));
59+
}
60+
61+
public void testTruncate_DoesNotReduceInputTextSize() {
62+
String input = randomAlphaOfLength(5);
63+
var request = createRequest("model", "url", "secret", input, true);
64+
assertThat(request.truncate(), is(request));
65+
}
66+
67+
public void testTruncationInfo_ReturnsNull() {
68+
var request = createRequest("model", "url", "secret", randomAlphaOfLength(5), true);
69+
assertNull(request.getTruncationInfo());
70+
}
71+
72+
public static LlamaChatCompletionRequest createRequest(String modelId, String url, String apiKey, String input, boolean stream) {
73+
var chatCompletionModel = LlamaChatCompletionModelTests.createChatCompletionModel(modelId, url, apiKey);
74+
return new LlamaChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
75+
}
76+
77+
public static LlamaChatCompletionRequest createRequestWithNoAuth(String modelId, String url, String input, boolean stream) {
78+
var chatCompletionModel = LlamaChatCompletionModelTests.createChatCompletionModelNoAuth(modelId, url);
79+
return new LlamaChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
80+
}
81+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void testTruncate_DoesNotReduceInputTextSize() throws IOException {
4949
var requestMap = entityAsMap(httpPost.getEntity().getContent());
5050
assertThat(requestMap, aMapWithSize(4));
5151

52-
// We do not truncate for Hugging Face chat completions
52+
// We do not truncate for Mistral chat completions
5353
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
5454
assertThat(requestMap.get("model"), is("model"));
5555
assertThat(requestMap.get("n"), is(1));

0 commit comments

Comments
 (0)