Skip to content

Commit 7d09f05

Browse files
authored
[8.19] [EIS] Dense Text Embedding task type integration (#129847) (#129963)
1 parent 48fa4aa commit 7d09f05

File tree

25 files changed

+1839
-177
lines changed

25 files changed

+1839
-177
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public void testGetDefaultEndpoints() throws IOException {
3333
var allModels = getAllModels();
3434
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
3535

36-
assertThat(allModels, hasSize(5));
36+
assertThat(allModels, hasSize(7));
3737
assertThat(chatCompletionModels, hasSize(1));
3838

3939
for (var model : chatCompletionModels) {
@@ -42,6 +42,8 @@ public void testGetDefaultEndpoints() throws IOException {
4242

4343
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
4444
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
45+
assertInferenceIdTaskType(allModels, ".multilingual-embed-v1-elastic", TaskType.TEXT_EMBEDDING);
46+
assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK);
4547
}
4648

4749
private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
2222
import static org.hamcrest.Matchers.containsInAnyOrder;
23+
import static org.hamcrest.Matchers.equalTo;
2324

2425
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2526

@@ -76,16 +77,21 @@ private Iterable<String> providers(List<Object> services) {
7677
}
7778

7879
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
80+
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
81+
assertThat(services.size(), equalTo(18));
82+
7983
assertThat(
8084
providersFor(TaskType.TEXT_EMBEDDING),
8185
containsInAnyOrder(
8286
List.of(
8387
"alibabacloud-ai-search",
8488
"amazonbedrock",
89+
"amazon_sagemaker",
8590
"azureaistudio",
8691
"azureopenai",
8792
"cohere",
8893
"custom",
94+
"elastic",
8995
"elasticsearch",
9096
"googleaistudio",
9197
"googlevertexai",
@@ -95,8 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
95101
"openai",
96102
"text_embedding_test_service",
97103
"voyageai",
98-
"watsonxai",
99-
"amazon_sagemaker"
104+
"watsonxai"
100105
).toArray()
101106
)
102107
);
@@ -114,6 +119,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
114119
"alibabacloud-ai-search",
115120
"cohere",
116121
"custom",
122+
"elastic",
117123
"elasticsearch",
118124
"googlevertexai",
119125
"jinaai",

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ public void enqueueAuthorizeAllModelsResponse() {
4141
{
4242
"model_name": "elser-v2",
4343
"task_types": ["embed/text/sparse"]
44+
},
45+
{
46+
"model_name": "multilingual-embed-v1",
47+
"task_types": ["embed/text/dense"]
48+
},
49+
{
50+
"model_name": "rerank-v1",
51+
"task_types": ["rerank/text/text-similarity"]
4452
}
4553
]
4654
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.action.support.PlainActionFuture;
1212
import org.elasticsearch.common.settings.Settings;
1313
import org.elasticsearch.core.TimeValue;
14+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1415
import org.elasticsearch.inference.InferenceService;
1516
import org.elasticsearch.inference.MinimalServiceSettings;
1617
import org.elasticsearch.inference.Model;
@@ -43,6 +44,7 @@
4344
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
4445
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
4546
import static org.hamcrest.CoreMatchers.is;
47+
import static org.hamcrest.Matchers.containsInAnyOrder;
4648
import static org.mockito.Mockito.mock;
4749

4850
public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
@@ -94,7 +96,6 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
9496
try (var service = createElasticInferenceService()) {
9597
ensureAuthorizationCallFinished(service);
9698
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
97-
9899
assertThat(
99100
service.defaultConfigIds(),
100101
is(
@@ -191,13 +192,21 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
191192
String responseJson = """
192193
{
193194
"models": [
195+
{
196+
"model_name": "elser-v2",
197+
"task_types": ["embed/text/sparse"]
198+
},
194199
{
195200
"model_name": "rainbow-sprinkles",
196201
"task_types": ["chat"]
197202
},
198203
{
199-
"model_name": "elser-v2",
200-
"task_types": ["embed/text/sparse"]
204+
"model_name": "multilingual-embed-v1",
205+
"task_types": ["embed/text/dense"]
206+
},
207+
{
208+
"model_name": "rerank-v1",
209+
"task_types": ["rerank/text/text-similarity"]
201210
}
202211
]
203212
}
@@ -211,27 +220,48 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
211220
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
212221
assertThat(
213222
service.defaultConfigIds(),
214-
is(
215-
List.of(
216-
new InferenceService.DefaultConfigId(
217-
".elser-v2-elastic",
218-
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
219-
service
223+
containsInAnyOrder(
224+
new InferenceService.DefaultConfigId(
225+
".elser-v2-elastic",
226+
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
227+
service
228+
),
229+
new InferenceService.DefaultConfigId(
230+
".rainbow-sprinkles-elastic",
231+
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
232+
service
233+
),
234+
new InferenceService.DefaultConfigId(
235+
".multilingual-embed-v1-elastic",
236+
MinimalServiceSettings.textEmbedding(
237+
ElasticInferenceService.NAME,
238+
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
239+
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
240+
DenseVectorFieldMapper.ElementType.FLOAT
220241
),
221-
new InferenceService.DefaultConfigId(
222-
".rainbow-sprinkles-elastic",
223-
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
224-
service
225-
)
242+
service
243+
),
244+
new InferenceService.DefaultConfigId(
245+
".rerank-v1-elastic",
246+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
247+
service
226248
)
227249
)
228250
);
229-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
251+
assertThat(
252+
service.supportedTaskTypes(),
253+
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING))
254+
);
230255

231256
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
232257
service.defaultConfigs(listener);
233258
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
234-
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
259+
assertThat(
260+
listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(),
261+
is(".multilingual-embed-v1-elastic")
262+
);
263+
assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
264+
assertThat(listener.actionGet(TIMEOUT).get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
235265

236266
var getModelListener = new PlainActionFuture<UnparsedModel>();
237267
// persists the default endpoints
@@ -249,6 +279,14 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
249279
{
250280
"model_name": "elser-v2",
251281
"task_types": ["embed/text/sparse"]
282+
},
283+
{
284+
"model_name": "rerank-v1",
285+
"task_types": ["rerank/text/text-similarity"]
286+
},
287+
{
288+
"model_name": "multilingual-embed-v1",
289+
"task_types": ["embed/text/dense"]
252290
}
253291
]
254292
}
@@ -262,17 +300,33 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
262300
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
263301
assertThat(
264302
service.defaultConfigIds(),
265-
is(
266-
List.of(
267-
new InferenceService.DefaultConfigId(
268-
".elser-v2-elastic",
269-
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
270-
service
271-
)
303+
containsInAnyOrder(
304+
new InferenceService.DefaultConfigId(
305+
".elser-v2-elastic",
306+
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
307+
service
308+
),
309+
new InferenceService.DefaultConfigId(
310+
".multilingual-embed-v1-elastic",
311+
MinimalServiceSettings.textEmbedding(
312+
ElasticInferenceService.NAME,
313+
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
314+
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
315+
DenseVectorFieldMapper.ElementType.FLOAT
316+
),
317+
service
318+
),
319+
new InferenceService.DefaultConfigId(
320+
".rerank-v1-elastic",
321+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
322+
service
272323
)
273324
)
274325
);
275-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
326+
assertThat(
327+
service.supportedTaskTypes(),
328+
is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
329+
);
276330

277331
var getModelListener = new PlainActionFuture<UnparsedModel>();
278332
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.response.elastic;
9+
10+
import org.elasticsearch.common.xcontent.XContentParserUtils;
11+
import org.elasticsearch.xcontent.ConstructingObjectParser;
12+
import org.elasticsearch.xcontent.ParseField;
13+
import org.elasticsearch.xcontent.XContentFactory;
14+
import org.elasticsearch.xcontent.XContentParserConfiguration;
15+
import org.elasticsearch.xcontent.XContentType;
16+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
17+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
18+
import org.elasticsearch.xpack.inference.external.request.Request;
19+
20+
import java.io.IOException;
21+
import java.util.List;
22+
23+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
24+
25+
public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity {
26+
27+
/**
28+
* Parses the Elastic Inference Service Dense Text Embeddings response.
29+
*
30+
* For a request like:
31+
*
32+
* <pre>
33+
* <code>
34+
* {
35+
* "inputs": ["Embed this text", "Embed this text, too"]
36+
* }
37+
* </code>
38+
* </pre>
39+
*
40+
* The response would look like:
41+
*
42+
* <pre>
43+
* <code>
44+
* {
45+
* "data": [
46+
* [
47+
* 2.1259406,
48+
* 1.7073475,
49+
* 0.9020516
50+
* ],
51+
* (...)
52+
* ],
53+
* "meta": {
54+
* "usage": {...}
55+
* }
56+
* }
57+
* </code>
58+
* </pre>
59+
*/
60+
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
61+
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
62+
return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
63+
}
64+
}
65+
66+
public record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> embeddingResults) {
67+
@SuppressWarnings("unchecked")
68+
public static final ConstructingObjectParser<EmbeddingFloatResult, Void> PARSER = new ConstructingObjectParser<>(
69+
EmbeddingFloatResult.class.getSimpleName(),
70+
true,
71+
args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0])
72+
);
73+
74+
static {
75+
// Custom field declaration to handle array of arrays format
76+
PARSER.declareField(constructorArg(), (parser, context) -> {
77+
return XContentParserUtils.parseList(parser, (p, index) -> {
78+
List<Float> embedding = XContentParserUtils.parseList(p, (innerParser, innerIndex) -> innerParser.floatValue());
79+
return EmbeddingFloatResultEntry.fromFloatArray(embedding);
80+
});
81+
}, new ParseField("data"), org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY);
82+
}
83+
84+
public TextEmbeddingFloatResults toTextEmbeddingFloatResults() {
85+
return new TextEmbeddingFloatResults(
86+
embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
87+
);
88+
}
89+
}
90+
91+
/**
92+
* Represents a single embedding entry in the response.
93+
* For the Elastic Inference Service, each entry is just an array of floats (no wrapper object).
94+
* This is a simpler wrapper that just holds the float array.
95+
*/
96+
public record EmbeddingFloatResultEntry(List<Float> embedding) {
97+
public static EmbeddingFloatResultEntry fromFloatArray(List<Float> floats) {
98+
return new EmbeddingFloatResultEntry(floats);
99+
}
100+
}
101+
102+
private ElasticInferenceServiceDenseTextEmbeddingsResponseEntity() {}
103+
}

0 commit comments

Comments
 (0)