Skip to content

Commit 644cd2e

Browse files
authored
[8.16][ML] Fix .inference index leaking out of tests. (#115343)
Backport of #115023 with additional logic to handle streaming to/from a patch version
1 parent e6cacd9 commit 644cd2e

File tree

8 files changed

+156
-67
lines changed

8 files changed

+156
-67
lines changed

muted-tests.yml

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -324,65 +324,20 @@ tests:
324324
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
325325
method: testPutE5WithTrainedModelAndInference
326326
issue: https://github.com/elastic/elasticsearch/issues/114023
327-
- class: org.elasticsearch.xpack.enrich.EnrichRestIT
328-
method: test {p0=enrich/40_synthetic_source/enrich documents over _bulk}
329-
issue: https://github.com/elastic/elasticsearch/issues/114825
330327
- class: org.elasticsearch.threadpool.SimpleThreadPoolIT
331328
method: testThreadPoolMetrics
332329
issue: https://github.com/elastic/elasticsearch/issues/108320
333-
- class: org.elasticsearch.xpack.eql.EqlRestValidationIT
334-
method: testDefaultIndicesOptions
335-
issue: https://github.com/elastic/elasticsearch/issues/114771
336-
- class: org.elasticsearch.xpack.enrich.EnrichRestIT
337-
method: test {p0=enrich/10_basic/Test enrich crud apis}
338-
issue: https://github.com/elastic/elasticsearch/issues/114766
339330
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
340331
method: testPutE5Small_withPlatformAgnosticVariant
341332
issue: https://github.com/elastic/elasticsearch/issues/113983
342-
- class: org.elasticsearch.xpack.eql.EqlRestIT
343-
method: testIndexWildcardPatterns
344-
issue: https://github.com/elastic/elasticsearch/issues/114749
345-
- class: org.elasticsearch.xpack.enrich.EnrichRestIT
346-
method: test {p0=enrich/20_standard_index/enrich stats REST response structure}
347-
issue: https://github.com/elastic/elasticsearch/issues/114753
348333
- class: org.elasticsearch.ingest.geoip.DatabaseNodeServiceIT
349334
method: testGzippedDatabase
350335
issue: https://github.com/elastic/elasticsearch/issues/113752
351-
- class: org.elasticsearch.xpack.enrich.EnrichRestIT
352-
method: test {p0=enrich/10_basic/Test using the deprecated elasticsearch_version field results in a warning}
353-
issue: https://github.com/elastic/elasticsearch/issues/114748
354-
- class: org.elasticsearch.xpack.enrich.EnrichRestIT
355-
method: test {p0=enrich/20_standard_index/enrich documents over _bulk via an alias}
356-
issue: https://github.com/elastic/elasticsearch/issues/114763
357-
- class: org.elasticsearch.xpack.eql.EqlRestValidationIT
358-
method: testAllowNoIndicesOption
359-
issue: https://github.com/elastic/elasticsearch/issues/114789
360-
- class: org.elasticsearch.xpack.enrich.EnrichRestIT
361-
method: test {p0=enrich/20_standard_index/enrich documents over _bulk}
362-
issue: https://github.com/elastic/elasticsearch/issues/114768
363-
- class: org.elasticsearch.xpack.eql.EqlStatsIT
364-
method: testEqlRestUsage
365-
issue: https://github.com/elastic/elasticsearch/issues/114790
366-
- class: org.elasticsearch.xpack.eql.EqlRestIT
367-
method: testBadRequests
368-
issue: https://github.com/elastic/elasticsearch/issues/114752
369-
- class: org.elasticsearch.xpack.eql.EqlRestIT
370-
method: testUnicodeChars
371-
issue: https://github.com/elastic/elasticsearch/issues/114791
372336
- class: org.elasticsearch.xpack.rank.rrf.RRFRankClientYamlTestSuiteIT
373337
method: test {yaml=rrf/800_rrf_with_text_similarity_reranker_retriever/explain using rrf retriever and text-similarity}
374338
issue: https://github.com/elastic/elasticsearch/issues/114757
375-
- class: org.elasticsearch.xpack.enrich.EnrichIT
376-
method: testEnrichSpecialTypes
377-
issue: https://github.com/elastic/elasticsearch/issues/114773
378339
- class: org.elasticsearch.license.LicensingTests
379340
issue: https://github.com/elastic/elasticsearch/issues/114865
380-
- class: org.elasticsearch.xpack.enrich.EnrichIT
381-
method: testDeleteIsCaseSensitive
382-
issue: https://github.com/elastic/elasticsearch/issues/114840
383-
- class: org.elasticsearch.xpack.enrich.EnrichIT
384-
method: testImmutablePolicy
385-
issue: https://github.com/elastic/elasticsearch/issues/114839
386341
- class: org.elasticsearch.xpack.security.authc.ldap.GroupMappingIT
387342
issue: https://github.com/elastic/elasticsearch/issues/115221
388343
- class: org.elasticsearch.smoketest.DocsClientYamlTestSuiteIT

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ static TransportVersion def(int id) {
246246
public static final TransportVersion ESQL_PER_AGGREGATE_FILTER = def(8_770_00_0);
247247
public static final TransportVersion ML_INFERENCE_ATTACH_TO_EXISTSING_DEPLOYMENT = def(8_771_00_0);
248248
public static final TransportVersion CONVERT_FAILURE_STORE_OPTIONS_TO_SELECTOR_OPTIONS_INTERNALLY = def(8_772_00_0);
249+
public static final TransportVersion INFERENCE_DONT_PERSIST_ON_READ_BACKPORT_8_16 = def(8_772_00_1);
249250

250251
/*
251252
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.core.inference.action;
99

10+
import org.elasticsearch.TransportVersion;
1011
import org.elasticsearch.TransportVersions;
1112
import org.elasticsearch.action.ActionResponse;
1213
import org.elasticsearch.action.ActionType;
@@ -34,19 +35,46 @@ public GetInferenceModelAction() {
3435

3536
public static class Request extends AcknowledgedRequest<GetInferenceModelAction.Request> {
3637

38+
private static boolean PERSIST_DEFAULT_CONFIGS = true;
39+
40+
public static boolean shouldReadPersistDefault(TransportVersion transportVersion) {
41+
// This constant is defined on future branches but we need to know about it here
42+
final TransportVersion INFERENCE_DONT_PERSIST_ON_READ = new TransportVersion(8_776_00_0);
43+
return transportVersion.onOrAfter(INFERENCE_DONT_PERSIST_ON_READ)
44+
|| transportVersion.isPatchFrom(TransportVersions.INFERENCE_DONT_PERSIST_ON_READ_BACKPORT_8_16);
45+
}
46+
3747
private final String inferenceEntityId;
3848
private final TaskType taskType;
49+
// Default endpoint configurations are persisted on first read.
50+
// Set to false to avoid persisting on read.
51+
// This setting only applies to GET * requests. It has
52+
// no effect when getting a single model
53+
private final boolean persistDefaultConfig;
3954

4055
public Request(String inferenceEntityId, TaskType taskType) {
4156
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
4257
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
4358
this.taskType = Objects.requireNonNull(taskType);
59+
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
60+
}
61+
62+
public Request(String inferenceEntityId, TaskType taskType, boolean persistDefaultConfig) {
63+
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
64+
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
65+
this.taskType = Objects.requireNonNull(taskType);
66+
this.persistDefaultConfig = persistDefaultConfig;
4467
}
4568

4669
public Request(StreamInput in) throws IOException {
4770
super(in);
4871
this.inferenceEntityId = in.readString();
4972
this.taskType = TaskType.fromStream(in);
73+
if (shouldReadPersistDefault(in.getTransportVersion())) {
74+
this.persistDefaultConfig = in.readBoolean();
75+
} else {
76+
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
77+
}
5078
}
5179

5280
public String getInferenceEntityId() {
@@ -57,24 +85,33 @@ public TaskType getTaskType() {
5785
return taskType;
5886
}
5987

88+
public boolean isPersistDefaultConfig() {
89+
return persistDefaultConfig;
90+
}
91+
6092
@Override
6193
public void writeTo(StreamOutput out) throws IOException {
6294
super.writeTo(out);
6395
out.writeString(inferenceEntityId);
6496
taskType.writeTo(out);
97+
if (shouldReadPersistDefault(out.getTransportVersion())) {
98+
out.writeBoolean(this.persistDefaultConfig);
99+
}
65100
}
66101

67102
@Override
68103
public boolean equals(Object o) {
69104
if (this == o) return true;
70105
if (o == null || getClass() != o.getClass()) return false;
71106
Request request = (Request) o;
72-
return Objects.equals(inferenceEntityId, request.inferenceEntityId) && taskType == request.taskType;
107+
return Objects.equals(inferenceEntityId, request.inferenceEntityId)
108+
&& taskType == request.taskType
109+
&& persistDefaultConfig == request.persistDefaultConfig;
73110
}
74111

75112
@Override
76113
public int hashCode() {
77-
return Objects.hash(inferenceEntityId, taskType);
114+
return Objects.hash(inferenceEntityId, taskType, persistDefaultConfig);
78115
}
79116
}
80117

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.cluster.service.ClusterService;
1515
import org.elasticsearch.common.io.stream.StreamOutput;
1616
import org.elasticsearch.common.settings.Settings;
17+
import org.elasticsearch.index.IndexNotFoundException;
1718
import org.elasticsearch.inference.InferenceService;
1819
import org.elasticsearch.inference.InferenceServiceExtension;
1920
import org.elasticsearch.inference.Model;
@@ -250,7 +251,7 @@ public void testGetAllModels() throws InterruptedException {
250251
}
251252

252253
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
253-
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
254+
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
254255
assertNull(exceptionHolder.get());
255256
assertThat(modelHolder.get(), hasSize(modelCount));
256257
var getAllModels = modelHolder.get();
@@ -332,14 +333,14 @@ public void testGetAllModels_WithDefaults() throws Exception {
332333
}
333334

334335
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
335-
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
336+
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
336337
assertNull(exceptionHolder.get());
337338
assertThat(modelHolder.get(), hasSize(totalModelCount));
338339
var getAllModels = modelHolder.get();
339340
assertReturnModelIsModifiable(modelHolder.get().get(0));
340341

341342
// same result but configs should have been persisted this time
342-
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
343+
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
343344
assertNull(exceptionHolder.get());
344345
assertThat(modelHolder.get(), hasSize(totalModelCount));
345346

@@ -386,7 +387,7 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
386387

387388
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
388389
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
389-
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
390+
blockingCall(listener -> modelRegistry.getAllModels(randomBoolean(), listener), modelHolder, exceptionHolder);
390391
assertNull(exceptionHolder.get());
391392
assertThat(modelHolder.get(), hasSize(2));
392393
var getAllModels = modelHolder.get();
@@ -404,6 +405,44 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
404405
}
405406
}
406407

408+
public void testGetAllModels_withDoNotPersist() throws Exception {
409+
int defaultModelCount = 2;
410+
var serviceName = "foo";
411+
var service = mock(InferenceService.class);
412+
413+
var defaultConfigs = new ArrayList<Model>();
414+
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
415+
for (int i = 0; i < defaultModelCount; i++) {
416+
var id = "default-" + i;
417+
var taskType = randomFrom(TaskType.values());
418+
defaultConfigs.add(createModel(id, taskType, serviceName));
419+
defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
420+
}
421+
422+
doAnswer(invocation -> {
423+
@SuppressWarnings("unchecked")
424+
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
425+
listener.onResponse(defaultConfigs);
426+
return Void.TYPE;
427+
}).when(service).defaultConfigs(any());
428+
429+
defaultIds.forEach(modelRegistry::addDefaultIds);
430+
431+
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
432+
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
433+
blockingCall(listener -> modelRegistry.getAllModels(false, listener), modelHolder, exceptionHolder);
434+
assertNull(exceptionHolder.get());
435+
assertThat(modelHolder.get(), hasSize(2));
436+
437+
expectThrows(IndexNotFoundException.class, () -> client().admin().indices().prepareGetIndex().addIndices(".inference").get());
438+
439+
// this time check the index is created
440+
blockingCall(listener -> modelRegistry.getAllModels(true, listener), modelHolder, exceptionHolder);
441+
assertNull(exceptionHolder.get());
442+
assertThat(modelHolder.get(), hasSize(2));
443+
assertInferenceIndexExists();
444+
}
445+
407446
public void testGet_WithDefaults() throws InterruptedException {
408447
var serviceName = "foo";
409448
var service = mock(InferenceService.class);
@@ -512,6 +551,12 @@ public void testGetByTaskType_WithDefaults() throws Exception {
512551
assertReturnModelIsModifiable(modelHolder.get().get(0));
513552
}
514553

554+
private void assertInferenceIndexExists() {
555+
var indexResponse = client().admin().indices().prepareGetIndex().addIndices(".inference").get();
556+
assertNotNull(indexResponse.getSettings());
557+
assertNotNull(indexResponse.getMappings());
558+
}
559+
515560
@SuppressWarnings("unchecked")
516561
private void assertReturnModelIsModifiable(UnparsedModel unparsedModel) {
517562
var settings = unparsedModel.settings();
@@ -550,7 +595,6 @@ private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType)
550595
);
551596
default -> throw new IllegalArgumentException("task type " + taskType + " is not supported");
552597
};
553-
554598
}
555599

556600
protected <T> void blockingCall(Consumer<ActionListener<T>> function, AtomicReference<T> response, AtomicReference<Exception> error)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ protected void doExecute(
7272
boolean inferenceEntityIdIsWildCard = Strings.isAllOrWildcard(request.getInferenceEntityId());
7373

7474
if (request.getTaskType() == TaskType.ANY && inferenceEntityIdIsWildCard) {
75-
getAllModels(listener);
75+
getAllModels(request.isPersistDefaultConfig(), listener);
7676
} else if (inferenceEntityIdIsWildCard) {
7777
getModelsByTaskType(request.getTaskType(), listener);
7878
} else {
@@ -114,8 +114,11 @@ private void getSingleModel(
114114
}));
115115
}
116116

117-
private void getAllModels(ActionListener<GetInferenceModelAction.Response> listener) {
118-
modelRegistry.getAllModels(listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener))));
117+
private void getAllModels(boolean persistDefaultEndpoints, ActionListener<GetInferenceModelAction.Response> listener) {
118+
modelRegistry.getAllModels(
119+
persistDefaultEndpoints,
120+
listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, l)))
121+
);
119122
}
120123

121124
private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceModelAction.Response> listener) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ protected void masterOperation(
6363
ClusterState state,
6464
ActionListener<XPackUsageFeatureResponse> listener
6565
) {
66-
GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY);
66+
GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY, false);
6767
client.execute(GetInferenceModelAction.INSTANCE, getInferenceModelAction, listener.delegateFailureAndWrap((delegate, response) -> {
6868
Map<String, InferenceFeatureSetUsage.ModelStats> stats = new TreeMap<>();
6969
for (ModelConfigurations model : response.getEndpoints()) {

0 commit comments

Comments
 (0)