Skip to content

Commit 2f1c857

Browse files
authored
Exclude Default Inference Endpoints from Cluster State Storage (elastic#125242)
When retrieving a default inference endpoint for the first time, the system automatically creates the endpoint. However, unlike the `put inference model` action, the `get` action does not redirect the request to the master node. Since elastic#121106, we rely on the assumption that every model creation (`put model`) must run on the master node, as it modifies the cluster state. However, this assumption led to a bug where the get action tries to store default inference endpoints from a different node. This change resolves the issue by preventing default inference endpoints from being added to the cluster state. These endpoints are not strictly needed there, as they are already reported by inference services upon startup. **Note:** This bug did not prevent the default endpoints from being used, but it caused repeated attempts to store them in the index, resulting in logging errors on every usage.
1 parent 296bbef commit 2f1c857

File tree

6 files changed

+87
-54
lines changed

6 files changed

+87
-54
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,15 @@ public static class Request extends AcknowledgedRequest<GetInferenceModelAction.
4444
// no effect when getting a single model
4545
private final boolean persistDefaultConfig;
4646

47-
// For testing only, retrieves the minimal config from the cluster state.
48-
private final boolean returnMinimalConfig;
49-
5047
public Request(String inferenceEntityId, TaskType taskType) {
5148
this(inferenceEntityId, taskType, PERSIST_DEFAULT_CONFIGS);
5249
}
5350

5451
public Request(String inferenceEntityId, TaskType taskType, boolean persistDefaultConfig) {
55-
this(inferenceEntityId, taskType, persistDefaultConfig, false);
56-
}
57-
58-
public Request(String inferenceEntityId, TaskType taskType, boolean persistDefaultConfig, boolean returnMinimalConfig) {
5952
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
6053
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
6154
this.taskType = Objects.requireNonNull(taskType);
6255
this.persistDefaultConfig = persistDefaultConfig;
63-
this.returnMinimalConfig = returnMinimalConfig;
6456
}
6557

6658
public Request(StreamInput in) throws IOException {
@@ -72,13 +64,6 @@ public Request(StreamInput in) throws IOException {
7264
} else {
7365
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
7466
}
75-
76-
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_MODEL_REGISTRY_METADATA)) {
77-
this.returnMinimalConfig = in.readBoolean();
78-
} else {
79-
this.returnMinimalConfig = false;
80-
}
81-
8267
}
8368

8469
public String getInferenceEntityId() {
@@ -93,10 +78,6 @@ public boolean isPersistDefaultConfig() {
9378
return persistDefaultConfig;
9479
}
9580

96-
public boolean isReturnMinimalConfig() {
97-
return returnMinimalConfig;
98-
}
99-
10081
@Override
10182
public void writeTo(StreamOutput out) throws IOException {
10283
super.writeTo(out);
@@ -105,10 +86,6 @@ public void writeTo(StreamOutput out) throws IOException {
10586
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
10687
out.writeBoolean(this.persistDefaultConfig);
10788
}
108-
109-
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_MODEL_REGISTRY_METADATA)) {
110-
out.writeBoolean(returnMinimalConfig);
111-
}
11289
}
11390

11491
@Override
@@ -118,13 +95,12 @@ public boolean equals(Object o) {
11895
Request request = (Request) o;
11996
return Objects.equals(inferenceEntityId, request.inferenceEntityId)
12097
&& taskType == request.taskType
121-
&& persistDefaultConfig == request.persistDefaultConfig
122-
&& returnMinimalConfig == request.returnMinimalConfig;
98+
&& persistDefaultConfig == request.persistDefaultConfig;
12399
}
124100

125101
@Override
126102
public int hashCode() {
127-
return Objects.hash(inferenceEntityId, taskType, persistDefaultConfig, returnMinimalConfig);
103+
return Objects.hash(inferenceEntityId, taskType, persistDefaultConfig);
128104
}
129105
}
130106

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.concurrent.CountDownLatch;
2626

2727
import static org.hamcrest.Matchers.empty;
28+
import static org.hamcrest.Matchers.equalTo;
2829
import static org.hamcrest.Matchers.hasSize;
2930
import static org.hamcrest.Matchers.is;
3031
import static org.hamcrest.Matchers.oneOf;
@@ -62,6 +63,25 @@ public void testGet() throws IOException {
6263
assertDefaultRerankConfig(rerankModel);
6364
}
6465

