Skip to content

Commit 74522c4

Browse files
authored
[ML] Pick best model variant for the default elser endpoint (#114690)
1 parent 9eab11c commit 74522c4

File tree

23 files changed

+444
-355
lines changed

23 files changed

+444
-355
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,22 @@ default boolean canStream(TaskType taskType) {
192192
return supportedStreamingTasks().contains(taskType);
193193
}
194194

195+
record DefaultConfigId(String inferenceId, TaskType taskType, InferenceService service) {};
196+
195197
/**
196-
* A service can define default configurations that can be
197-
* used out of the box without creating an endpoint first.
198-
* @return Default configurations provided by this service
198+
* Get the Ids and task type of any default configurations provided by this service
199+
* @return Defaults
199200
*/
200-
default List<UnparsedModel> defaultConfigs() {
201+
default List<DefaultConfigId> defaultConfigIds() {
201202
return List.of();
202203
}
204+
205+
/**
206+
* Call the listener with the default model configurations defined by
207+
* the service
208+
* @param defaultsListener The listener
209+
*/
210+
default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
211+
defaultsListener.onResponse(List.of());
212+
}
203213
}

server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
package org.elasticsearch.inference;
1111

1212
import org.elasticsearch.client.internal.Client;
13+
import org.elasticsearch.cluster.service.ClusterService;
14+
import org.elasticsearch.common.settings.Settings;
1315
import org.elasticsearch.threadpool.ThreadPool;
1416

