Skip to content

Commit 5efba5b

Browse files
authored
[ML] Default inference endpoint for the multilingual-e5-small model (#114683)
1 parent 50c02f4 commit 5efba5b

File tree

8 files changed

+88
-36
lines changed

8 files changed

+88
-36
lines changed

docs/changelog/114683.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114683
2+
summary: Default inference endpoint for the multilingual-e5-small model
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

docs/reference/rest-api/usage.asciidoc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,12 @@ GET /_xpack/usage
210210
"service": "elasticsearch",
211211
"task_type": "SPARSE_EMBEDDING",
212212
"count": 1
213-
}
213+
},
214+
{
215+
"service": "elasticsearch",
216+
"task_type": "TEXT_EMBEDDING",
217+
"count": 1
218+
},
214219
]
215220
},
216221
"logstash" : {
Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
import static org.hamcrest.Matchers.is;
2323
import static org.hamcrest.Matchers.oneOf;
2424

25-
public class DefaultElserIT extends InferenceBaseRestTest {
25+
public class DefaultEndPointsIT extends InferenceBaseRestTest {
2626

2727
private TestThreadPool threadPool;
2828

2929
@Before
3030
public void createThreadPool() {
31-
threadPool = new TestThreadPool(DefaultElserIT.class.getSimpleName());
31+
threadPool = new TestThreadPool(DefaultEndPointsIT.class.getSimpleName());
3232
}
3333

3434
@After
@@ -38,7 +38,7 @@ public void tearDown() throws Exception {
3838
}
3939

4040
@SuppressWarnings("unchecked")
41-
public void testInferCreatesDefaultElser() throws IOException {
41+
public void testInferDeploysDefaultElser() throws IOException {
4242
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
4343
var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
4444
assertDefaultElserConfig(model);
@@ -67,4 +67,39 @@ private static void assertDefaultElserConfig(Map<String, Object> modelConfig) {
6767
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
6868
);
6969
}
70+
71+
@SuppressWarnings("unchecked")
72+
public void testInferDeploysDefaultE5() throws IOException {
73+
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
74+
var model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
75+
assertDefaultE5Config(model);
76+
77+
var inputs = List.of("Hello World", "Goodnight moon");
78+
var queryParams = Map.of("timeout", "120s");
79+
var results = infer(ElasticsearchInternalService.DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, inputs, queryParams);
80+
var embeddings = (List<Map<String, Object>>) results.get("text_embedding");
81+
assertThat(results.toString(), embeddings, hasSize(2));
82+
}
83+
84+
@SuppressWarnings("unchecked")
85+
private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
86+
assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_E5_ID, modelConfig.get("inference_id"));
87+
assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
88+
assertEquals(modelConfig.toString(), TaskType.TEXT_EMBEDDING.toString(), modelConfig.get("task_type"));
89+
90+
var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
91+
assertThat(
92+
modelConfig.toString(),
93+
serviceSettings.get("model_id"),
94+
is(oneOf(".multilingual-e5-small", ".multilingual-e5-small_linux-x86_64"))
95+
);
96+
assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));
97+
98+
var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
99+
assertThat(
100+
modelConfig.toString(),
101+
adaptiveAllocations,
102+
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
103+
);
104+
}
70105
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public void testCRUD() throws IOException {
4040
}
4141

4242
var getAllModels = getAllModels();
43-
int numModels = DefaultElserFeatureFlag.isEnabled() ? 10 : 9;
43+
int numModels = DefaultElserFeatureFlag.isEnabled() ? 11 : 9;
4444
assertThat(getAllModels, hasSize(numModels));
4545

4646
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
@@ -51,7 +51,8 @@ public void testCRUD() throws IOException {
5151
}
5252

5353
var getDenseModels = getModels("_all", TaskType.TEXT_EMBEDDING);
54-
assertThat(getDenseModels, hasSize(4));
54+
int numDenseModels = DefaultElserFeatureFlag.isEnabled() ? 5 : 4;
55+
assertThat(getDenseModels, hasSize(numDenseModels));
5556
for (var denseModel : getDenseModels) {
5657
assertEquals("text_embedding", denseModel.get("task_type"));
5758
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,12 @@ private void preferredVariantFromPlatformArchitecture(ActionListener<PreferredMo
239239
// However, in Elastic cloud ml nodes run on Linux x86
240240
delegate.onResponse(PreferredModelVariant.LINUX_X86_OPTIMIZED);
241241
} else {
242-
delegate.onResponse(PreferredModelVariant.PLATFORM_AGNOSTIC);
242+
boolean homogenous = architectures.size() == 1;
243+
if (homogenous && architectures.iterator().next().equals("linux-x86_64")) {
244+
delegate.onResponse(PreferredModelVariant.LINUX_X86_OPTIMIZED);
245+
} else {
246+
delegate.onResponse(PreferredModelVariant.PLATFORM_AGNOSTIC);
247+
}
243248
}
244249
}),
245250
client,
@@ -270,7 +275,7 @@ public static InferModelAction.Request buildInferenceRequest(
270275
return request;
271276
}
272277