66+
public void testDefaultModels() throws IOException {
67+
var elserModel = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
68+
assertDefaultElserConfig(elserModel);
69+
70+
var e5Model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
71+
assertDefaultE5Config(e5Model);
72+
73+
var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
74+
assertDefaultRerankConfig(rerankModel);
75+
76+
putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING));
77+
var registeredModels = getMinimalConfigs();
78+
assertThat(registeredModels.size(), equalTo(1));
79+
assertTrue(registeredModels.containsKey("my-model"));
80+
assertFalse(registeredModels.containsKey(ElasticsearchInternalService.DEFAULT_E5_ID));
81+
assertFalse(registeredModels.containsKey(ElasticsearchInternalService.DEFAULT_ELSER_ID));
82+
assertFalse(registeredModels.containsKey(ElasticsearchInternalService.DEFAULT_RERANK_ID));
83+
}
84+
6585
@SuppressWarnings("unchecked")
6686
public void testInferDeploysDefaultElser() throws IOException {
6787
var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.common.settings.SecureString;
1717
import org.elasticsearch.common.settings.Settings;
1818
import org.elasticsearch.common.util.concurrent.ThreadContext;
19+
import org.elasticsearch.common.xcontent.support.XContentMapValues;
1920
import org.elasticsearch.core.Nullable;
2021
import org.elasticsearch.inference.TaskType;
2122
import org.elasticsearch.test.cluster.ElasticsearchCluster;
@@ -514,4 +515,13 @@ protected Map<String, Object> getTrainedModel(String inferenceEntityId) throws I
514515
assertStatusOkOrCreated(response);
515516
return entityAsMap(response);
516517
}
518+
519+
@SuppressWarnings("unchecked")
520+
protected Map<String, Map<String, Object>> getMinimalConfigs() throws IOException {
521+
var endpoint = "_cluster/state?filter_path=metadata.model_registry";
522+
var request = new Request("GET", endpoint);
523+
var response = client().performRequest(request);
524+
assertOK(response);
525+
return (Map<String, Map<String, Object>>) XContentMapValues.extractValue("metadata.model_registry.models", entityAsMap(response));
526+
}
517527
}

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/InferenceUpgradeTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ protected Map<String, Object> get(TaskType taskType, String inferenceId) throws
8989
}
9090