1517
import java.util.List;
@@ -21,7 +23,7 @@ public interface InferenceServiceExtension {
2123

2224
List<Factory> getInferenceServiceFactories();
2325

24-
record InferenceServiceFactoryContext(Client client, ThreadPool threadPool) {}
26+
record InferenceServiceFactoryContext(Client client, ThreadPool threadPool, ClusterService clusterService, Settings settings) {}
2527

2628
interface Factory {
2729
/**

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ public final class MachineLearningField {
3737
Setting.Property.NodeScope
3838
);
3939

40+
public static final Setting<Integer> MAX_LAZY_ML_NODES = Setting.intSetting(
41+
"xpack.ml.max_lazy_ml_nodes",
42+
0,
43+
0,
44+
Setting.Property.OperatorDynamic,
45+
Setting.Property.NodeScope
46+
);
47+
4048
/**
4149
* This boolean value indicates if `max_machine_memory_percent` should be ignored and an automatic calculation is used instead.
4250
*

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 100 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.client.internal.Client;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.common.settings.Settings;
17+
import org.elasticsearch.inference.InferenceService;
1518
import org.elasticsearch.inference.InferenceServiceExtension;
1619
import org.elasticsearch.inference.Model;
1720
import org.elasticsearch.inference.ModelConfigurations;
@@ -47,6 +50,7 @@
4750
import java.util.concurrent.CountDownLatch;
4851
import java.util.concurrent.atomic.AtomicReference;
4952
import java.util.function.Consumer;
53+
import java.util.function.Function;
5054
import java.util.stream.Collectors;
5155

5256
import static org.hamcrest.CoreMatchers.equalTo;
@@ -57,6 +61,8 @@
5761
import static org.hamcrest.Matchers.instanceOf;
5862
import static org.hamcrest.Matchers.not;
5963
import static org.hamcrest.Matchers.nullValue;
64+
import static org.mockito.ArgumentMatchers.any;
65+
import static org.mockito.Mockito.doAnswer;
6066
import static org.mockito.Mockito.mock;
6167

6268
public class ModelRegistryIT extends ESSingleNodeTestCase {
@@ -122,7 +128,12 @@ public void testGetModel() throws Exception {
122128
assertEquals(model.getConfigurations().getService(), modelHolder.get().service());
123129

124130
var elserService = new ElasticsearchInternalService(
125-
new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class))
131+
new InferenceServiceExtension.InferenceServiceFactoryContext(
132+
mock(Client.class),
133+
mock(ThreadPool.class),
134+
mock(ClusterService.class),
135+
Settings.EMPTY
136+
)
126137
);
127138
ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets(
128139
modelHolder.get().inferenceEntityId(),
@@ -283,26 +294,38 @@ public void testGetModelWithSecrets() throws InterruptedException {
283294
}
284295

285296
public void testGetAllModels_WithDefaults() throws Exception {
286-
var service = "foo";
287-
var secret = "abc";
297+
var serviceName = "foo";
288298
int configuredModelCount = 10;
289299
int defaultModelCount = 2;
290300
int totalModelCount = 12;
291301

292-
var defaultConfigs = new HashMap<String, UnparsedModel>();
302+
var service = mock(InferenceService.class);
303+
304+
var defaultConfigs = new ArrayList<Model>();
305+
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
293306
for (int i = 0; i < defaultModelCount; i++) {
294307
var id = "default-" + i;
295-
defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
308+
var taskType = randomFrom(TaskType.values());
309+
defaultConfigs.add(createModel(id, taskType, serviceName));
310+
defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
296311
}
297-
defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
312+
313+
doAnswer(invocation -> {
314+
@SuppressWarnings("unchecked")
315+
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
316+
listener.onResponse(defaultConfigs);
317+
return Void.TYPE;
318+
}).when(service).defaultConfigs(any());
319+
320+
defaultIds.forEach(modelRegistry::addDefaultIds);
298321

299322
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
300323
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
301324

302325
var createdModels = new HashMap<String, Model>();
303326
for (int i = 0; i < configuredModelCount; i++) {
304327
var id = randomAlphaOfLength(5) + i;
305-
var model = createModel(id, randomFrom(TaskType.values()), service);
328+
var model = createModel(id, randomFrom(TaskType.values()), serviceName);
306329
createdModels.put(id, model);
307330
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
308331
assertThat(putModelHolder.get(), is(true));
@@ -316,16 +339,22 @@ public void testGetAllModels_WithDefaults() throws Exception {
316339
var getAllModels = modelHolder.get();
317340
assertReturnModelIsModifiable(modelHolder.get().get(0));
318341

342+
// same result but configs should have been persisted this time
343+
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
344+
assertNull(exceptionHolder.get());
345+
assertThat(modelHolder.get(), hasSize(totalModelCount));
346+
319347
// sort in the same order as the returned models
320-
var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
348+
var ids = new ArrayList<>(defaultIds.stream().map(InferenceService.DefaultConfigId::inferenceId).toList());
321349
ids.addAll(createdModels.keySet().stream().toList());
322350
ids.sort(String::compareTo);
351+
var configsById = defaultConfigs.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity()));
323352
for (int i = 0; i < totalModelCount; i++) {
324353
var id = ids.get(i);
325354
assertEquals(id, getAllModels.get(i).inferenceEntityId());
326355
if (id.startsWith("default")) {
327-
assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
328-
assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
356+
assertEquals(configsById.get(id).getTaskType(), getAllModels.get(i).taskType());
357+
assertEquals(configsById.get(id).getConfigurations().getService(), getAllModels.get(i).service());
329358
} else {
330359
assertEquals(createdModels.get(id).getTaskType(), getAllModels.get(i).taskType());
331360
assertEquals(createdModels.get(id).getConfigurations().getService(), getAllModels.get(i).service());
@@ -334,16 +363,27 @@ public void testGetAllModels_WithDefaults() throws Exception {
334363
}
335364

336365
public void testGetAllModels_OnlyDefaults() throws Exception {
337-
var service = "foo";
338-
var secret = "abc";
339366
int defaultModelCount = 2;
367+
var serviceName = "foo";
368+
var service = mock(InferenceService.class);
340369

341-
var defaultConfigs = new HashMap<String, UnparsedModel>();
370+
var defaultConfigs = new ArrayList<Model>();
371+
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
342372
for (int i = 0; i < defaultModelCount; i++) {
343373
var id = "default-" + i;
344-
defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
374+
var taskType = randomFrom(TaskType.values());
375+
defaultConfigs.add(createModel(id, taskType, serviceName));
376+
defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
345377
}
346-
defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
378+
379+
doAnswer(invocation -> {
380+
@SuppressWarnings("unchecked")
381+
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
382+
listener.onResponse(defaultConfigs);
383+
return Void.TYPE;
384+
}).when(service).defaultConfigs(any());
385+
386+
defaultIds.forEach(modelRegistry::addDefaultIds);
347387

348388
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
349389
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
@@ -354,31 +394,42 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
354394
assertReturnModelIsModifiable(modelHolder.get().get(0));
355395

356396
// sort in the same order as the returned models
357-
var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
397+
var configsById = defaultConfigs.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity()));
398+
var ids = new ArrayList<>(configsById.keySet().stream().toList());
358399
ids.sort(String::compareTo);
359400
for (int i = 0; i < defaultModelCount; i++) {
360401
var id = ids.get(i);
361402
assertEquals(id, getAllModels.get(i).inferenceEntityId());
362-
assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
363-
assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
403+
assertEquals(configsById.get(id).getTaskType(), getAllModels.get(i).taskType());
404+
assertEquals(configsById.get(id).getConfigurations().getService(), getAllModels.get(i).service());
364405
}
365406
}
366407

367408
public void testGet_WithDefaults() throws InterruptedException {
368-
var service = "foo";
369-
var secret = "abc";
409+
var serviceName = "foo";
410+
var service = mock(InferenceService.class);
411+
412+
var defaultConfigs = new ArrayList<Model>();
413+
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
370414

371-
var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
372-
var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
415+
defaultConfigs.add(createModel("default-sparse", TaskType.SPARSE_EMBEDDING, serviceName));
416+
defaultConfigs.add(createModel("default-text", TaskType.TEXT_EMBEDDING, serviceName));
417+
defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
418+
defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
373419

374-
modelRegistry.addDefaultConfiguration(defaultSparse);
375-
modelRegistry.addDefaultConfiguration(defaultText);
420+
doAnswer(invocation -> {
421+
@SuppressWarnings("unchecked")
422+
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
423+
listener.onResponse(defaultConfigs);
424+
return Void.TYPE;
425+
}).when(service).defaultConfigs(any());
426+
defaultIds.forEach(modelRegistry::addDefaultIds);
376427

377428
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
378429
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
379430

380-
var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
381-
var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
431+
var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName);
432+
var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName);
382433
blockingCall(listener -> modelRegistry.storeModel(configured1, listener), putModelHolder, exceptionHolder);
383434
assertThat(putModelHolder.get(), is(true));
384435
blockingCall(listener -> modelRegistry.storeModel(configured2, listener), putModelHolder, exceptionHolder);
@@ -387,6 +438,7 @@ public void testGet_WithDefaults() throws InterruptedException {
387438

388439
AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
389440
blockingCall(listener -> modelRegistry.getModel("default-sparse", listener), modelHolder, exceptionHolder);
441+
assertNull(exceptionHolder.get());
390442
assertEquals("default-sparse", modelHolder.get().inferenceEntityId());
391443
assertEquals(TaskType.SPARSE_EMBEDDING, modelHolder.get().taskType());
392444
assertReturnModelIsModifiable(modelHolder.get());
@@ -401,23 +453,32 @@ public void testGet_WithDefaults() throws InterruptedException {
401453
}
402454

403455
public void testGetByTaskType_WithDefaults() throws Exception {
404-
var service = "foo";
405-
var secret = "abc";
406-
407-
var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
408-
var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
409-
var defaultChat = createUnparsedConfig("default-chat", TaskType.COMPLETION, service, secret);
410-
411-
modelRegistry.addDefaultConfiguration(defaultSparse);
412-
modelRegistry.addDefaultConfiguration(defaultText);
413-
modelRegistry.addDefaultConfiguration(defaultChat);
456+
var serviceName = "foo";
457+
458+
var defaultSparse = createModel("default-sparse", TaskType.SPARSE_EMBEDDING, serviceName);
459+
var defaultText = createModel("default-text", TaskType.TEXT_EMBEDDING, serviceName);
460+
var defaultChat = createModel("default-chat", TaskType.COMPLETION, serviceName);
461+
462+
var service = mock(InferenceService.class);
463+
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
464+
defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
465+
defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
466+
defaultIds.add(new InferenceService.DefaultConfigId("default-chat", TaskType.COMPLETION, service));
467+
468+
doAnswer(invocation -> {
469+
@SuppressWarnings("unchecked")
470+
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
471+
listener.onResponse(List.of(defaultSparse, defaultChat, defaultText));
472+
return Void.TYPE;
473+
}).when(service).defaultConfigs(any());
474+
defaultIds.forEach(modelRegistry::addDefaultIds);
414475

415476
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
416477
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
417478

418-
var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, service);
419-
var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, service);
420-
var configuredRerank = createModel("configured-rerank", TaskType.RERANK, service);
479+
var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, serviceName);
480+
var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, serviceName);
481+
var configuredRerank = createModel("configured-rerank", TaskType.RERANK, serviceName);
421482
blockingCall(listener -> modelRegistry.storeModel(configuredSparse, listener), putModelHolder, exceptionHolder);
422483
assertThat(putModelHolder.get(), is(true));
423484
blockingCall(listener -> modelRegistry.storeModel(configuredText, listener), putModelHolder, exceptionHolder);
@@ -531,10 +592,6 @@ public static Model createModelWithSecrets(String inferenceEntityId, TaskType ta
531592
);
532593
}
533594

534-
public static UnparsedModel createUnparsedConfig(String inferenceEntityId, TaskType taskType, String service, String secret) {
535-
return new UnparsedModel(inferenceEntityId, taskType, service, Map.of("a", "b"), Map.of("secret", secret));
536-
}
537-
538595
private static class TestModelOfAnyKind extends ModelConfigurations {
539596

540597
record TestModelServiceSettings() implements ServiceSettings {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,14 +212,20 @@ public Collection<?> createComponents(PluginServices services) {
212212
);
213213
}
214214

215-
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client(), services.threadPool());
215+
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(
216+
services.client(),
217+
services.threadPool(),
218+
services.clusterService(),
219+
settings
220+
);
221+
216222
// This must be done after the HttpRequestSenderFactory is created so that the services can get the
217223
// reference correctly
218224
var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
219225
registry.init(services.client());
220226
if (DefaultElserFeatureFlag.isEnabled()) {
221227
for (var service : registry.getServices().values()) {
222-
service.defaultConfigs().forEach(modelRegistry::addDefaultConfiguration);
228+
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
223229
}
224230
}
225231
inferenceServiceRegistry.set(registry);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
3535
ChunkingSettingsOptions.SENTENCE_OVERLAP.toString()
3636
);
3737

38-
private static int DEFAULT_OVERLAP = 0;
38+
private static int DEFAULT_OVERLAP = 1;
3939

4040
protected final int maxChunkSize;
4141
protected int sentenceOverlap = DEFAULT_OVERLAP;
@@ -69,17 +69,18 @@ public static SentenceBoundaryChunkingSettings fromMap(Map<String, Object> map)
6969
validationException
7070
);
7171

72-
Integer sentenceOverlap = ServiceUtils.extractOptionalPositiveInteger(
72+
Integer sentenceOverlap = ServiceUtils.removeAsType(
7373
map,
7474
ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(),
75-
ModelConfigurations.CHUNKING_SETTINGS,
75+
Integer.class,
7676
validationException
7777
);
78-
79-
if (sentenceOverlap != null && sentenceOverlap > 1) {
78+
if (sentenceOverlap == null) {
79+
sentenceOverlap = DEFAULT_OVERLAP;
80+
} else if (sentenceOverlap > 1 || sentenceOverlap < 0) {
8081
validationException.addValidationError(
81-
ChunkingSettingsOptions.SENTENCE_OVERLAP.toString() + "[" + sentenceOverlap + "] must be either 0 or 1"
82-
); // todo better
82+
ChunkingSettingsOptions.SENTENCE_OVERLAP + "[" + sentenceOverlap + "] must be either 0 or 1"
83+
);
8384
}
8485

8586
if (validationException.validationErrors().isEmpty() == false) {

0 commit comments

Comments
 (0)