Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/130336.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 130336
summary: "[EIS] Rename the elser 2 default model and the default inference endpoint"
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,34 @@
import org.elasticsearch.inference.TaskType;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels;
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertTrue;

public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest {

public void testGetDefaultEndpoints() throws IOException {
var allModels = getAllModels();
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);

assertThat(allModels, hasSize(4));
assertThat(allModels, hasSize(5));
assertThat(chatCompletionModels, hasSize(1));

for (var model : chatCompletionModels) {
assertEquals("chat_completion", model.get("task_type"));
}

assertInferenceIdTaskType(allModels, ".elser-2-elastic", TaskType.SPARSE_EMBEDDING);
}

private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {
var model = models.stream().filter(m -> m.get("inference_id").equals(inferenceId)).findFirst();
assertTrue("could not find inference id: " + inferenceId, model.isPresent());
assertThat(model.get().get("task_type"), is(taskType.toString()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowS
"task_types": ["chat"]
},
{
"model_name": "elser-v2",
"model_name": "elser_model_2",
"task_types": ["embed/text/sparse"]
}
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;

public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
Expand Down Expand Up @@ -190,7 +192,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
"task_types": ["chat"]
},
{
"model_name": "elser-v2",
"model_name": "elser_model_2",
"task_types": ["embed/text/sparse"]
}
]
Expand All @@ -205,21 +207,17 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertThat(
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(),
service
)
)
containsInAnyOrder(
new InferenceService.DefaultConfigId(".elser-2-elastic", MinimalServiceSettings.sparseEmbedding(), service),
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic"));
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));

var getModelListener = new PlainActionFuture<UnparsedModel>();
// persists the default endpoints
Expand All @@ -235,7 +233,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
{
"models": [
{
"model_name": "elser-v2",
"model_name": "elser_model_2",
"task_types": ["embed/text/sparse"]
}
]
Expand All @@ -248,7 +246,12 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertTrue(service.defaultConfigIds().isEmpty());
assertThat(
service.defaultConfigIds(),
containsInAnyOrder(
new InferenceService.DefaultConfigId(".elser-2-elastic", MinimalServiceSettings.sparseEmbedding(), service)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));

var getModelListener = new PlainActionFuture<UnparsedModel>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ public class ElasticInferenceService extends SenderService {
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION);
private static final String SERVICE_NAME = "Elastic";
static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles";
static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);

// elser-2
static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2";
static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2");

/**
* The task types that the {@link InferenceAction.Request} can accept.
Expand Down Expand Up @@ -133,6 +137,19 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
elasticInferenceServiceComponents
),
MinimalServiceSettings.chatCompletion()
),
DEFAULT_ELSER_2_MODEL_ID,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this was missing - does this mean EIS-based ELSER was never available in 9.0?

new DefaultModelConfig(
new ElasticInferenceServiceSparseEmbeddingsModel(
DEFAULT_ELSER_ENDPOINT_ID_V2,
TaskType.SPARSE_EMBEDDING,
NAME,
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null, null),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
elasticInferenceServiceComponents
),
MinimalServiceSettings.sparseEmbedding()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,4 @@ public static boolean isValidModel(String model) {
return model != null && VALID_ELSER_MODEL_IDS.contains(model);
}

public static boolean isValidEisModel(String model) {
return ELSER_V2_MODEL.equals(model);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.isA;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -960,6 +964,18 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
},
{
"model_name": "elser_model_2",
"task_types": ["embed/text/sparse"]
},
{
"model_name": "multilingual-embed-v1",
"task_types": ["embed/text/dense"]
},
{
"model_name": "rerank-v1",
"task_types": ["rerank/text/text-similarity"]
}
]
}
Expand All @@ -976,15 +992,19 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(".elser-2-elastic", MinimalServiceSettings.sparseEmbedding(), service),
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)));

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
var models = listener.actionGet(TIMEOUT);
assertThat(models.size(), is(2));
assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic"));
assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.junit.Before;
Expand Down Expand Up @@ -165,6 +167,19 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints() {
ElasticInferenceServiceComponents.EMPTY_INSTANCE
),
MinimalServiceSettings.chatCompletion()
),
"elser-2",
new DefaultModelConfig(
new ElasticInferenceServiceSparseEmbeddingsModel(
defaultEndpointId("elser-2"),
TaskType.SPARSE_EMBEDDING,
"test",
new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-2", null, null),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
ElasticInferenceServiceComponents.EMPTY_INSTANCE
),
MinimalServiceSettings.sparseEmbedding()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,7 @@ public void testIsValidModel() {
assertTrue(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidModel(randomElserModel()));
}

public void testIsValidEisModel() {
assertTrue(
org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidEisModel(
org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL
)
);
}

public void testIsInvalidModel() {
assertFalse(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidModel("invalid"));
}

public void testIsInvalidEisModel() {
assertFalse(
org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL_LINUX_X86)
);
}
}