Skip to content

Commit c43f449

Browse files
committed
Add option not to persist for get all
1 parent 831d55f commit c43f449

File tree

8 files changed

+128
-23
lines changed

8 files changed

+128
-23
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ static TransportVersion def(int id) {
176176
public static final TransportVersion CONVERT_FAILURE_STORE_OPTIONS_TO_SELECTOR_OPTIONS_INTERNALLY = def(8_772_00_0);
177177
public static final TransportVersion REMOVE_MIN_COMPATIBLE_SHARD_NODE = def(8_773_00_0);
178178
public static final TransportVersion REVERT_REMOVE_MIN_COMPATIBLE_SHARD_NODE = def(8_774_00_0);
179+
public static final TransportVersion INFERENCE_DONT_PERSIST_ON_READ = def(8_775_00_0);
179180

180181
/*
181182
* STOP! READ THIS FIRST! No, really,

test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,8 +1121,6 @@ protected static void wipeAllIndices(boolean preserveSecurityIndices) throws IOE
11211121
if (preserveSecurityIndices) {
11221122
indexPatterns.add("-.security-*");
11231123
}
1124-
// always preserve inference index
1125-
indexPatterns.add("-.inference");
11261124
final Request deleteRequest = new Request("DELETE", Strings.collectionToCommaDelimitedString(indexPatterns));
11271125
deleteRequest.addParameter("expand_wildcards", "open,closed,hidden");
11281126
final Response response = adminClient().performRequest(deleteRequest);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,40 @@ public GetInferenceModelAction() {
3434

3535
public static class Request extends AcknowledgedRequest<GetInferenceModelAction.Request> {
3636

37+
private static boolean DEFAULT_DO_NOT_PERSIST_DEFAULT_CONFIGS = false;
38+
3739
private final String inferenceEntityId;
3840
private final TaskType taskType;
41+
// Default endpoint configurations are persisted on first read.
42+
// Set to true to avoid persisting on read.
43+
// This setting only applies to GET * requests it has
44+
// no effect when getting a single model
45+
private final boolean doNotPersistDefaultConfigs;
3946

4047
public Request(String inferenceEntityId, TaskType taskType) {
4148
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
4249
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
4350
this.taskType = Objects.requireNonNull(taskType);
51+
this.doNotPersistDefaultConfigs = DEFAULT_DO_NOT_PERSIST_DEFAULT_CONFIGS;
52+
}
53+
54+
public Request(String inferenceEntityId, TaskType taskType, boolean doNotPersistDefaultConfigs) {
55+
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
56+
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
57+
this.taskType = Objects.requireNonNull(taskType);
58+
this.doNotPersistDefaultConfigs = doNotPersistDefaultConfigs;
4459
}
4560

4661
public Request(StreamInput in) throws IOException {
4762
super(in);
4863
this.inferenceEntityId = in.readString();
4964
this.taskType = TaskType.fromStream(in);
65+
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DONT_PERSIST_ON_READ)) {
66+
this.doNotPersistDefaultConfigs = in.readBoolean();
67+
} else {
68+
this.doNotPersistDefaultConfigs = DEFAULT_DO_NOT_PERSIST_DEFAULT_CONFIGS;
69+
}
70+
5071
}
5172

5273
public String getInferenceEntityId() {
@@ -57,24 +78,32 @@ public TaskType getTaskType() {
5778
return taskType;
5879
}
5980

81+
public boolean isDoNotPersistDefaultConfigs() {
82+
return doNotPersistDefaultConfigs;
83+
}
84+
6085
@Override
6186
public void writeTo(StreamOutput out) throws IOException {
6287
super.writeTo(out);
6388
out.writeString(inferenceEntityId);
6489
taskType.writeTo(out);
90+
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DONT_PERSIST_ON_READ)) {
91+
out.writeBoolean(this.doNotPersistDefaultConfigs);
92+
}
6593
}
6694

6795
@Override
6896
public boolean equals(Object o) {
6997
if (this == o) return true;
7098
if (o == null || getClass() != o.getClass()) return false;
7199
Request request = (Request) o;
72-
return Objects.equals(inferenceEntityId, request.inferenceEntityId) && taskType == request.taskType;
100+
return Objects.equals(inferenceEntityId, request.inferenceEntityId) && taskType == request.taskType &&
101+
doNotPersistDefaultConfigs == request.doNotPersistDefaultConfigs;
73102
}
74103

75104
@Override
76105
public int hashCode() {
77-
return Objects.hash(inferenceEntityId, taskType);
106+
return Objects.hash(inferenceEntityId, taskType, doNotPersistDefaultConfigs);
78107
}
79108
}
80109

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.cluster.service.ClusterService;
1515
import org.elasticsearch.common.io.stream.StreamOutput;
1616
import org.elasticsearch.common.settings.Settings;
17+
import org.elasticsearch.index.IndexNotFoundException;
1718
import org.elasticsearch.inference.InferenceService;
1819
import org.elasticsearch.inference.InferenceServiceExtension;
1920
import org.elasticsearch.inference.Model;
@@ -251,7 +252,7 @@ public void testGetAllModels() throws InterruptedException {
251252
}
252253

253254
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
254-
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
255+
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
255256
assertNull(exceptionHolder.get());
256257
assertThat(modelHolder.get(), hasSize(modelCount));
257258
var getAllModels = modelHolder.get();
@@ -333,14 +334,14 @@ public void testGetAllModels_WithDefaults() throws Exception {
333334
}
334335

335336
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
336-
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
337+
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
337338
assertNull(exceptionHolder.get());
338339
assertThat(modelHolder.get(), hasSize(totalModelCount));
339340
var getAllModels = modelHolder.get();
340341
assertReturnModelIsModifiable(modelHolder.get().get(0));
341342

342343
// same result but configs should have been persisted this time
343-
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
344+
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
344345
assertNull(exceptionHolder.get());
345346
assertThat(modelHolder.get(), hasSize(totalModelCount));
346347

@@ -387,7 +388,7 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
387388

388389
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
389390
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
390-
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
391+
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
391392
assertNull(exceptionHolder.get());
392393
assertThat(modelHolder.get(), hasSize(2));
393394
var getAllModels = modelHolder.get();
@@ -405,6 +406,44 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
405406
}
406407
}
407408

409+
public void testGetAllModels_withDoNotPersist() throws Exception {
410+
int defaultModelCount = 2;
411+
var serviceName = "foo";
412+
var service = mock(InferenceService.class);
413+
414+
var defaultConfigs = new ArrayList<Model>();
415+
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
416+
for (int i = 0; i < defaultModelCount; i++) {
417+
var id = "default-" + i;
418+
var taskType = randomFrom(TaskType.values());
419+
defaultConfigs.add(createModel(id, taskType, serviceName));
420+
defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
421+
}
422+
423+
doAnswer(invocation -> {
424+
@SuppressWarnings("unchecked")
425+
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
426+
listener.onResponse(defaultConfigs);
427+
return Void.TYPE;
428+
}).when(service).defaultConfigs(any());
429+
430+
defaultIds.forEach(modelRegistry::addDefaultIds);
431+
432+
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
433+
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
434+
blockingCall(listener -> modelRegistry.getAllModels(true, listener), modelHolder, exceptionHolder);
435+
assertNull(exceptionHolder.get());
436+
assertThat(modelHolder.get(), hasSize(2));
437+
438+
expectThrows(IndexNotFoundException.class, () -> client().admin().indices().prepareGetIndex().addIndices(".inference").get());
439+
440+
// this time check the index is created
441+
blockingCall(listener -> modelRegistry.getAllModels(false, listener), modelHolder, exceptionHolder);
442+
assertNull(exceptionHolder.get());
443+
assertThat(modelHolder.get(), hasSize(2));
444+
assertInferenceIndexExists();
445+
}
446+
408447
public void testGet_WithDefaults() throws InterruptedException {
409448
var serviceName = "foo";
410449
var service = mock(InferenceService.class);
@@ -513,6 +552,12 @@ public void testGetByTaskType_WithDefaults() throws Exception {
513552
assertReturnModelIsModifiable(modelHolder.get().get(0));
514553
}
515554

555+
private void assertInferenceIndexExists() {
556+
var indexResponse = client().admin().indices().prepareGetIndex().addIndices(".inference").get();
557+
assertNotNull(indexResponse.getSettings());
558+
assertNotNull(indexResponse.getMappings());
559+
}
560+
516561
@SuppressWarnings("unchecked")
517562
private void assertReturnModelIsModifiable(UnparsedModel unparsedModel) {
518563
var settings = unparsedModel.settings();
@@ -551,7 +596,6 @@ private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType)
551596
);
552597
default -> throw new IllegalArgumentException("task type " + taskType + " is not supported");
553598
};
554-
555599
}
556600

557601
protected <T> void blockingCall(Consumer<ActionListener<T>> function, AtomicReference<T> response, AtomicReference<Exception> error)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ protected void doExecute(
6969
boolean inferenceEntityIdIsWildCard = Strings.isAllOrWildcard(request.getInferenceEntityId());
7070

7171
if (request.getTaskType() == TaskType.ANY && inferenceEntityIdIsWildCard) {
72-
getAllModels(listener);
72+
getAllModels(request.isDoNotPersistDefaultConfigs(), listener);
7373
} else if (inferenceEntityIdIsWildCard) {
7474
getModelsByTaskType(request.getTaskType(), listener);
7575
} else {
@@ -100,8 +100,9 @@ private void getSingleModel(
100100
}));
101101
}
102102

103-
private void getAllModels(ActionListener<GetInferenceModelAction.Response> listener) {
103+
private void getAllModels(boolean doNotPersistEndpoints, ActionListener<GetInferenceModelAction.Response> listener) {
104104
modelRegistry.getAllModels(
105+
doNotPersistEndpoints == false,
105106
listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))
106107
);
107108
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ protected void masterOperation(
6363
ClusterState state,
6464
ActionListener<XPackUsageFeatureResponse> listener
6565
) {
66-
GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY);
66+
GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY, true);
6767
client.execute(GetInferenceModelAction.INSTANCE, getInferenceModelAction, listener.delegateFailureAndWrap((delegate, response) -> {
6868
Map<String, InferenceFeatureSetUsage.ModelStats> stats = new TreeMap<>();
6969
for (ModelConfigurations model : response.getEndpoints()) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,17 @@
6969

7070
import static org.elasticsearch.core.Strings.format;
7171

72+
/**
73+
* Class for persisting and reading inference endpoint configurations.
74+
* Some inference services provide default configurations, the registry is
75+
* made aware of these at start up via {@link #addDefaultIds(InferenceService.DefaultConfigId)}.
76+
* Only the ids and service details are registered at this point
77+
* as the full config definition may not be known at start up.
78+
* The full config is lazily populated on read and persisted to the
79+
* index. This has the effect of creating the backing index on reading
80+
* the configs. {@link #getAllModels(boolean, ActionListener)} has an option
81+
* to not write the default configs to index on read to avoid index creation.
82+
*/
7283
public class ModelRegistry {
7384
public record ModelConfigMap(Map<String, Object> config, Map<String, Object> secrets) {}
7485

@@ -132,7 +143,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener<Unparse
132143
if (searchResponse.getHits().getHits().length == 0) {
133144
var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
134145
if (maybeDefault.isPresent()) {
135-
getDefaultConfig(maybeDefault.get(), listener);
146+
getDefaultConfig(true, maybeDefault.get(), listener);
136147
} else {
137148
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
138149
}
@@ -163,7 +174,7 @@ public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> lis
163174
if (searchResponse.getHits().getHits().length == 0) {
164175
var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
165176
if (maybeDefault.isPresent()) {
166-
getDefaultConfig(maybeDefault.get(), listener);
177+
getDefaultConfig(true, maybeDefault.get(), listener);
167178
} else {
168179
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
169180
}
@@ -199,7 +210,7 @@ public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedM
199210
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
200211
var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
201212
var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds);
202-
addAllDefaultConfigsIfMissing(modelConfigs, defaultConfigsForTaskType, delegate);
213+
addAllDefaultConfigsIfMissing(true, modelConfigs, defaultConfigsForTaskType, delegate);
203214
});
204215

205216
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString()));
@@ -216,13 +227,20 @@ public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedM
216227

