Skip to content

Commit 9b48dfb

Browse files
committed
Add dense text embedding test cases to ElasticInferenceServiceActionCreatorTests
1 parent fddfd9d commit 9b48dfb

File tree

1 file changed

+209
-0
lines changed

1 file changed

+209
-0
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2323
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
2424
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
25+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
2526
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
2627
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
2728
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
2829
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
2930
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
3031
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests;
32+
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests;
3133
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
3234
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
3335
import org.junit.After;
@@ -256,6 +258,213 @@ public void testExecute_ReturnsSuccessfulResponse_ForRerankAction() throws IOExc
256258
}
257259
}
258260

261+
@SuppressWarnings("unchecked")
262+
public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() throws IOException {
263+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
264+
265+
try (var sender = createSender(senderFactory)) {
266+
sender.start();
267+
268+
String responseJson = """
269+
{
270+
"data": [
271+
[
272+
2.1259406,
273+
1.7073475,
274+
0.9020516
275+
],
276+
[
277+
1.8342123,
278+
2.3456789,
279+
0.7654321
280+
]
281+
]
282+
}
283+
""";
284+
285+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
286+
287+
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null);
288+
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
289+
var action = actionCreator.create(model);
290+
291+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
292+
action.execute(
293+
new EmbeddingsInput(List.of("hello world", "second text"), null, InputType.UNSPECIFIED),
294+
InferenceAction.Request.DEFAULT_TIMEOUT,
295+
listener
296+
);
297+
298+
var result = listener.actionGet(TIMEOUT);
299+
300+
assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
301+
var textEmbeddingResults = (TextEmbeddingFloatResults) result;
302+
assertThat(textEmbeddingResults.embeddings(), hasSize(2));
303+
304+
var firstEmbedding = textEmbeddingResults.embeddings().get(0);
305+
assertThat(firstEmbedding.values(), is(new float[]{2.1259406f, 1.7073475f, 0.9020516f}));
306+
307+
var secondEmbedding = textEmbeddingResults.embeddings().get(1);
308+
assertThat(secondEmbedding.values(), is(new float[]{1.8342123f, 2.3456789f, 0.7654321f}));
309+
310+
assertThat(webServer.requests(), hasSize(1));
311+
assertNull(webServer.requests().get(0).getUri().getQuery());
312+
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
313+
314+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
315+
assertThat(requestMap.size(), is(2));
316+
assertThat(requestMap.get("input"), instanceOf(List.class));
317+
var inputList = (List<String>) requestMap.get("input");
318+
assertThat(inputList, contains("hello world", "second text"));
319+
assertThat(requestMap.get("model"), is("my-dense-model-id"));
320+
}
321+
}
322+
323+
@SuppressWarnings("unchecked")
324+
public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_WithUsageContext() throws IOException {
325+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
326+
327+
try (var sender = createSender(senderFactory)) {
328+
sender.start();
329+
330+
String responseJson = """
331+
{
332+
"data": [
333+
[
334+
0.1234567,
335+
0.9876543
336+
]
337+
]
338+
}
339+
""";
340+
341+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
342+
343+
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null);
344+
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
345+
var action = actionCreator.create(model);
346+
347+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
348+
action.execute(
349+
new EmbeddingsInput(List.of("search query"), null, InputType.SEARCH),
350+
InferenceAction.Request.DEFAULT_TIMEOUT,
351+
listener
352+
);
353+
354+
var result = listener.actionGet(TIMEOUT);
355+
356+
assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
357+
var textEmbeddingResults = (TextEmbeddingFloatResults) result;
358+
assertThat(textEmbeddingResults.embeddings(), hasSize(1));
359+
360+
var embedding = textEmbeddingResults.embeddings().get(0);
361+
assertThat(embedding.values(), is(new float[]{0.1234567f, 0.9876543f}));
362+
363+
assertThat(webServer.requests(), hasSize(1));
364+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
365+
assertThat(requestMap.size(), is(3));
366+
assertThat(requestMap.get("input"), instanceOf(List.class));
367+
var inputList = (List<String>) requestMap.get("input");
368+
assertThat(inputList, contains("search query"));
369+
assertThat(requestMap.get("model"), is("my-dense-model-id"));
370+
assertThat(requestMap.get("usage_context"), is("search"));
371+
}
372+
}
373+
374+
@SuppressWarnings("unchecked")
375+
public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction() throws IOException {
376+
// timeout as zero for no retries
377+
var settings = buildSettingsWithRetryFields(
378+
TimeValue.timeValueMillis(1),
379+
TimeValue.timeValueMinutes(1),
380+
TimeValue.timeValueSeconds(0)
381+
);
382+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
383+
384+
try (var sender = createSender(senderFactory)) {
385+
sender.start();
386+
387+
// This will fail because the expected output is {"data": [[...]]}
388+
String responseJson = """
389+
{
390+
"data": {
391+
"embedding": [2.1259406, 1.7073475]
392+
}
393+
}
394+
""";
395+
396+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
397+
398+
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null);
399+
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
400+
var action = actionCreator.create(model);
401+
402+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
403+
action.execute(
404+
new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED),
405+
InferenceAction.Request.DEFAULT_TIMEOUT,
406+
listener
407+
);
408+
409+
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
410+
assertThat(
411+
thrownException.getMessage(),
412+
is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]")
413+
);
414+
415+
assertThat(webServer.requests(), hasSize(1));
416+
assertNull(webServer.requests().get(0).getUri().getQuery());
417+
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
418+
419+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
420+
assertThat(requestMap.size(), is(2));
421+
assertThat(requestMap.get("input"), instanceOf(List.class));
422+
var inputList = (List<String>) requestMap.get("input");
423+
assertThat(inputList, contains("hello world"));
424+
assertThat(requestMap.get("model"), is("my-dense-model-id"));
425+
}
426+
}
427+
428+
@SuppressWarnings("unchecked")
429+
public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_EmptyEmbeddings() throws IOException {
430+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
431+
432+
try (var sender = createSender(senderFactory)) {
433+
sender.start();
434+
435+
String responseJson = """
436+
{
437+
"data": []
438+
}
439+
""";
440+
441+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
442+
443+
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null);
444+
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
445+
var action = actionCreator.create(model);
446+
447+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
448+
action.execute(
449+
new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED),
450+
InferenceAction.Request.DEFAULT_TIMEOUT,
451+
listener
452+
);
453+
454+
var result = listener.actionGet(TIMEOUT);
455+
456+
assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
457+
var textEmbeddingResults = (TextEmbeddingFloatResults) result;
458+
assertThat(textEmbeddingResults.embeddings(), hasSize(0));
459+
460+
assertThat(webServer.requests(), hasSize(1));
461+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
462+
assertThat(requestMap.get("input"), instanceOf(List.class));
463+
var inputList = (List<String>) requestMap.get("input");
464+
assertThat(inputList, hasSize(0));
465+
}
466+
}
467+
259468
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
260469
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
261470

0 commit comments

Comments
 (0)