Skip to content

Commit 4eade05

Browse files
Add unit tests for LlamaActionCreator and related models
1 parent a13020c commit 4eade05

File tree

3 files changed

+442
-0
lines changed

3 files changed

+442
-0
lines changed
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.llama.action;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.elasticsearch.ElasticsearchException;
12+
import org.elasticsearch.action.support.PlainActionFuture;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.inference.InferenceServiceResults;
16+
import org.elasticsearch.test.ESTestCase;
17+
import org.elasticsearch.test.http.MockResponse;
18+
import org.elasticsearch.test.http.MockWebServer;
19+
import org.elasticsearch.threadpool.ThreadPool;
20+
import org.elasticsearch.xcontent.XContentType;
21+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
22+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
23+
import org.elasticsearch.xpack.inference.InputTypeTests;
24+
import org.elasticsearch.xpack.inference.common.TruncatorTests;
25+
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
26+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
27+
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
28+
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
29+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
30+
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
31+
import org.elasticsearch.xpack.inference.services.ServiceComponents;
32+
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests;
33+
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingModelTests;
34+
import org.junit.After;
35+
import org.junit.Before;
36+
37+
import java.io.IOException;
38+
import java.util.List;
39+
import java.util.Map;
40+
import java.util.concurrent.TimeUnit;
41+
42+
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
43+
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
44+
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
45+
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
46+
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
47+
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
48+
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
49+
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
50+
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
51+
import static org.hamcrest.Matchers.contains;
52+
import static org.hamcrest.Matchers.equalTo;
53+
import static org.hamcrest.Matchers.hasSize;
54+
import static org.hamcrest.Matchers.instanceOf;
55+
import static org.hamcrest.Matchers.is;
56+
import static org.mockito.Mockito.mock;
57+
58+
public class LlamaActionCreatorTests extends ESTestCase {
59+
60+
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
61+
private final MockWebServer webServer = new MockWebServer();
62+
private ThreadPool threadPool;
63+
private HttpClientManager clientManager;
64+
65+
@Before
66+
public void init() throws Exception {
67+
webServer.start();
68+
threadPool = createThreadPool(inferenceUtilityPool());
69+
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
70+
}
71+
72+
@After
73+
public void shutdown() throws IOException {
74+
clientManager.close();
75+
terminate(threadPool);
76+
webServer.close();
77+
}
78+
79+
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException {
80+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
81+
82+
try (var sender = createSender(senderFactory)) {
83+
sender.start();
84+
85+
String responseJson = """
86+
{
87+
"embeddings": [
88+
[
89+
-0.0123,
90+
0.123
91+
]
92+
]
93+
{
94+
""";
95+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
96+
97+
PlainActionFuture<InferenceServiceResults> listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool));
98+
99+
var result = listener.actionGet(TIMEOUT);
100+
101+
assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
102+
103+
assertEmbeddingsRequest();
104+
}
105+
}
106+
107+
public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws IOException {
108+
var settings = buildSettingsWithRetryFields(
109+
TimeValue.timeValueMillis(1),
110+
TimeValue.timeValueMinutes(1),
111+
TimeValue.timeValueSeconds(0)
112+
);
113+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
114+
115+
try (var sender = createSender(senderFactory)) {
116+
sender.start();
117+
118+
String responseJson = """
119+
[
120+
{
121+
"embeddings": [
122+
[
123+
-0.0123,
124+
0.123
125+
]
126+
]
127+
{
128+
]
129+
""";
130+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
131+
132+
PlainActionFuture<InferenceServiceResults> listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool));
133+
134+
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
135+
assertThat(
136+
thrownException.getMessage(),
137+
is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]")
138+
);
139+
140+
assertEmbeddingsRequest();
141+
}
142+
}
143+
144+
public void testExecute_ReturnsSuccessfulResponse_ForCompletionAction() throws IOException {
145+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
146+
147+
try (var sender = createSender(senderFactory)) {
148+
sender.start();
149+
150+
String responseJson = """
151+
{
152+
"id": "chatcmpl-03e70a75-efb6-447d-b661-e5ed0bd59ce9",
153+
"choices": [
154+
{
155+
"finish_reason": "length",
156+
"index": 0,
157+
"logprobs": null,
158+
"message": {
159+
"content": "Hello there, how may I assist you today?",
160+
"refusal": null,
161+
"role": "assistant",
162+
"annotations": null,
163+
"audio": null,
164+
"function_call": null,
165+
"tool_calls": null
166+
}
167+
}
168+
],
169+
"created": 1750157476,
170+
"model": "llama3.2:3b",
171+
"object": "chat.completion",
172+
"service_tier": null,
173+
"system_fingerprint": "fp_ollama",
174+
"usage": {
175+
"completion_tokens": 10,
176+
"prompt_tokens": 30,
177+
"total_tokens": 40,
178+
"completion_tokens_details": null,
179+
"prompt_tokens_details": null
180+
}
181+
}
182+
""";
183+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
184+
185+
PlainActionFuture<InferenceServiceResults> listener = createCompletionFuture(sender, createWithEmptySettings(threadPool));
186+
187+
var result = listener.actionGet(TIMEOUT);
188+
189+
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?"))));
190+
191+
assertCompletionRequest();
192+
}
193+
}
194+
195+
public void testExecute_FailsFromInvalidResponseFormat_ForCompletionAction() throws IOException {
196+
var settings = buildSettingsWithRetryFields(
197+
TimeValue.timeValueMillis(1),
198+
TimeValue.timeValueMinutes(1),
199+
TimeValue.timeValueSeconds(0)
200+
);
201+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
202+
203+
try (var sender = createSender(senderFactory)) {
204+
sender.start();
205+
206+
String responseJson = """
207+
{
208+
"invalid_field": "unexpected"
209+
}
210+
""";
211+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
212+
213+
PlainActionFuture<InferenceServiceResults> listener = createCompletionFuture(
214+
sender,
215+
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
216+
);
217+
218+
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
219+
assertThat(
220+
thrownException.getMessage(),
221+
is("Failed to send Llama completion request from inference entity id [id]. Cause: Required [choices]")
222+
);
223+
224+
assertCompletionRequest();
225+
}
226+
}
227+
228+
private PlainActionFuture<InferenceServiceResults> createEmbeddingsFuture(Sender sender, ServiceComponents threadPool) {
229+
var model = LlamaEmbeddingModelTests.createEmbeddingsModel("model", getUrl(webServer), "secret");
230+
var actionCreator = new LlamaActionCreator(sender, threadPool);
231+
var action = actionCreator.create(model);
232+
233+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
234+
action.execute(
235+
new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
236+
InferenceAction.Request.DEFAULT_TIMEOUT,
237+
listener
238+
);
239+
return listener;
240+
}
241+
242+
private PlainActionFuture<InferenceServiceResults> createCompletionFuture(Sender sender, ServiceComponents threadPool) {
243+
var model = LlamaChatCompletionModelTests.createCompletionModel("model", getUrl(webServer), "secret");
244+
var actionCreator = new LlamaActionCreator(sender, threadPool);
245+
var action = actionCreator.create(model);
246+
247+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
248+
action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
249+
return listener;
250+
}
251+
252+
private void assertCompletionRequest() throws IOException {
253+
assertCommonRequestProperties();
254+
255+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
256+
assertThat(requestMap.size(), is(4));
257+
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));
258+
assertThat(requestMap.get("model"), is("model"));
259+
assertThat(requestMap.get("n"), is(1));
260+
assertThat(requestMap.get("stream"), is(false));
261+
}
262+
263+
@SuppressWarnings("unchecked")
264+
private void assertEmbeddingsRequest() throws IOException {
265+
assertCommonRequestProperties();
266+
267+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
268+
assertThat(requestMap.size(), is(2));
269+
assertThat(requestMap.get("contents"), instanceOf(List.class));
270+
var inputList = (List<String>) requestMap.get("contents");
271+
assertThat(inputList, contains("abc"));
272+
}
273+
274+
private void assertCommonRequestProperties() {
275+
assertThat(webServer.requests(), hasSize(1));
276+
assertNull(webServer.requests().get(0).getUri().getQuery());
277+
assertThat(
278+
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
279+
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
280+
);
281+
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
282+
}
283+
}

0 commit comments

Comments
 (0)