Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public void testGetDefaultEndpoints() throws IOException {
var allModels = getAllModels();
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);

assertThat(allModels, hasSize(5));
assertThat(allModels, hasSize(7));
assertThat(chatCompletionModels, hasSize(1));

for (var model : chatCompletionModels) {
Expand All @@ -42,6 +42,8 @@ public void testGetDefaultEndpoints() throws IOException {

assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
assertInferenceIdTaskType(allModels, ".multilingual-embed-v1-elastic", TaskType.TEXT_EMBEDDING);
assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK);
}

private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;

public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

Expand Down Expand Up @@ -76,16 +77,21 @@ private Iterable<String> providers(List<Object> services) {
}

public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(18));

assertThat(
providersFor(TaskType.TEXT_EMBEDDING),
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"amazon_sagemaker",
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"elastic",
"elasticsearch",
"googleaistudio",
"googlevertexai",
Expand All @@ -95,8 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"openai",
"text_embedding_test_service",
"voyageai",
"watsonxai",
"amazon_sagemaker"
"watsonxai"
).toArray()
)
);
Expand All @@ -114,6 +119,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
"alibabacloud-ai-search",
"cohere",
"custom",
"elastic",
"elasticsearch",
"googlevertexai",
"jinaai",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ public void enqueueAuthorizeAllModelsResponse() {
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
},
{
"model_name": "multilingual-embed-v1",
"task_types": ["embed/text/dense"]
},
{
"model_name": "rerank-v1",
"task_types": ["rerank/text/text-similarity"]
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -43,6 +44,7 @@
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.mockito.Mockito.mock;

public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
Expand Down Expand Up @@ -94,7 +96,6 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
try (var service = createElasticInferenceService()) {
ensureAuthorizationCallFinished(service);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));

assertThat(
service.defaultConfigIds(),
is(
Expand Down Expand Up @@ -191,13 +192,21 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
String responseJson = """
{
"models": [
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
},
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
},
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
"model_name": "multilingual-embed-v1",
"task_types": ["embed/text/dense"]
},
{
"model_name": "rerank-v1",
"task_types": ["rerank/text/text-similarity"]
}
]
}
Expand All @@ -211,27 +220,48 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
containsInAnyOrder(
new InferenceService.DefaultConfigId(
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
),
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
service
),
new InferenceService.DefaultConfigId(
".multilingual-embed-v1-elastic",
MinimalServiceSettings.textEmbedding(
ElasticInferenceService.NAME,
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
DenseVectorFieldMapper.ElementType.FLOAT
),
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
service
)
service
),
new InferenceService.DefaultConfigId(
".rerank-v1-elastic",
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
service
)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
assertThat(
service.supportedTaskTypes(),
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING))
);

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
assertThat(
listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(),
is(".multilingual-embed-v1-elastic")
);
assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
assertThat(listener.actionGet(TIMEOUT).get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));

var getModelListener = new PlainActionFuture<UnparsedModel>();
// persists the default endpoints
Expand All @@ -249,6 +279,14 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
},
{
"model_name": "rerank-v1",
"task_types": ["rerank/text/text-similarity"]
},
{
"model_name": "multilingual-embed-v1",
"task_types": ["embed/text/dense"]
}
]
}
Expand All @@ -262,17 +300,33 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertThat(
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
)
containsInAnyOrder(
new InferenceService.DefaultConfigId(
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
),
new InferenceService.DefaultConfigId(
".multilingual-embed-v1-elastic",
MinimalServiceSettings.textEmbedding(
ElasticInferenceService.NAME,
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
DenseVectorFieldMapper.ElementType.FLOAT
),
service
),
new InferenceService.DefaultConfigId(
".rerank-v1-elastic",
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
service
)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
assertThat(
service.supportedTaskTypes(),
is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
);

var getModelListener = new PlainActionFuture<UnparsedModel>();
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.elastic;

import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity {

/**
* Parses the Elastic Inference Service Dense Text Embeddings response.
*
* For a request like:
*
* <pre>
* <code>
* {
* "inputs": ["Embed this text", "Embed this text, too"]
* }
* </code>
* </pre>
*
* The response would look like:
*
* <pre>
* <code>
* {
* "data": [
* [
* 2.1259406,
* 1.7073475,
* 0.9020516
* ],
* (...)
* ],
* "meta": {
* "usage": {...}
* }
* }
* </code>
* </pre>
*/
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
}
}

public record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> embeddingResults) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<EmbeddingFloatResult, Void> PARSER = new ConstructingObjectParser<>(
EmbeddingFloatResult.class.getSimpleName(),
true,
args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0])
);

static {
// Custom field declaration to handle array of arrays format
PARSER.declareField(constructorArg(), (parser, context) -> {
return XContentParserUtils.parseList(parser, (p, index) -> {
List<Float> embedding = XContentParserUtils.parseList(p, (innerParser, innerIndex) -> innerParser.floatValue());
return EmbeddingFloatResultEntry.fromFloatArray(embedding);
});
}, new ParseField("data"), org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY);
}

public TextEmbeddingFloatResults toTextEmbeddingFloatResults() {
return new TextEmbeddingFloatResults(
embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
);
}
}

/**
* Represents a single embedding entry in the response.
* For the Elastic Inference Service, each entry is just an array of floats (no wrapper object).
* This is a simpler wrapper that just holds the float array.
*/
public record EmbeddingFloatResultEntry(List<Float> embedding) {
public static EmbeddingFloatResultEntry fromFloatArray(List<Float> floats) {
return new EmbeddingFloatResultEntry(floats);
}
}

private ElasticInferenceServiceDenseTextEmbeddingsResponseEntity() {}
}
Loading