273-
protected abstract boolean isDefaultId(String inferenceId);
278+
abstract boolean isDefaultId(String inferenceId);
274279

275280
protected void maybeStartDeployment(
276281
ElasticsearchInternalModel model,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
7979

8080
public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
8181
public static final String DEFAULT_ELSER_ID = ".elser-2";
82+
public static final String DEFAULT_E5_ID = ".multi-e5-small";
8283

8384
private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
8485
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);
@@ -389,14 +390,6 @@ private void elserCase(
389390
);
390391
}
391392

392-
if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(preferredModelVariant, esServiceSettingsBuilder.getModelId())) {
393-
throw new IllegalArgumentException(
394-
"Error parsing request config, model id does not match any models available on this platform. Was ["
395-
+ esServiceSettingsBuilder.getModelId()
396-
+ "]"
397-
);
398-
}
399-
400393
throwIfNotEmptyMap(config, name());
401394
throwIfNotEmptyMap(serviceSettingsMap, name());
402395

@@ -412,19 +405,6 @@ private void elserCase(
412405
);
413406
}
414407

415-
private static boolean modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(
416-
PreferredModelVariant preferredModelVariant,
417-
String modelId
418-
) {
419-
return modelId.equals(
420-
selectDefaultModelVariantBasedOnClusterArchitecture(
421-
preferredModelVariant,
422-
MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86,
423-
MULTILINGUAL_E5_SMALL_MODEL_ID
424-
)
425-
);
426-
}
427-
428408
@Override
429409
public Model parsePersistedConfigWithSecrets(
430410
String inferenceEntityId,
@@ -800,7 +780,10 @@ private RankedDocsResults textSimilarityResultsToRankedDocs(
800780
}
801781

802782
public List<DefaultConfigId> defaultConfigIds() {
803-
return List.of(new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this));
783+
return List.of(
784+
new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this),
785+
new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this)
786+
);
804787
}
805788

806789
@Override
@@ -876,13 +859,24 @@ private List<Model> defaultConfigs(boolean useLinuxOptimizedModel) {
876859
ElserMlNodeTaskSettings.DEFAULT,
877860
null // default chunking settings
878861
);
879-
880-
return List.of(defaultElser);
862+
var defaultE5 = new MultilingualE5SmallModel(
863+
DEFAULT_E5_ID,
864+
TaskType.TEXT_EMBEDDING,
865+
NAME,
866+
new MultilingualE5SmallInternalServiceSettings(
867+
null,
868+
1,
869+
useLinuxOptimizedModel ? MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 : MULTILINGUAL_E5_SMALL_MODEL_ID,
870+
new AdaptiveAllocationsSettings(Boolean.TRUE, 1, 8)
871+
),
872+
null // default chunking settings
873+
);
874+
return List.of(defaultElser, defaultE5);
881875
}
882876

883877
@Override
884-
protected boolean isDefaultId(String inferenceId) {
885-
return DEFAULT_ELSER_ID.equals(inferenceId);
878+
boolean isDefaultId(String inferenceId) {
879+
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
886880
}
887881

888882
static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,6 +1551,13 @@ public void testEmbeddingTypeFromTaskTypeAndSettings() {
15511551
assertThat(e.getMessage(), containsString("Chunking is not supported for task type [completion]"));
15521552
}
15531553

1554+
public void testIsDefaultId() {
1555+
var service = createService(mock(Client.class));
1556+
assertTrue(service.isDefaultId(".elser-2"));
1557+
assertTrue(service.isDefaultId(".multi-e5-small"));
1558+
assertFalse(service.isDefaultId("foo"));
1559+
}
1560+
15541561
private ElasticsearchInternalService createService(Client client) {
15551562
var cs = mock(ClusterService.class);
15561563
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,9 @@ protected void masterOperation(
234234
if (getModelResponse.getResources().results().size() > 1) {
235235
listener.onFailure(
236236
ExceptionsHelper.badRequestException(
237-
"cannot deploy more than one models at the same time; [{}] matches [{}] models]",
237+
"cannot deploy more than one model at the same time; [{}] matches models [{}]",
238238
request.getModelId(),
239-
getModelResponse.getResources().results().size()
239+
getModelResponse.getResources().results().stream().map(TrainedModelConfig::getModelId).toList()
240240
)
241241
);
242242
return;

0 commit comments

Comments
 (0)