217228
/**
218229
* Get all models.
230+
* If the defaults endpoint configurations have not been persisted then only
231+
* persist them if {@code doNotPersistDefaultEndpoints == true}. Persisting the
232+
* configs has the side effect of creating the index.
233+
*
219234
* Secret settings are not included
235+
* @param doNotPersistDefaultEndpoints Don't persist the defaults endpoint configurations if
236+
* not already persisted. When true this avoids the creation
237+
* of the backing index.
220238
* @param listener Models listener
221239
*/
222-
public void getAllModels(ActionListener<List<UnparsedModel>> listener) {
240+
public void getAllModels(boolean doNotPersistDefaultEndpoints, ActionListener<List<UnparsedModel>> listener) {
223241
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
224242
var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
225-
addAllDefaultConfigsIfMissing(foundConfigs, defaultConfigIds, delegate);
243+
addAllDefaultConfigsIfMissing(doNotPersistDefaultEndpoints, foundConfigs, defaultConfigIds, delegate);
226244
});
227245

228246
// In theory the index should only contain model config documents
@@ -241,6 +259,7 @@ public void getAllModels(ActionListener<List<UnparsedModel>> listener) {
241259
}
242260

243261
private void addAllDefaultConfigsIfMissing(
262+
boolean doNotPersistDefaultEndpoints,
244263
List<UnparsedModel> foundConfigs,
245264
List<InferenceService.DefaultConfigId> matchedDefaults,
246265
ActionListener<List<UnparsedModel>> listener
@@ -263,18 +282,26 @@ private void addAllDefaultConfigsIfMissing(
263282
);
264283

265284
for (var required : missing) {
266-
getDefaultConfig(required, groupedListener);
285+
getDefaultConfig(doNotPersistDefaultEndpoints, required, groupedListener);
267286
}
268287
}
269288
}
270289

