Skip to content

Commit c2cec39

Browse files
authored
[8.16][ML] Pick best model variant for the default elser endpoint (#114758)
* [ML] Pick best model variant for the default elser endpoint (#114690) # Conflicts: # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java # x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml * fix test * fix test
1 parent 2ca9d7a commit c2cec39

File tree

23 files changed

+448
-370
lines changed

23 files changed

+448
-370
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,22 @@ default boolean canStream(TaskType taskType) {
192192
return supportedStreamingTasks().contains(taskType);
193193
}
194194

195+
record DefaultConfigId(String inferenceId, TaskType taskType, InferenceService service) {};
196+
195197
/**
196-
* A service can define default configurations that can be
197-
* used out of the box without creating an endpoint first.
198-
* @return Default configurations provided by this service
198+
* Get the Ids and task type of any default configurations provided by this service
199+
* @return Defaults
199200
*/
200-
default List<UnparsedModel> defaultConfigs() {
201+
default List<DefaultConfigId> defaultConfigIds() {
201202
return List.of();
202203
}
204+
205+
/**
206+
* Call the listener with the default model configurations defined by
207+
* the service
208+
* @param defaultsListener The listener
209+
*/
210+
default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
211+
defaultsListener.onResponse(List.of());
212+
}
203213
}

server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
package org.elasticsearch.inference;
1111

1212
import org.elasticsearch.client.internal.Client;
13+
import org.elasticsearch.cluster.service.ClusterService;
14+
import org.elasticsearch.common.settings.Settings;
1315
import org.elasticsearch.threadpool.ThreadPool;
1416

1517
import java.util.List;
@@ -21,7 +23,7 @@ public interface InferenceServiceExtension {
2123

2224
List<Factory> getInferenceServiceFactories();
2325

24-
record InferenceServiceFactoryContext(Client client, ThreadPool threadPool) {}
26+
record InferenceServiceFactoryContext(Client client, ThreadPool threadPool, ClusterService clusterService, Settings settings) {}
2527

2628
interface Factory {
2729
/**

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ public final class MachineLearningField {
3737
Setting.Property.NodeScope
3838
);
3939

40+
public static final Setting<Integer> MAX_LAZY_ML_NODES = Setting.intSetting(
41+
"xpack.ml.max_lazy_ml_nodes",
42+
0,
43+
0,
44+
Setting.Property.OperatorDynamic,
45+
Setting.Property.NodeScope
46+
);
47+
4048
/**
4149
* This boolean value indicates if `max_machine_memory_percent` should be ignored and an automatic calculation is used instead.
4250
*

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 100 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.client.internal.Client;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.common.settings.Settings;
17+
import org.elasticsearch.inference.InferenceService;
1518
import org.elasticsearch.inference.InferenceServiceExtension;
1619
import org.elasticsearch.inference.Model;
1720
import org.elasticsearch.inference.ModelConfigurations;
@@ -46,6 +49,7 @@
4649
import java.util.concurrent.CountDownLatch;
4750
import java.util.concurrent.atomic.AtomicReference;
4851
import java.util.function.Consumer;
52+
import java.util.function.Function;
4953
import java.util.stream.Collectors;
5054

5155
import static org.hamcrest.CoreMatchers.equalTo;
@@ -56,6 +60,8 @@
5660
import static org.hamcrest.Matchers.instanceOf;
5761
import static org.hamcrest.Matchers.not;
5862
import static org.hamcrest.Matchers.nullValue;
63+
import static org.mockito.ArgumentMatchers.any;
64+
import static org.mockito.Mockito.doAnswer;
5965
import static org.mockito.Mockito.mock;
6066

6167
public class ModelRegistryIT extends ESSingleNodeTestCase {
@@ -121,7 +127,12 @@ public void testGetModel() throws Exception {
121127
assertEquals(model.getConfigurations().getService(), modelHolder.get().service());
122128

123129
var elserService = new ElasticsearchInternalService(
124-
new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class))
130+
new InferenceServiceExtension.InferenceServiceFactoryContext(
131+
mock(Client.class),
132+
mock(ThreadPool.class),
133+
mock(ClusterService.class),
134+
Settings.EMPTY
135+
)
125136
);
126137
ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets(
127138
modelHolder.get().inferenceEntityId(),
@@ -282,26 +293,38 @@ public void testGetModelWithSecrets() throws InterruptedException {
282293
}
283294

284295
public void testGetAllModels_WithDefaults() throws Exception {
285-
var service = "foo";
286-
var secret = "abc";
296+
var serviceName = "foo";
287297
int configuredModelCount = 10;
288298
int defaultModelCount = 2;
289299
int totalModelCount = 12;
290300

291-
var defaultConfigs = new HashMap<String, UnparsedModel>();
301+
var service = mock(InferenceService.class);
302+
303+
var defaultConfigs = new ArrayList<Model>();
304+
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
292305
for (int i = 0; i < defaultModelCount; i++) {
293306
var id = "default-" + i;
294-
defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
307+
var taskType = randomFrom(TaskType.values());
308+
defaultConfigs.add(createModel(id, taskType, serviceName));
309+
defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
295310
}
296-
defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
311+
312+
doAnswer(invocation -> {
313+
@SuppressWarnings("unchecked")
314+
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
315+
listener.onResponse(defaultConfigs);
316+
return Void.TYPE;
317+
}).when(service).defaultConfigs(any());
318+
319+
defaultIds.forEach(modelRegistry::addDefaultIds);
297320

298321
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
299322
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
300323

301324
var createdModels = new HashMap<String, Model>();
302325
for (int i = 0; i < configuredModelCount; i++) {
303326
var id = randomAlphaOfLength(5) + i;
304-
var model = createModel(id, randomFrom(TaskType.values()), service);
327+
var model = createModel(id, randomFrom(TaskType.values()), serviceName);
305328
createdModels.put(id, model);
306329
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
307330
assertThat(putModelHolder.get(), is(true));
@@ -315,16 +338,22 @@ public void testGetAllModels_WithDefaults() throws Exception {
315338
var getAllModels = modelHolder.get();
316339
assertReturnModelIsModifiable(modelHolder.get().get(0));
317340

341+
// same result but configs should have been persisted this time
342+
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
343+
assertNull(exceptionHolder.get());
344+
assertThat(modelHolder.get(), hasSize(totalModelCount));
345+
318346
// sort in the same order as the returned models
319-
var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
347+
var ids = new ArrayList<>(defaultIds.stream().map(InferenceService.DefaultConfigId::inferenceId).toList());
320348
ids.addAll(createdModels.keySet().stream().toList());
321349
ids.sort(String::compareTo);
350+
var configsById = defaultConfigs.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity()));
322351
for (int i = 0; i < totalModelCount; i++) {
323352
var id = ids.get(i);
324353
assertEquals(id, getAllModels.get(i).inferenceEntityId());
325354
if (id.startsWith("default")) {
326-
assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
327-
assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
355+
assertEquals(configsById.get(id).getTaskType(), getAllModels.get(i).taskType());
356+
assertEquals(configsById.get(id).getConfigurations().getService(), getAllModels.get(i).service());
328357
} else {
329358
assertEquals(createdModels.get(id).getTaskType(), getAllModels.get(i).taskType());
330359
assertEquals(createdModels.get(id).getConfigurations().getService(), getAllModels.get(i).service());
@@ -333,16 +362,27 @@ public void testGetAllModels_WithDefaults() throws Exception {
333362
}
334363

335364
public void testGetAllModels_OnlyDefaults() throws Exception {
336-
var service = "foo";
337-
var secret = "abc";
338365
int defaultModelCount = 2;
366+
var serviceName = "foo";
367+
var service = mock(InferenceService.class);
339368

340-
var defaultConfigs = new HashMap<String, UnparsedModel>();
369+
var defaultConfigs = new ArrayList<Model>();
370+
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
341371
for (int i = 0; i < defaultModelCount; i++) {
342372
var id = "default-" + i;
343-
defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
373+
var taskType = randomFrom(TaskType.values());
374+
defaultConfigs.add(createModel(id, taskType, serviceName));
375+
defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
344376
}
345-
defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
377+
378+
doAnswer(invocation -> {
379+
@SuppressWarnings("unchecked")
380+
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
381+
listener.onResponse(defaultConfigs);
382+
return Void.TYPE;
383+
}).when(service).defaultConfigs(any());
384+
385+
defaultIds.forEach(modelRegistry::addDefaultIds);
346386

347387
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
348388
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
@@ -353,31 +393,42 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
353393
assertReturnModelIsModifiable(modelHolder.get().get(0));
354394

355395
// sort in the same order as the returned models
356-
var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
396+
var configsById = defaultConfigs.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity()));
397+
var ids = new ArrayList<>(configsById.keySet().stream().toList());
357398
ids.sort(String::compareTo);
358399
for (int i = 0; i < defaultModelCount; i++) {
359400
var id = ids.get(i);
360401
assertEquals(id, getAllModels.get(i).inferenceEntityId());
361-
assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
362-
assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
402+
assertEquals(configsById.get(id).getTaskType(), getAllModels.get(i).taskType());
403+
assertEquals(configsById.get(id).getConfigurations().getService(), getAllModels.get(i).service());
363404
}
364405
}
365406

366407
public void testGet_WithDefaults() throws InterruptedException {
367-
var service = "foo";
368-
var secret = "abc";
408+
var serviceName = "foo";
409+
var service = mock(InferenceService.class);
410+
411+
var defaultConfigs = new ArrayList<Model>();
412+
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
369413

370-
var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
371-
var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
414+
defaultConfigs.add(createModel("default-sparse", TaskType.SPARSE_EMBEDDING, serviceName));
415+
defaultConfigs.add(createModel("default-text", TaskType.TEXT_EMBEDDING, serviceName));
416+
defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
417+
defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
372418

373-
modelRegistry.addDefaultConfiguration(defaultSparse);
374-
modelRegistry.addDefaultConfiguration(defaultText);
419+
doAnswer(invocation -> {
420+
@SuppressWarnings("unchecked")
421+
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
422+
listener.onResponse(defaultConfigs);
423+
return Void.TYPE;
424+
}).when(service).defaultConfigs(any());
425+
defaultIds.forEach(modelRegistry::addDefaultIds);
375426

376427
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
377428
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
378429

379-
var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
380-
var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
430+
var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName);
431+
var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName);
381432
blockingCall(listener -> modelRegistry.storeModel(configured1, listener), putModelHolder, exceptionHolder);
382433
assertThat(putModelHolder.get(), is(true));
383434
blockingCall(listener -> modelRegistry.storeModel(configured2, listener), putModelHolder, exceptionHolder);
@@ -386,6 +437,7 @@ public void testGet_WithDefaults() throws InterruptedException {
386437

387438
AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
388439
blockingCall(listener -> modelRegistry.getModel("default-sparse", listener), modelHolder, exceptionHolder);
440+
assertNull(exceptionHolder.get());
389441
assertEquals("default-sparse", modelHolder.get().inferenceEntityId());
390442
assertEquals(TaskType.SPARSE_EMBEDDING, modelHolder.get().taskType());
391443
assertReturnModelIsModifiable(modelHolder.get());
@@ -400,23 +452,32 @@ public void testGet_WithDefaults() throws InterruptedException {
400452
}
401453

402454
public void testGetByTaskType_WithDefaults() throws Exception {
403-
var service = "foo";
404-
var secret = "abc";
405-
406-
var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
407-
var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
408-
var defaultChat = createUnparsedConfig("default-chat", TaskType.COMPLETION, service, secret);
409-
410-
modelRegistry.addDefaultConfiguration(defaultSparse);
411-
modelRegistry.addDefaultConfiguration(defaultText);
412-
modelRegistry.addDefaultConfiguration(defaultChat);
455+
var serviceName = "foo";
456+
457+
var defaultSparse = createModel("default-sparse", TaskType.SPARSE_EMBEDDING, serviceName);
458+
var defaultText = createModel("default-text", TaskType.TEXT_EMBEDDING, serviceName);
459+
var defaultChat = createModel("default-chat", TaskType.COMPLETION, serviceName);
460+
461+
var service = mock(InferenceService.class);
462+
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
463+
defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
464+
defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
465+
defaultIds.add(new InferenceService.DefaultConfigId("default-chat", TaskType.COMPLETION, service));
466+
467+
doAnswer(invocation -> {
468+
@SuppressWarnings("unchecked")
469+
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
470+
listener.onResponse(List.of(defaultSparse, defaultChat, defaultText));
471+
return Void.TYPE;
472+
}).when(service).defaultConfigs(any());
473+
defaultIds.forEach(modelRegistry::addDefaultIds);
413474

414475
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
415476
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
416477

417-
var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, service);
418-
var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, service);
419-
var configuredRerank = createModel("configured-rerank", TaskType.RERANK, service);
478+
var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, serviceName);
479+
var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, serviceName);
480+
var configuredRerank = createModel("configured-rerank", TaskType.RERANK, serviceName);
420481
blockingCall(listener -> modelRegistry.storeModel(configuredSparse, listener), putModelHolder, exceptionHolder);
421482
assertThat(putModelHolder.get(), is(true));
422483
blockingCall(listener -> modelRegistry.storeModel(configuredText, listener), putModelHolder, exceptionHolder);
@@ -530,10 +591,6 @@ public static Model createModelWithSecrets(String inferenceEntityId, TaskType ta
530591
);
531592
}
532593

533-
public static UnparsedModel createUnparsedConfig(String inferenceEntityId, TaskType taskType, String service, String secret) {
534-
return new UnparsedModel(inferenceEntityId, taskType, service, Map.of("a", "b"), Map.of("secret", secret));
535-
}
536-
537594
private static class TestModelOfAnyKind extends ModelConfigurations {
538595

539596
record TestModelServiceSettings() implements ServiceSettings {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,21 @@ public Collection<?> createComponents(PluginServices services) {
212212
);
213213
}
214214

215-
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client(), services.threadPool());
215+
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(
216+
services.client(),
217+
services.threadPool(),
218+
services.clusterService(),
219+
settings
220+
);
221+
216222
// This must be done after the HttpRequestSenderFactory is created so that the services can get the
217223
// reference correctly
218224
var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
219225
registry.init(services.client());
220-
for (var service : registry.getServices().values()) {
221-
service.defaultConfigs().forEach(modelRegistry::addDefaultConfiguration);
226+
if (DefaultElserFeatureFlag.isEnabled()) {
227+
for (var service : registry.getServices().values()) {
228+
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
229+
}
222230
}
223231
inferenceServiceRegistry.set(registry);
224232

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
3535
ChunkingSettingsOptions.SENTENCE_OVERLAP.toString()
3636
);
3737

38-
private static int DEFAULT_OVERLAP = 0;
38+
private static int DEFAULT_OVERLAP = 1;
3939

4040
protected final int maxChunkSize;
4141
protected int sentenceOverlap = DEFAULT_OVERLAP;
@@ -69,17 +69,18 @@ public static SentenceBoundaryChunkingSettings fromMap(Map<String, Object> map)
6969
validationException
7070
);
7171

72-
Integer sentenceOverlap = ServiceUtils.extractOptionalPositiveInteger(
72+
Integer sentenceOverlap = ServiceUtils.removeAsType(
7373
map,
7474
ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(),
75-
ModelConfigurations.CHUNKING_SETTINGS,
75+
Integer.class,
7676
validationException
7777
);
78-
79-
if (sentenceOverlap != null && sentenceOverlap > 1) {
78+
if (sentenceOverlap == null) {
79+
sentenceOverlap = DEFAULT_OVERLAP;
80+
} else if (sentenceOverlap > 1 || sentenceOverlap < 0) {
8081
validationException.addValidationError(
81-
ChunkingSettingsOptions.SENTENCE_OVERLAP.toString() + "[" + sentenceOverlap + "] must be either 0 or 1"
82-
); // todo better
82+
ChunkingSettingsOptions.SENTENCE_OVERLAP + "[" + sentenceOverlap + "] must be either 0 or 1"
83+
);
8384
}
8485

8586
if (validationException.validationErrors().isEmpty() == false) {

0 commit comments

Comments
 (0)