Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_17);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
Expand Down Expand Up @@ -214,6 +215,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00);
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_048_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void testGetDefaultEndpoints() throws IOException {
var allModels = getAllModels();
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);

assertThat(allModels, hasSize(5));
assertThat(allModels, hasSize(6));
assertThat(chatCompletionModels, hasSize(1));

for (var model : chatCompletionModels) {
Expand All @@ -35,6 +35,7 @@ public void testGetDefaultEndpoints() throws IOException {

assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
assertInferenceIdTaskType(allModels, ".multilingual-embed-elastic", TaskType.TEXT_EMBEDDING);
}

private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(15));
assertThat(services.size(), equalTo(16));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -79,6 +79,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"elastic",
"elasticsearch",
"googleaistudio",
"googlevertexai",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowS
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
},
{
"model_name": "multilingual-embed",
"task_types": ["embed/text/dense"]
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -197,6 +198,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
},
{
"model_name": "multilingual-embed",
"task_types": ["embed/text/dense"]
}
]
}
Expand All @@ -221,16 +226,33 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
service
),
new InferenceService.DefaultConfigId(
".multilingual-embed-elastic",
MinimalServiceSettings.textEmbedding(
ElasticInferenceService.NAME,
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
DenseVectorFieldMapper.ElementType.FLOAT
),
service
)
)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
assertThat(
service.supportedTaskTypes(),
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING))
);

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
assertThat(
listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(),
is(".multilingual-embed-elastic")
);

var getModelListener = new PlainActionFuture<UnparsedModel>();
// persists the default endpoints
Expand Down Expand Up @@ -267,6 +289,16 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
),
new InferenceService.DefaultConfigId(
".multilingual-embed-elastic",
MinimalServiceSettings.textEmbedding(
ElasticInferenceService.NAME,
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
DenseVectorFieldMapper.ElementType.FLOAT
),
service
)
)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext;

public class ElasticInferenceServiceDenseTextEmbeddingsRequest extends ElasticInferenceServiceRequest {

private final URI uri;
private final ElasticInferenceServiceDenseTextEmbeddingsModel model;
private final List<String> inputs;
private final TraceContextHandler traceContextHandler;
private final InputType inputType;

public ElasticInferenceServiceDenseTextEmbeddingsRequest(
ElasticInferenceServiceDenseTextEmbeddingsModel model,
List<String> inputs,
TraceContext traceContext,
ElasticInferenceServiceRequestMetadata metadata,
InputType inputType
) {
super(metadata);
this.inputs = inputs;
this.model = Objects.requireNonNull(model);
this.uri = model.uri();
this.traceContextHandler = new TraceContextHandler(traceContext);
this.inputType = inputType;
}

@Override
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(uri);
var usageContext = inputTypeToUsageContext(inputType);
var requestEntity = Strings.toString(
new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(inputs, model.getServiceSettings().modelId(), usageContext)
);

ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return httpPost;
}

public TraceContext getTraceContext() {
return traceContextHandler.traceContext();
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public URI getURI() {
return this.uri;
}

@Override
public Request truncate() {
return this;
}

@Override
public boolean[] getTruncationInfo() {
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public record ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
List<String> inputs,
String modelId,
@Nullable ElasticInferenceServiceUsageContext usageContext
) implements ToXContentObject {

private static final String INPUT_FIELD = "input";
private static final String MODEL_FIELD = "model";
private static final String USAGE_CONTEXT = "usage_context";

public ElasticInferenceServiceDenseTextEmbeddingsRequestEntity {
Objects.requireNonNull(inputs);
Objects.requireNonNull(modelId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startArray(INPUT_FIELD);

for (String input : inputs) {
builder.value(input);
}

builder.endArray();

builder.field(MODEL_FIELD, modelId);

// optional field
if ((usageContext == ElasticInferenceServiceUsageContext.UNSPECIFIED) == false) {
builder.field(USAGE_CONTEXT, usageContext);
}

builder.endObject();

return builder;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer
"embed/text/sparse",
TaskType.SPARSE_EMBEDDING,
"chat",
TaskType.CHAT_COMPLETION
TaskType.CHAT_COMPLETION,
"embed/text/dense",
TaskType.TEXT_EMBEDDING
);

@SuppressWarnings("unchecked")
Expand Down
Loading