9191
@SuppressWarnings("unchecked")
92-
protected Map<String, Map<String, Object>> getMinimalConfig() throws IOException {
92+
protected Map<String, Map<String, Object>> getMinimalConfigs() throws IOException {
9393
var endpoint = "_cluster/state?filter_path=metadata.model_registry";
9494
var request = new Request("GET", endpoint);
9595
var response = client().performRequest(request);

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/ModelRegistryUpgradeIT.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public void testUpgradeModels() throws Exception {
8383
@SuppressWarnings("unchecked")
8484
private void assertMinimalModelsAreUpgraded() throws IOException {
8585
var fullModels = (List<Map<String, Object>>) get(TaskType.ANY, "*").get("endpoints");
86-
var minimalModels = getMinimalConfig();
86+
var minimalModels = getMinimalConfigs();
8787
assertMinimalModelsAreUpgraded(
8888
fullModels.stream().collect(Collectors.toMap(a -> (String) a.get("inference_id"), a -> a)),
8989
minimalModels
@@ -92,9 +92,14 @@ private void assertMinimalModelsAreUpgraded() throws IOException {
9292

9393
@SuppressWarnings("unchecked")
9494
private void assertMinimalModelsAreUpgraded(
95-
Map<String, Map<String, Object>> fullModels,
95+
Map<String, Map<String, Object>> fullModelsWithDefaults,
9696
Map<String, Map<String, Object>> minimalModels
9797
) {
98+
// remove the default models as they are not stored in cluster state.
99+
var fullModels = fullModelsWithDefaults.entrySet()
100+
.stream()
101+
.filter(e -> e.getKey().startsWith(".") == false)
102+
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()));
98103
assertThat(fullModels.size(), greaterThan(0));
99104
assertThat(fullModels.size(), equalTo(minimalModels.size()));
100105
for (var entry : fullModels.entrySet()) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,9 @@ private void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
428428
}
429429
});
430430

431-
storeModel(preconfigured, ActionListener.runAfter(responseListener, runAfter), AcknowledgedRequest.DEFAULT_ACK_TIMEOUT);
431+
// Store the model in the index without adding it to the cluster state,
432+
// as default models are already managed under defaultConfigIds.
433+
storeModel(preconfigured, false, ActionListener.runAfter(responseListener, runAfter), AcknowledgedRequest.DEFAULT_ACK_TIMEOUT);
432434
}
433435

434436
private ArrayList<ModelConfigMap> parseHitsAsModels(SearchHits hits) {
@@ -618,9 +620,15 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
618620

619621
/**
620622
* Note: storeModel does not overwrite existing models and thus does not need to check the lock
623+
*
624+
* <p><b>WARNING:</b> This function must always be called on a master node. Failure to do so will result in an error.
621625
*/
622626
public void storeModel(Model model, ActionListener<Boolean> listener, TimeValue timeout) {
623-
ActionListener<BulkResponse> bulkResponseActionListener = getStoreIndexListener(model, listener, timeout);
627+
storeModel(model, true, listener, timeout);
628+
}
629+
630+
private void storeModel(Model model, boolean addToClusterState, ActionListener<Boolean> listener, TimeValue timeout) {
631+
ActionListener<BulkResponse> bulkResponseActionListener = getStoreIndexListener(model, addToClusterState, listener, timeout);
624632

625633
IndexRequest configRequest = createIndexRequest(
626634
Model.documentId(model.getConfigurations().getInferenceEntityId()),
@@ -643,7 +651,12 @@ public void storeModel(Model model, ActionListener<Boolean> listener, TimeValue
643651
.execute(bulkResponseActionListener);
644652
}
645653

646-
private ActionListener<BulkResponse> getStoreIndexListener(Model model, ActionListener<Boolean> listener, TimeValue timeout) {
654+
private ActionListener<BulkResponse> getStoreIndexListener(
655+
Model model,
656+
boolean addToClusterState,
657+
ActionListener<Boolean> listener,
658+
TimeValue timeout
659+
) {
647660
return ActionListener.wrap(bulkItemResponses -> {
648661
var inferenceEntityId = model.getConfigurations().getInferenceEntityId();
649662

@@ -667,16 +680,20 @@ private ActionListener<BulkResponse> getStoreIndexListener(Model model, ActionLi
667680
BulkItemResponse.Failure failure = getFirstBulkFailure(bulkItemResponses);
668681

669682
if (failure == null) {
670-
var storeListener = getStoreMetadataListener(inferenceEntityId, listener);
671-
try {
672-
var projectId = clusterService.state().projectState().projectId();
673-
metadataTaskQueue.submitTask(
674-
"add model [" + inferenceEntityId + "]",
675-
new AddModelMetadataTask(projectId, inferenceEntityId, new MinimalServiceSettings(model), storeListener),
676-
timeout
677-
);
678-
} catch (Exception exc) {
679-
storeListener.onFailure(exc);
683+
if (addToClusterState) {
684+
var storeListener = getStoreMetadataListener(inferenceEntityId, listener);
685+
try {
686+
var projectId = clusterService.state().projectState().projectId();
687+
metadataTaskQueue.submitTask(
688+
"add model [" + inferenceEntityId + "]",
689+
new AddModelMetadataTask(projectId, inferenceEntityId, new MinimalServiceSettings(model), storeListener),
690+
timeout
691+
);
692+
} catch (Exception exc) {
693+
storeListener.onFailure(exc);
694+
}
695+
} else {
696+
listener.onResponse(Boolean.TRUE);
680697
}
681698
return;
682699
}
@@ -719,7 +736,8 @@ public void onFailure(Exception exc) {
719736
+ "inconsistent state. Please try deleting and re-adding the endpoint.",
720737
inferenceEntityId
721738
),
722-
RestStatus.INTERNAL_SERVER_ERROR
739+
RestStatus.INTERNAL_SERVER_ERROR,
740+
exc
723741
)
724742
);
725743
}));
@@ -810,7 +828,8 @@ public void onFailure(Exception exc) {
810828
+ "inconsistent state. Please try deleting the endpoint again.",
811829
inferenceEntityIds
812830
),
813-
RestStatus.INTERNAL_SERVER_ERROR
831+
RestStatus.INTERNAL_SERVER_ERROR,
832+
exc
814833
)
815834
);
816835
}
@@ -924,16 +943,19 @@ public void clusterChanged(ClusterChangedEvent event) {
924943
public void onResponse(GetInferenceModelAction.Response response) {
925944
Map<String, MinimalServiceSettings> map = new HashMap<>();
926945
for (var model : response.getEndpoints()) {
927-
map.put(
928-
model.getInferenceEntityId(),
929-
new MinimalServiceSettings(
930-
model.getService(),
931-
model.getTaskType(),
932-
model.getServiceSettings().dimensions(),
933-
model.getServiceSettings().similarity(),
934-
model.getServiceSettings().elementType()
935-
)
936-
);
946+
// ignore default models
947+
if (defaultConfigIds.containsKey(model.getInferenceEntityId()) == false) {
948+
map.put(
949+
model.getInferenceEntityId(),
950+
new MinimalServiceSettings(
951+
model.getService(),
952+
model.getTaskType(),
953+
model.getServiceSettings().dimensions(),
954+
model.getServiceSettings().similarity(),
955+
model.getServiceSettings().elementType()
956+
)
957+
);
958+
}
937959
}
938960
metadataTaskQueue.submitTask(
939961
"model registry auto upgrade",

0 commit comments

Comments
 (0)