Skip to content

Commit f054dca

Browse files
committed
Add working dense text embeddings integration with default endpoint. Some tests WIP
1 parent 6e4cb81 commit f054dca

19 files changed

+1092
-214
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ static TransportVersion def(int id) {
157157
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
158158
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
159159
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
160+
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_17);
160161
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
161162
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
162163
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@@ -214,6 +215,7 @@ static TransportVersion def(int id) {
214215
public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00);
215216
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
216217
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
218+
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_048_00_0);
217219

218220
/*
219221
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public void testGetDefaultEndpoints() throws IOException {
2626
var allModels = getAllModels();
2727
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
2828

29-
assertThat(allModels, hasSize(5));
29+
assertThat(allModels, hasSize(6));
3030
assertThat(chatCompletionModels, hasSize(1));
3131

3232
for (var model : chatCompletionModels) {
@@ -35,6 +35,7 @@ public void testGetDefaultEndpoints() throws IOException {
3535

3636
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
3737
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
38+
assertInferenceIdTaskType(allModels, ".multilingual-embed-elastic", TaskType.TEXT_EMBEDDING);
3839
}
3940

4041
private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
6464
@SuppressWarnings("unchecked")
6565
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
6666
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
67-
assertThat(services.size(), equalTo(15));
67+
assertThat(services.size(), equalTo(16));
6868

6969
String[] providers = new String[services.size()];
7070
for (int i = 0; i < services.size(); i++) {
@@ -79,6 +79,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
7979
"azureaistudio",
8080
"azureopenai",
8181
"cohere",
82+
"elastic",
8283
"elasticsearch",
8384
"googleaistudio",
8485
"googlevertexai",

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowS
3636
{
3737
"model_name": "elser-v2",
3838
"task_types": ["embed/text/sparse"]
39+
},
40+
{
41+
"model_name": "multilingual-embed",
42+
"task_types": ["embed/text/dense"]
3943
}
4044
]
4145
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.action.support.PlainActionFuture;
1212
import org.elasticsearch.common.settings.Settings;
1313
import org.elasticsearch.core.TimeValue;
14+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1415
import org.elasticsearch.inference.InferenceService;
1516
import org.elasticsearch.inference.MinimalServiceSettings;
1617
import org.elasticsearch.inference.Model;
@@ -197,6 +198,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
197198
{
198199
"model_name": "elser-v2",
199200
"task_types": ["embed/text/sparse"]
201+
},
202+
{
203+
"model_name": "multilingual-embed",
204+
"task_types": ["embed/text/dense"]
200205
}
201206
]
202207
}
@@ -221,16 +226,33 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
221226
".rainbow-sprinkles-elastic",
222227
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
223228
service
229+
),
230+
new InferenceService.DefaultConfigId(
231+
".multilingual-embed-elastic",
232+
MinimalServiceSettings.textEmbedding(
233+
ElasticInferenceService.NAME,
234+
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
235+
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
236+
DenseVectorFieldMapper.ElementType.FLOAT
237+
),
238+
service
224239
)
225240
)
226241
)
227242
);
228-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
243+
assertThat(
244+
service.supportedTaskTypes(),
245+
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING))
246+
);
229247

230248
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
231249
service.defaultConfigs(listener);
232250
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
233251
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
252+
assertThat(
253+
listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(),
254+
is(".multilingual-embed-elastic")
255+
);
234256

235257
var getModelListener = new PlainActionFuture<UnparsedModel>();
236258
// persists the default endpoints
@@ -267,6 +289,16 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
267289
".elser-v2-elastic",
268290
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
269291
service
292+
),
293+
new InferenceService.DefaultConfigId(
294+
".multilingual-embed-elastic",
295+
MinimalServiceSettings.textEmbedding(
296+
ElasticInferenceService.NAME,
297+
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
298+
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
299+
DenseVectorFieldMapper.ElementType.FLOAT
300+
),
301+
service
270302
)
271303
)
272304
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.client.methods.HttpRequestBase;
13+
import org.apache.http.entity.ByteArrayEntity;
14+
import org.apache.http.message.BasicHeader;
15+
import org.elasticsearch.common.Strings;
16+
import org.elasticsearch.inference.InputType;
17+
import org.elasticsearch.xcontent.XContentType;
18+
import org.elasticsearch.xpack.inference.external.request.Request;
19+
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
20+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
22+
23+
import java.net.URI;
24+
import java.nio.charset.StandardCharsets;
25+
import java.util.List;
26+
import java.util.Objects;
27+
28+
import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext;
29+
30+
public class ElasticInferenceServiceDenseTextEmbeddingsRequest extends ElasticInferenceServiceRequest {
31+
32+
private final URI uri;
33+
private final ElasticInferenceServiceDenseTextEmbeddingsModel model;
34+
private final List<String> inputs;
35+
private final TraceContextHandler traceContextHandler;
36+
private final InputType inputType;
37+
38+
public ElasticInferenceServiceDenseTextEmbeddingsRequest(
39+
ElasticInferenceServiceDenseTextEmbeddingsModel model,
40+
List<String> inputs,
41+
TraceContext traceContext,
42+
ElasticInferenceServiceRequestMetadata metadata,
43+
InputType inputType
44+
) {
45+
super(metadata);
46+
this.inputs = inputs;
47+
this.model = Objects.requireNonNull(model);
48+
this.uri = model.uri();
49+
this.traceContextHandler = new TraceContextHandler(traceContext);
50+
this.inputType = inputType;
51+
}
52+
53+
@Override
54+
public HttpRequestBase createHttpRequestBase() {
55+
var httpPost = new HttpPost(uri);
56+
var usageContext = inputTypeToUsageContext(inputType);
57+
var requestEntity = Strings.toString(
58+
new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(inputs, model.getServiceSettings().modelId(), usageContext)
59+
);
60+
61+
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
62+
httpPost.setEntity(byteEntity);
63+
64+
traceContextHandler.propagateTraceContext(httpPost);
65+
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
66+
67+
return httpPost;
68+
}
69+
70+
public TraceContext getTraceContext() {
71+
return traceContextHandler.traceContext();
72+
}
73+
74+
@Override
75+
public String getInferenceEntityId() {
76+
return model.getInferenceEntityId();
77+
}
78+
79+
@Override
80+
public URI getURI() {
81+
return this.uri;
82+
}
83+
84+
@Override
85+
public Request truncate() {
86+
return this;
87+
}
88+
89+
@Override
90+
public boolean[] getTruncationInfo() {
91+
return null;
92+
}
93+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.xcontent.ToXContentObject;
12+
import org.elasticsearch.xcontent.XContentBuilder;
13+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
14+
15+
import java.io.IOException;
16+
import java.util.List;
17+
import java.util.Objects;
18+
19+
public record ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
20+
List<String> inputs,
21+
String modelId,
22+
@Nullable ElasticInferenceServiceUsageContext usageContext
23+
) implements ToXContentObject {
24+
25+
private static final String INPUT_FIELD = "input";
26+
private static final String MODEL_FIELD = "model";
27+
private static final String USAGE_CONTEXT = "usage_context";
28+
29+
public ElasticInferenceServiceDenseTextEmbeddingsRequestEntity {
30+
Objects.requireNonNull(inputs);
31+
Objects.requireNonNull(modelId);
32+
}
33+
34+
@Override
35+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
36+
builder.startObject();
37+
builder.startArray(INPUT_FIELD);
38+
39+
for (String input : inputs) {
40+
builder.value(input);
41+
}
42+
43+
builder.endArray();
44+
45+
builder.field(MODEL_FIELD, modelId);
46+
47+
// optional field
48+
if ((usageContext == ElasticInferenceServiceUsageContext.UNSPECIFIED) == false) {
49+
builder.field(USAGE_CONTEXT, usageContext);
50+
}
51+
52+
builder.endObject();
53+
54+
return builder;
55+
}
56+
57+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer
4343
"embed/text/sparse",
4444
TaskType.SPARSE_EMBEDDING,
4545
"chat",
46-
TaskType.CHAT_COMPLETION
46+
TaskType.CHAT_COMPLETION,
47+
"embed/text/dense",
48+
TaskType.TEXT_EMBEDDING
4749
);
4850

4951
@SuppressWarnings("unchecked")

0 commit comments

Comments
 (0)