Skip to content

Commit d41538a

Browse files
committed
More tests
1 parent 8f6e03b commit d41538a

File tree

13 files changed

+2031
-2
lines changed

13 files changed

+2031
-2
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838

3939
public class VoyageAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings {
4040
public static final String NAME = "voyageai_embeddings_service_settings";
41+
public static final VoyageAIEmbeddingsServiceSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsServiceSettings(
42+
null, null, null, null, null
43+
);
4144

4245
static final String EMBEDDING_TYPE = "embedding_type";
4346

@@ -118,7 +121,7 @@ static VoyageAIEmbeddingType fromVoyageAIOrDenseVectorEnumValues(String enumStri
118121

119122
public VoyageAIEmbeddingsServiceSettings(
120123
VoyageAIServiceSettings commonSettings,
121-
VoyageAIEmbeddingType embeddingType,
124+
@Nullable VoyageAIEmbeddingType embeddingType,
122125
@Nullable SimilarityMeasure similarity,
123126
@Nullable Integer dimensions,
124127
@Nullable Integer maxInputTokens
@@ -127,7 +130,7 @@ public VoyageAIEmbeddingsServiceSettings(
127130
this.similarity = similarity;
128131
this.dimensions = dimensions;
129132
this.maxInputTokens = maxInputTokens;
130-
this.embeddingType = Objects.requireNonNull(embeddingType);
133+
this.embeddingType = embeddingType;
131134
}
132135

133136
public VoyageAIEmbeddingsServiceSettings(StreamInput in) throws IOException {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.voyageai;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.elasticsearch.action.support.PlainActionFuture;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.core.TimeValue;
14+
import org.elasticsearch.inference.InferenceServiceResults;
15+
import org.elasticsearch.inference.InputType;
16+
import org.elasticsearch.test.ESTestCase;
17+
import org.elasticsearch.test.http.MockResponse;
18+
import org.elasticsearch.test.http.MockWebServer;
19+
import org.elasticsearch.threadpool.ThreadPool;
20+
import org.elasticsearch.xcontent.XContentType;
21+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
22+
import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionCreator;
23+
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
24+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
25+
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
26+
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
27+
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
28+
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
29+
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests;
30+
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
31+
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests;
32+
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
33+
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests;
34+
import org.hamcrest.MatcherAssert;
35+
import org.junit.After;
36+
import org.junit.Before;
37+
38+
import java.io.IOException;
39+
import java.util.List;
40+
import java.util.Map;
41+
import java.util.concurrent.TimeUnit;
42+
43+
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
44+
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
45+
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
46+
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
47+
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
48+
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
49+
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
50+
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
51+
import static org.hamcrest.Matchers.is;
52+
import static org.hamcrest.Matchers.hasSize;
53+
import static org.hamcrest.Matchers.equalTo;
54+
import static org.mockito.Mockito.mock;
55+
56+
public class VoyageAIActionCreatorTests extends ESTestCase {
57+
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
58+
private final MockWebServer webServer = new MockWebServer();
59+
private ThreadPool threadPool;
60+
private HttpClientManager clientManager;
61+
62+
@Before
63+
public void init() throws Exception {
64+
webServer.start();
65+
threadPool = createThreadPool(inferenceUtilityPool());
66+
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
67+
}
68+
69+
@After
70+
public void shutdown() throws IOException {
71+
clientManager.close();
72+
terminate(threadPool);
73+
webServer.close();
74+
}
75+
76+
public void testCreate_CohereEmbeddingsModel() throws IOException {
77+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
78+
79+
try (var sender = createSender(senderFactory)) {
80+
sender.start();
81+
82+
String responseJson = """
83+
{
84+
"id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
85+
"texts": [
86+
"hello"
87+
],
88+
"embeddings": {
89+
"float": [
90+
[
91+
0.123,
92+
-0.123
93+
]
94+
]
95+
},
96+
"meta": {
97+
"api_version": {
98+
"version": "1"
99+
},
100+
"billed_units": {
101+
"input_tokens": 1
102+
}
103+
},
104+
"response_type": "embeddings_by_type"
105+
}
106+
""";
107+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
108+
109+
var model = CohereEmbeddingsModelTests.createModel(
110+
getUrl(webServer),
111+
"secret",
112+
new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START),
113+
1024,
114+
1024,
115+
"model",
116+
CohereEmbeddingType.FLOAT
117+
);
118+
var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool));
119+
var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, CohereTruncation.END);
120+
var action = actionCreator.create(model, overriddenTaskSettings, InputType.UNSPECIFIED);
121+
122+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
123+
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
124+
125+
var result = listener.actionGet(TIMEOUT);
126+
127+
MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F }))));
128+
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
129+
assertNull(webServer.requests().get(0).getUri().getQuery());
130+
MatcherAssert.assertThat(
131+
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
132+
equalTo(XContentType.JSON.mediaType())
133+
);
134+
MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
135+
136+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
137+
MatcherAssert.assertThat(
138+
requestMap,
139+
is(
140+
Map.of(
141+
"texts",
142+
List.of("abc"),
143+
"model",
144+
"model",
145+
"input_type",
146+
"search_query",
147+
"embedding_types",
148+
List.of("float"),
149+
"truncate",
150+
"end"
151+
)
152+
)
153+
);
154+
}
155+
}
156+
157+
public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException {
158+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
159+
160+
try (var sender = createSender(senderFactory)) {
161+
sender.start();
162+
163+
String responseJson = """
164+
{
165+
"response_id": "some id",
166+
"text": "result",
167+
"generation_id": "some id",
168+
"chat_history": [
169+
{
170+
"role": "USER",
171+
"message": "input"
172+
},
173+
{
174+
"role": "CHATBOT",
175+
"message": "result"
176+
}
177+
],
178+
"finish_reason": "COMPLETE",
179+
"meta": {
180+
"api_version": {
181+
"version": "1"
182+
},
183+
"billed_units": {
184+
"input_tokens": 4,
185+
"output_tokens": 191
186+
},
187+
"tokens": {
188+
"input_tokens": 70,
189+
"output_tokens": 191
190+
}
191+
}
192+
}
193+
""";
194+
195+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
196+
197+
var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model");
198+
var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool));
199+
var action = actionCreator.create(model, Map.of());
200+
201+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
202+
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
203+
204+
var result = listener.actionGet(TIMEOUT);
205+
206+
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result"))));
207+
assertThat(webServer.requests(), hasSize(1));
208+
assertNull(webServer.requests().get(0).getUri().getQuery());
209+
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType()));
210+
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret"));
211+
212+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
213+
assertThat(requestMap, is(Map.of("message", "abc", "model", "model")));
214+
}
215+
}
216+
217+
public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException {
218+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
219+
220+
try (var sender = createSender(senderFactory)) {
221+
sender.start();
222+
223+
String responseJson = """
224+
{
225+
"response_id": "some id",
226+
"text": "result",
227+
"generation_id": "some id",
228+
"chat_history": [
229+
{
230+
"role": "USER",
231+
"message": "input"
232+
},
233+
{
234+
"role": "CHATBOT",
235+
"message": "result"
236+
}
237+
],
238+
"finish_reason": "COMPLETE",
239+
"meta": {
240+
"api_version": {
241+
"version": "1"
242+
},
243+
"billed_units": {
244+
"input_tokens": 4,
245+
"output_tokens": 191
246+
},
247+
"tokens": {
248+
"input_tokens": 70,
249+
"output_tokens": 191
250+
}
251+
}
252+
}
253+
""";
254+
255+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
256+
257+
var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", null);
258+
var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool));
259+
var action = actionCreator.create(model, Map.of());
260+
261+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
262+
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
263+
264+
var result = listener.actionGet(TIMEOUT);
265+
266+
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result"))));
267+
assertThat(webServer.requests(), hasSize(1));
268+
assertNull(webServer.requests().get(0).getUri().getQuery());
269+
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType()));
270+
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret"));
271+
272+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
273+
assertThat(requestMap, is(Map.of("message", "abc")));
274+
}
275+
}
276+
}

0 commit comments

Comments
 (0)