Skip to content

Commit e2f872e

Browse files
committed
Add ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests
1 parent 9b48dfb commit e2f872e

File tree

2 files changed

+128
-8
lines changed

2 files changed

+128
-8
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction()
302302
assertThat(textEmbeddingResults.embeddings(), hasSize(2));
303303

304304
var firstEmbedding = textEmbeddingResults.embeddings().get(0);
305-
assertThat(firstEmbedding.values(), is(new float[]{2.1259406f, 1.7073475f, 0.9020516f}));
305+
assertThat(firstEmbedding.values(), is(new float[] { 2.1259406f, 1.7073475f, 0.9020516f }));
306306

307307
var secondEmbedding = textEmbeddingResults.embeddings().get(1);
308-
assertThat(secondEmbedding.values(), is(new float[]{1.8342123f, 2.3456789f, 0.7654321f}));
308+
assertThat(secondEmbedding.values(), is(new float[] { 1.8342123f, 2.3456789f, 0.7654321f }));
309309

310310
assertThat(webServer.requests(), hasSize(1));
311311
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -358,7 +358,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W
358358
assertThat(textEmbeddingResults.embeddings(), hasSize(1));
359359

360360
var embedding = textEmbeddingResults.embeddings().get(0);
361-
assertThat(embedding.values(), is(new float[]{0.1234567f, 0.9876543f}));
361+
assertThat(embedding.values(), is(new float[] { 0.1234567f, 0.9876543f }));
362362

363363
assertThat(webServer.requests(), hasSize(1));
364364
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
@@ -445,11 +445,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E
445445
var action = actionCreator.create(model);
446446

447447
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
448-
action.execute(
449-
new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED),
450-
InferenceAction.Request.DEFAULT_TIMEOUT,
451-
listener
452-
);
448+
action.execute(new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
453449

454450
var result = listener.actionGet(TIMEOUT);
455451

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic.response;
9+
10+
import org.apache.http.HttpResponse;
11+
import org.elasticsearch.test.ESTestCase;
12+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
13+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
14+
import org.elasticsearch.xpack.inference.external.request.Request;
15+
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity;
16+
17+
import java.nio.charset.StandardCharsets;
18+
19+
import static org.hamcrest.CoreMatchers.is;
20+
import static org.hamcrest.Matchers.hasSize;
21+
import static org.mockito.Mockito.mock;
22+
23+
public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests extends ESTestCase {
24+
25+
public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_NoMeta() throws Exception {
26+
String responseJson = """
27+
{
28+
"data": [
29+
[
30+
1.23,
31+
4.56,
32+
7.89
33+
]
34+
]
35+
}
36+
""";
37+
38+
TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
39+
mock(Request.class),
40+
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
41+
);
42+
43+
assertThat(parsedResults.embeddings(), hasSize(1));
44+
45+
var embedding = parsedResults.embeddings().get(0);
46+
assertThat(embedding.values(), is(new float[] { 1.23f, 4.56f, 7.89f }));
47+
}
48+
49+
public void testDenseTextEmbeddingsResponse_MultipleEmbeddingsInData_NoMeta() throws Exception {
50+
String responseJson = """
51+
{
52+
"data": [
53+
[
54+
1.23,
55+
4.56,
56+
7.89
57+
],
58+
[
59+
0.12,
60+
0.34,
61+
0.56
62+
]
63+
]
64+
}
65+
""";
66+
67+
TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
68+
mock(Request.class),
69+
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
70+
);
71+
72+
assertThat(parsedResults.embeddings(), hasSize(2));
73+
74+
var firstEmbedding = parsedResults.embeddings().get(0);
75+
assertThat(firstEmbedding.values(), is(new float[] { 1.23f, 4.56f, 7.89f }));
76+
77+
var secondEmbedding = parsedResults.embeddings().get(1);
78+
assertThat(secondEmbedding.values(), is(new float[] { 0.12f, 0.34f, 0.56f }));
79+
}
80+
81+
public void testDenseTextEmbeddingsResponse_EmptyData() throws Exception {
82+
String responseJson = """
83+
{
84+
"data": []
85+
}
86+
""";
87+
88+
TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
89+
mock(Request.class),
90+
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
91+
);
92+
93+
assertThat(parsedResults.embeddings(), hasSize(0));
94+
}
95+
96+
public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_IgnoresMeta() throws Exception {
97+
String responseJson = """
98+
{
99+
"data": [
100+
[
101+
-1.0,
102+
0.0,
103+
1.0
104+
]
105+
],
106+
"meta": {
107+
"usage": {
108+
"total_tokens": 5
109+
}
110+
}
111+
}
112+
""";
113+
114+
TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
115+
mock(Request.class),
116+
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
117+
);
118+
119+
assertThat(parsedResults.embeddings(), hasSize(1));
120+
121+
var embedding = parsedResults.embeddings().get(0);
122+
assertThat(embedding.values(), is(new float[] { -1.0f, 0.0f, 1.0f }));
123+
}
124+
}

0 commit comments

Comments
 (0)