Skip to content

Commit 3f1a75a

Browse files
committed
More tests/corrections
1 parent d41538a commit 3f1a75a

File tree

5 files changed

+199
-259
lines changed

5 files changed

+199
-259
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java

Lines changed: 28 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,15 @@
1919
import org.elasticsearch.threadpool.ThreadPool;
2020
import org.elasticsearch.xcontent.XContentType;
2121
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
22-
import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionCreator;
22+
import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionCreator;
2323
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
24-
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
2524
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
2625
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
2726
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
28-
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
29-
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests;
30-
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
31-
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests;
32-
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
33-
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests;
27+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
28+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests;
29+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
30+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettingsTests;
3431
import org.hamcrest.MatcherAssert;
3532
import org.junit.After;
3633
import org.junit.Before;
@@ -45,7 +42,6 @@
4542
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
4643
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
4744
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
48-
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
4945
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
5046
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
5147
import static org.hamcrest.Matchers.is;
@@ -73,50 +69,42 @@ public void shutdown() throws IOException {
7369
webServer.close();
7470
}
7571

76-
public void testCreate_CohereEmbeddingsModel() throws IOException {
72+
public void testCreate_VoyageAIEmbeddingsModel() throws IOException {
7773
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
7874

7975
try (var sender = createSender(senderFactory)) {
8076
sender.start();
8177

8278
String responseJson = """
8379
{
84-
"id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
85-
"texts": [
86-
"hello"
87-
],
88-
"embeddings": {
89-
"float": [
90-
[
80+
"object": "list",
81+
"data": [{
82+
"object": "embedding",
83+
"embedding": [
9184
0.123,
9285
-0.123
93-
]
94-
]
95-
},
96-
"meta": {
97-
"api_version": {
98-
"version": "1"
99-
},
100-
"billed_units": {
101-
"input_tokens": 1
102-
}
103-
},
104-
"response_type": "embeddings_by_type"
86+
],
87+
"index": 0
88+
}],
89+
"model": "voyage-3-large",
90+
"usage": {
91+
"total_tokens": 123
92+
}
10593
}
10694
""";
10795
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
10896

109-
var model = CohereEmbeddingsModelTests.createModel(
97+
var model = VoyageAIEmbeddingsModelTests.createModel(
11098
getUrl(webServer),
11199
"secret",
112-
new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START),
100+
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true),
113101
1024,
114102
1024,
115103
"model",
116-
CohereEmbeddingType.FLOAT
104+
VoyageAIEmbeddingType.FLOAT
117105
);
118-
var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool));
119-
var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, CohereTruncation.END);
106+
var actionCreator = new VoyageAIActionCreator(sender, createWithEmptySettings(threadPool));
107+
var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH);
120108
var action = actionCreator.create(model, overriddenTaskSettings, InputType.UNSPECIFIED);
121109

122110
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -138,139 +126,15 @@ public void testCreate_CohereEmbeddingsModel() throws IOException {
138126
requestMap,
139127
is(
140128
Map.of(
141-
"texts",
142-
List.of("abc"),
143-
"model",
144-
"model",
145-
"input_type",
146-
"search_query",
147-
"embedding_types",
148-
List.of("float"),
149-
"truncate",
150-
"end"
129+
"output_dtype","float",
130+
"truncation", true,
131+
"input_type", "query",
132+
"output_dimension",1024,
133+
"input", List.of("abc"),
134+
"model", "model"
151135
)
152136
)
153137
);
154138
}
155139
}
156-
157-
public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException {
158-
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
159-
160-
try (var sender = createSender(senderFactory)) {
161-
sender.start();
162-
163-
String responseJson = """
164-
{
165-
"response_id": "some id",
166-
"text": "result",
167-
"generation_id": "some id",
168-
"chat_history": [
169-
{
170-
"role": "USER",
171-
"message": "input"
172-
},
173-
{
174-
"role": "CHATBOT",
175-
"message": "result"
176-
}
177-
],
178-
"finish_reason": "COMPLETE",
179-
"meta": {
180-
"api_version": {
181-
"version": "1"
182-
},
183-
"billed_units": {
184-
"input_tokens": 4,
185-
"output_tokens": 191
186-
},
187-
"tokens": {
188-
"input_tokens": 70,
189-
"output_tokens": 191
190-
}
191-
}
192-
}
193-
""";
194-
195-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
196-
197-
var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model");
198-
var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool));
199-
var action = actionCreator.create(model, Map.of());
200-
201-
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
202-
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
203-
204-
var result = listener.actionGet(TIMEOUT);
205-
206-
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result"))));
207-
assertThat(webServer.requests(), hasSize(1));
208-
assertNull(webServer.requests().get(0).getUri().getQuery());
209-
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType()));
210-
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret"));
211-
212-
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
213-
assertThat(requestMap, is(Map.of("message", "abc", "model", "model")));
214-
}
215-
}
216-
217-
public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException {
218-
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
219-
220-
try (var sender = createSender(senderFactory)) {
221-
sender.start();
222-
223-
String responseJson = """
224-
{
225-
"response_id": "some id",
226-
"text": "result",
227-
"generation_id": "some id",
228-
"chat_history": [
229-
{
230-
"role": "USER",
231-
"message": "input"
232-
},
233-
{
234-
"role": "CHATBOT",
235-
"message": "result"
236-
}
237-
],
238-
"finish_reason": "COMPLETE",
239-
"meta": {
240-
"api_version": {
241-
"version": "1"
242-
},
243-
"billed_units": {
244-
"input_tokens": 4,
245-
"output_tokens": 191
246-
},
247-
"tokens": {
248-
"input_tokens": 70,
249-
"output_tokens": 191
250-
}
251-
}
252-
}
253-
""";
254-
255-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
256-
257-
var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", null);
258-
var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool));
259-
var action = actionCreator.create(model, Map.of());
260-
261-
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
262-
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
263-
264-
var result = listener.actionGet(TIMEOUT);
265-
266-
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result"))));
267-
assertThat(webServer.requests(), hasSize(1));
268-
assertNull(webServer.requests().get(0).getUri().getQuery());
269-
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType()));
270-
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret"));
271-
272-
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
273-
assertThat(requestMap, is(Map.of("message", "abc")));
274-
}
275-
}
276140
}

0 commit comments

Comments
 (0)