Skip to content

Commit 7d658e4

Browse files
committed
Default E5 endpoint
# Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
1 parent e833e7b commit 7d658e4

File tree

5 files changed

+78
-8
lines changed

5 files changed

+78
-8
lines changed
Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
import static org.hamcrest.Matchers.is;
2323
import static org.hamcrest.Matchers.oneOf;
2424

25-
public class DefaultElserIT extends InferenceBaseRestTest {
25+
public class DefaultEndPointsIT extends InferenceBaseRestTest {
2626

2727
private TestThreadPool threadPool;
2828

2929
@Before
3030
public void createThreadPool() {
31-
threadPool = new TestThreadPool(DefaultElserIT.class.getSimpleName());
31+
threadPool = new TestThreadPool(DefaultEndPointsIT.class.getSimpleName());
3232
}
3333

3434
@After
@@ -38,7 +38,7 @@ public void tearDown() throws Exception {
3838
}
3939

4040
@SuppressWarnings("unchecked")
41-
public void testInferCreatesDefaultElser() throws IOException {
41+
public void testInferDeploysDefaultElser() throws IOException {
4242
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
4343
var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
4444
assertDefaultElserConfig(model);
@@ -67,4 +67,39 @@ private static void assertDefaultElserConfig(Map<String, Object> modelConfig) {
6767
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
6868
);
6969
}
70+
71+
@SuppressWarnings("unchecked")
72+
public void testInferDeploysDefaultE5() throws IOException {
73+
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
74+
var model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
75+
assertDefaultE5Config(model);
76+
77+
var inputs = List.of("Hello World", "Goodnight moon");
78+
var queryParams = Map.of("timeout", "120s");
79+
var results = infer(ElasticsearchInternalService.DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, inputs, queryParams);
80+
var embeddings = (List<Map<String, Object>>) results.get("text_embedding");
81+
assertThat(results.toString(), embeddings, hasSize(2));
82+
}
83+
84+
@SuppressWarnings("unchecked")
85+
private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
86+
assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_E5_ID, modelConfig.get("inference_id"));
87+
assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
88+
assertEquals(modelConfig.toString(), TaskType.TEXT_EMBEDDING.toString(), modelConfig.get("task_type"));
89+
90+
var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
91+
assertThat(
92+
modelConfig.toString(),
93+
serviceSettings.get("model_id"),
94+
is(oneOf(".multilingual-e5-small", ".multilingual-e5-small_linux-x86_64"))
95+
);
96+
assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));
97+
98+
var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
99+
assertThat(
100+
modelConfig.toString(),
101+
adaptiveAllocations,
102+
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
103+
);
104+
}
70105
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ public static InferModelAction.Request buildInferenceRequest(
259259
return request;
260260
}
261261

262-
protected abstract boolean isDefaultId(String inferenceId);
262+
abstract boolean isDefaultId(String inferenceId);
263263

264264
protected void maybeStartDeployment(
265265
ElasticsearchInternalModel model,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
7878

7979
public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
8080
public static final String DEFAULT_ELSER_ID = ".elser-2";
81+
public static final String DEFAULT_E5_ID = ".default-multilingual-e5-small"; // TODO what to name this
8182

8283
private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
8384
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);
@@ -815,20 +816,47 @@ public List<UnparsedModel> defaultConfigs() {
815816
)
816817
);
817818

819+
// TODO Chunking settings
820+
Map<String, Object> e5Settings = Map.of(
821+
ModelConfigurations.SERVICE_SETTINGS,
822+
Map.of(
823+
ElasticsearchInternalServiceSettings.MODEL_ID,
824+
MULTILINGUAL_E5_SMALL_MODEL_ID, // TODO pick model depending on platform
825+
ElasticsearchInternalServiceSettings.NUM_THREADS,
826+
1,
827+
ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS,
828+
Map.of(
829+
"enabled",
830+
Boolean.TRUE,
831+
"min_number_of_allocations",
832+
1,
833+
"max_number_of_allocations",
834+
8 // no max?
835+
)
836+
)
837+
);
838+
818839
return List.of(
819840
new UnparsedModel(
820841
DEFAULT_ELSER_ID,
821842
TaskType.SPARSE_EMBEDDING,
822843
NAME,
823844
elserSettings,
824845
Map.of() // no secrets
846+
),
847+
new UnparsedModel(
848+
DEFAULT_E5_ID,
849+
TaskType.TEXT_EMBEDDING,
850+
NAME,
851+
e5Settings,
852+
Map.of() // no secrets
825853
)
826854
);
827855
}
828856

829857
@Override
830-
protected boolean isDefaultId(String inferenceId) {
831-
return DEFAULT_ELSER_ID.equals(inferenceId);
858+
boolean isDefaultId(String inferenceId) {
859+
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
832860
}
833861

834862
static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,6 +1541,13 @@ public void testEmbeddingTypeFromTaskTypeAndSettings() {
15411541
assertThat(e.getMessage(), containsString("Chunking is not supported for task type [completion]"));
15421542
}
15431543

1544+
public void testIsDefaultId() {
1545+
var service = createService(mock(Client.class));
1546+
assertTrue(service.isDefaultId(".elser-2"));
1547+
assertTrue(service.isDefaultId(".default-multilingual-e5-small")); // TODO name?
1548+
assertFalse(service.isDefaultId("foo"));
1549+
}
1550+
15441551
private ElasticsearchInternalService createService(Client client) {
15451552
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool);
15461553
return new ElasticsearchInternalService(context);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,9 @@ protected void masterOperation(
234234
if (getModelResponse.getResources().results().size() > 1) {
235235
listener.onFailure(
236236
ExceptionsHelper.badRequestException(
237-
"cannot deploy more than one models at the same time; [{}] matches [{}] models]",
237+
"cannot deploy more than one model at the same time; [{}] matches models [{}]",
238238
request.getModelId(),
239-
getModelResponse.getResources().results().size()
239+
getModelResponse.getResources().results().stream().map(TrainedModelConfig::getModelId).toList()
240240
)
241241
);
242242
return;

0 commit comments

Comments
 (0)