271-
private void getDefaultConfig(InferenceService.DefaultConfigId defaultConfig, ActionListener<UnparsedModel> listener) {
290+
private void getDefaultConfig(
291+
boolean doNotPersistDefaultEndpoints,
292+
InferenceService.DefaultConfigId defaultConfig,
293+
ActionListener<UnparsedModel> listener
294+
) {
272295
defaultConfig.service().defaultConfigs(listener.delegateFailureAndWrap((delegate, models) -> {
273296
boolean foundModel = false;
274297
for (var m : models) {
275298
if (m.getInferenceEntityId().equals(defaultConfig.inferenceId())) {
276299
foundModel = true;
277-
storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m)));
300+
if (doNotPersistDefaultEndpoints) {
301+
listener.onResponse(modelToUnparsedModel(m));
302+
} else {
303+
storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m)));
304+
}
278305
break;
279306
}
280307
}
@@ -287,7 +314,7 @@ private void getDefaultConfig(InferenceService.DefaultConfigId defaultConfig, Ac
287314
}));
288315
}
289316

290-
public void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
317+
private void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
291318
var responseListener = ActionListener.<Boolean>wrap(success -> {
292319
logger.debug("Added default inference endpoint [{}]", preconfigured.getInferenceEntityId());
293320
}, exception -> {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
public class GetInferenceModelRequestTests extends AbstractWireSerializingTestCase<GetInferenceModelAction.Request> {
1616

1717
public static GetInferenceModelAction.Request randomTestInstance() {
18-
return new GetInferenceModelAction.Request(randomAlphaOfLength(8), randomFrom(TaskType.values()));
18+
return new GetInferenceModelAction.Request(randomAlphaOfLength(8), randomFrom(TaskType.values()), randomBoolean());
1919
}
2020

2121
@Override
@@ -30,12 +30,17 @@ protected GetInferenceModelAction.Request createTestInstance() {
3030

3131
@Override
3232
protected GetInferenceModelAction.Request mutateInstance(GetInferenceModelAction.Request instance) {
33-
return switch (randomIntBetween(0, 1)) {
33+
return switch (randomIntBetween(0, 2)) {
3434
case 0 -> new GetInferenceModelAction.Request(instance.getInferenceEntityId() + "foo", instance.getTaskType());
3535
case 1 -> {
3636
var nextTaskType = TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length];
3737
yield new GetInferenceModelAction.Request(instance.getInferenceEntityId(), nextTaskType);
3838
}
39+
case 2 -> new GetInferenceModelAction.Request(
40+
instance.getInferenceEntityId(),
41+
instance.getTaskType(),
42+
instance.isDoNotPersistDefaultConfigs() == false
43+
);
3944
default -> throw new UnsupportedOperationException();
4045
};
4146
}

0 commit comments

Comments
 (0)