Skip to content

Commit 7794914

Browse files
authored
[8.x] Add ModelRegistryMetadata to Cluster State (elastic#125150)
* Add ModelRegistryMetadata to Cluster State (elastic#121106) This commit integrates `MinimalServiceSettings` (introduced in elastic#120560) into the cluster state for all registered models in the `ModelRegistry`. These settings allow consumers to access configuration details without requiring asynchronous calls to retrieve full model configurations. To ensure consistency, the cluster state metadata must remain synchronized with the models in the inference index. If a mismatch is detected during startup, the master node performs an upgrade to load all model settings from the index. * fix test compil * fix serialisation * Exclude Default Inference Endpoints from Cluster State Storage (elastic#125242) When retrieving a default inference endpoint for the first time, the system automatically creates the endpoint. However, unlike the `put inference model` action, the `get` action does not redirect the request to the master node. Since elastic#121106, we rely on the assumption that every model creation (`put model`) must run on the master node, as it modifies the cluster state. However, this assumption led to a bug where the get action tries to store default inference endpoints from a different node. This change resolves the issue by preventing default inference endpoints from being added to the cluster state. These endpoints are not strictly needed there, as they are already reported by inference services upon startup. **Note:** This bug did not prevent the default endpoints from being used, but it caused repeated attempts to store them in the index, resulting in logging errors on every usage.
1 parent aa89197 commit 7794914

File tree

43 files changed

+1465
-612
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1465
-612
lines changed

docs/changelog/121106.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 121106
2+
summary: Add `ModelRegistryMetadata` to Cluster State
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ static TransportVersion def(int id) {
196196
public static final TransportVersion ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_09);
197197
public static final TransportVersion ESQL_SERIALIZE_BLOCK_TYPE_CODE = def(8_841_0_10);
198198
public static final TransportVersion ESQL_FAILURE_FROM_REMOTE = def(8_841_0_11);
199+
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(8_841_0_12);
200+
public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA_8_19 = def(8_841_0_13);
199201

200202
/*
201203
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java

Lines changed: 111 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99

1010
package org.elasticsearch.inference;
1111

12+
import org.elasticsearch.TransportVersion;
13+
import org.elasticsearch.TransportVersions;
14+
import org.elasticsearch.cluster.Diff;
15+
import org.elasticsearch.cluster.SimpleDiffable;
16+
import org.elasticsearch.common.io.stream.StreamInput;
17+
import org.elasticsearch.common.io.stream.StreamOutput;
1218
import org.elasticsearch.core.Nullable;
1319
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1420
import org.elasticsearch.xcontent.ConstructingObjectParser;
@@ -46,12 +52,16 @@
4652
* @param elementType the type of elements in the embeddings, applicable only for {@link TaskType#TEXT_EMBEDDING} (nullable).
4753
*/
4854
public record MinimalServiceSettings(
55+
@Nullable String service,
4956
TaskType taskType,
5057
@Nullable Integer dimensions,
5158
@Nullable SimilarityMeasure similarity,
5259
@Nullable ElementType elementType
53-
) implements ToXContentObject {
60+
) implements ServiceSettings, SimpleDiffable<MinimalServiceSettings> {
5461

62+
public static final String NAME = "minimal_service_settings";
63+
64+
public static final String SERVICE_FIELD = "service";
5565
public static final String TASK_TYPE_FIELD = "task_type";
5666
static final String DIMENSIONS_FIELD = "dimensions";
5767
static final String SIMILARITY_FIELD = "similarity";
@@ -61,17 +71,20 @@ public record MinimalServiceSettings(
6171
"model_settings",
6272
true,
6373
args -> {
64-
TaskType taskType = TaskType.fromString((String) args[0]);
65-
Integer dimensions = (Integer) args[1];
66-
SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]);
67-
DenseVectorFieldMapper.ElementType elementType = args[3] == null
74+
String service = (String) args[0];
75+
TaskType taskType = TaskType.fromString((String) args[1]);
76+
Integer dimensions = (Integer) args[2];
77+
SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]);
78+
DenseVectorFieldMapper.ElementType elementType = args[4] == null
6879
? null
69-
: DenseVectorFieldMapper.ElementType.fromString((String) args[3]);
70-
return new MinimalServiceSettings(taskType, dimensions, similarity, elementType);
80+
: DenseVectorFieldMapper.ElementType.fromString((String) args[4]);
81+
return new MinimalServiceSettings(service, taskType, dimensions, similarity, elementType);
7182
}
7283
);
84+
private static final String UNKNOWN_SERVICE = "_unknown_";
7385

7486
static {
87+
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SERVICE_FIELD));
7588
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(TASK_TYPE_FIELD));
7689
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD));
7790
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SIMILARITY_FIELD));
@@ -82,51 +95,95 @@ public static MinimalServiceSettings parse(XContentParser parser) throws IOExcep
8295
return PARSER.parse(parser, null);
8396
}
8497

85-
public static MinimalServiceSettings textEmbedding(int dimensions, SimilarityMeasure similarity, ElementType elementType) {
86-
return new MinimalServiceSettings(TEXT_EMBEDDING, dimensions, similarity, elementType);
98+
public static MinimalServiceSettings textEmbedding(
99+
String serviceName,
100+
int dimensions,
101+
SimilarityMeasure similarity,
102+
ElementType elementType
103+
) {
104+
return new MinimalServiceSettings(serviceName, TEXT_EMBEDDING, dimensions, similarity, elementType);
105+
}
106+
107+
public static MinimalServiceSettings sparseEmbedding(String serviceName) {
108+
return new MinimalServiceSettings(serviceName, SPARSE_EMBEDDING, null, null, null);
87109
}
88110

89-
public static MinimalServiceSettings sparseEmbedding() {
90-
return new MinimalServiceSettings(SPARSE_EMBEDDING, null, null, null);
111+
public static MinimalServiceSettings rerank(String serviceName) {
112+
return new MinimalServiceSettings(serviceName, RERANK, null, null, null);
91113
}
92114

93-
public static MinimalServiceSettings rerank() {
94-
return new MinimalServiceSettings(RERANK, null, null, null);
115+
public static MinimalServiceSettings completion(String serviceName) {
116+
return new MinimalServiceSettings(serviceName, COMPLETION, null, null, null);
95117
}
96118

97-
public static MinimalServiceSettings completion() {
98-
return new MinimalServiceSettings(COMPLETION, null, null, null);
119+
public static MinimalServiceSettings chatCompletion(String serviceName) {
120+
return new MinimalServiceSettings(serviceName, CHAT_COMPLETION, null, null, null);
99121
}
100122

101-
public static MinimalServiceSettings chatCompletion() {
102-
return new MinimalServiceSettings(CHAT_COMPLETION, null, null, null);
123+
public MinimalServiceSettings {
124+
Objects.requireNonNull(taskType, "task type must not be null");
125+
validate(taskType, dimensions, similarity, elementType);
103126
}
104127

105128
public MinimalServiceSettings(Model model) {
106129
this(
130+
model.getConfigurations().getService(),
107131
model.getTaskType(),
108132
model.getServiceSettings().dimensions(),
109133
model.getServiceSettings().similarity(),
110134
model.getServiceSettings().elementType()
111135
);
112136
}
113137

114-
public MinimalServiceSettings(
115-
TaskType taskType,
116-
@Nullable Integer dimensions,
117-
@Nullable SimilarityMeasure similarity,
118-
@Nullable ElementType elementType
119-
) {
120-
this.taskType = Objects.requireNonNull(taskType, "task type must not be null");
121-
this.dimensions = dimensions;
122-
this.similarity = similarity;
123-
this.elementType = elementType;
124-
validate();
138+
public MinimalServiceSettings(StreamInput in) throws IOException {
139+
this(
140+
in.readOptionalString(),
141+
TaskType.fromStream(in),
142+
in.readOptionalInt(),
143+
in.readOptionalEnum(SimilarityMeasure.class),
144+
in.readOptionalEnum(ElementType.class)
145+
);
146+
}
147+
148+
@Override
149+
public void writeTo(StreamOutput out) throws IOException {
150+
out.writeOptionalString(service);
151+
taskType.writeTo(out);
152+
out.writeOptionalInt(dimensions);
153+
out.writeOptionalEnum(similarity);
154+
out.writeOptionalEnum(elementType);
155+
}
156+
157+
@Override
158+
public String getWriteableName() {
159+
return NAME;
160+
}
161+
162+
@Override
163+
public TransportVersion getMinimalSupportedVersion() {
164+
return TransportVersions.INFERENCE_MODEL_REGISTRY_METADATA_8_19;
165+
}
166+
167+
@Override
168+
public ToXContentObject getFilteredXContentObject() {
169+
return this::toXContent;
170+
}
171+
172+
@Override
173+
public String modelId() {
174+
return null;
175+
}
176+
177+
public static Diff<MinimalServiceSettings> readDiffFrom(StreamInput in) throws IOException {
178+
return SimpleDiffable.readDiffFrom(MinimalServiceSettings::new, in);
125179
}
126180

127181
@Override
128182
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
129183
builder.startObject();
184+
if (service != null) {
185+
builder.field(SERVICE_FIELD, service);
186+
}
130187
builder.field(TASK_TYPE_FIELD, taskType.toString());
131188
if (dimensions != null) {
132189
builder.field(DIMENSIONS_FIELD, dimensions);
@@ -143,7 +200,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
143200
@Override
144201
public String toString() {
145202
final StringBuilder sb = new StringBuilder();
146-
sb.append("task_type=").append(taskType);
203+
sb.append("service=").append(service);
204+
sb.append(", task_type=").append(taskType);
147205
if (dimensions != null) {
148206
sb.append(", dimensions=").append(dimensions);
149207
}
@@ -156,31 +214,46 @@ public String toString() {
156214
return sb.toString();
157215
}
158216

159-
private void validate() {
217+
private static void validate(TaskType taskType, Integer dimensions, SimilarityMeasure similarity, ElementType elementType) {
160218
switch (taskType) {
161219
case TEXT_EMBEDDING:
162-
validateFieldPresent(DIMENSIONS_FIELD, dimensions);
163-
validateFieldPresent(SIMILARITY_FIELD, similarity);
164-
validateFieldPresent(ELEMENT_TYPE_FIELD, elementType);
220+
validateFieldPresent(DIMENSIONS_FIELD, dimensions, taskType);
221+
validateFieldPresent(SIMILARITY_FIELD, similarity, taskType);
222+
validateFieldPresent(ELEMENT_TYPE_FIELD, elementType, taskType);
165223
break;
166224

167225
default:
168-
validateFieldNotPresent(DIMENSIONS_FIELD, dimensions);
169-
validateFieldNotPresent(SIMILARITY_FIELD, similarity);
170-
validateFieldNotPresent(ELEMENT_TYPE_FIELD, elementType);
226+
validateFieldNotPresent(DIMENSIONS_FIELD, dimensions, taskType);
227+
validateFieldNotPresent(SIMILARITY_FIELD, similarity, taskType);
228+
validateFieldNotPresent(ELEMENT_TYPE_FIELD, elementType, taskType);
171229
break;
172230
}
173231
}
174232

175-
private void validateFieldPresent(String field, Object fieldValue) {
233+
private static void validateFieldPresent(String field, Object fieldValue, TaskType taskType) {
176234
if (fieldValue == null) {
177235
throw new IllegalArgumentException("required [" + field + "] field is missing for task_type [" + taskType.name() + "]");
178236
}
179237
}
180238

181-
private void validateFieldNotPresent(String field, Object fieldValue) {
239+
private static void validateFieldNotPresent(String field, Object fieldValue, TaskType taskType) {
182240
if (fieldValue != null) {
183241
throw new IllegalArgumentException("[" + field + "] is not allowed for task_type [" + taskType.name() + "]");
184242
}
185243
}
244+
245+
public ModelConfigurations toModelConfigurations(String inferenceEntityId) {
246+
return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this);
247+
}
248+
249+
/**
250+
* Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition.
251+
*/
252+
public boolean canMergeWith(MinimalServiceSettings other) {
253+
return taskType == other.taskType
254+
&& Objects.equals(dimensions, other.dimensions)
255+
&& similarity == other.similarity
256+
&& elementType == other.elementType
257+
&& (service == null || service.equals(other.service));
258+
}
186259
}

server/src/test/java/org/elasticsearch/inference/MinimalServiceSettingsTests.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
import java.io.IOException;
1717

1818
public class MinimalServiceSettingsTests extends AbstractXContentTestCase<MinimalServiceSettings> {
19-
@Override
20-
protected MinimalServiceSettings createTestInstance() {
19+
public static MinimalServiceSettings randomInstance() {
2120
TaskType taskType = randomFrom(TaskType.values());
2221
Integer dimensions = null;
2322
SimilarityMeasure similarity = null;
@@ -28,7 +27,12 @@ protected MinimalServiceSettings createTestInstance() {
2827
similarity = randomFrom(SimilarityMeasure.values());
2928
elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
3029
}
31-
return new MinimalServiceSettings(taskType, dimensions, similarity, elementType);
30+
return new MinimalServiceSettings(randomBoolean() ? null : randomAlphaOfLength(10), taskType, dimensions, similarity, elementType);
31+
}
32+
33+
@Override
34+
protected MinimalServiceSettings createTestInstance() {
35+
return randomInstance();
3236
}
3337

3438
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,7 @@ public static class Request extends AcknowledgedRequest<GetInferenceModelAction.
4545
private final boolean persistDefaultConfig;
4646

4747
public Request(String inferenceEntityId, TaskType taskType) {
48-
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
49-
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
50-
this.taskType = Objects.requireNonNull(taskType);
51-
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
48+
this(inferenceEntityId, taskType, PERSIST_DEFAULT_CONFIGS);
5249
}
5350

5451
public Request(String inferenceEntityId, TaskType taskType, boolean persistDefaultConfig) {
@@ -68,7 +65,6 @@ public Request(StreamInput in) throws IOException {
6865
} else {
6966
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
7067
}
71-
7268
}
7369

7470
public String getInferenceEntityId() {

x-pack/plugin/esql/qa/server/multi-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/multi_node/SemanticMatchIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
1818
public class SemanticMatchIT extends SemanticMatchTestCase {
1919
@ClassRule
20-
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test"));
20+
public static ElasticsearchCluster cluster = Clusters.testCluster(
21+
spec -> spec.module("x-pack-inference").plugin("inference-service-test")
22+
);
2123

2224
@Override
2325
protected String getTestRestCluster() {

x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.io.IOException;
1919
import java.util.Map;
2020

21+
import static org.hamcrest.Matchers.equalTo;
2122
import static org.hamcrest.core.StringContains.containsString;
2223

2324
public abstract class SemanticMatchTestCase extends ESRestTestCase {
@@ -88,16 +89,22 @@ public void setUpTextEmbeddingInferenceEndpoint() throws IOException {
8889
Request request = new Request("PUT", "_inference/text_embedding/test_dense_inference");
8990
request.setJsonEntity("""
9091
{
91-
"service": "test_service",
92+
"service": "text_embedding_test_service",
9293
"service_settings": {
9394
"model": "my_model",
94-
"api_key": "abc64"
95+
"api_key": "abc64",
96+
"dimensions": 128
9597
},
9698
"task_settings": {
9799
}
98100
}
99101
""");
100-
adminClient().performRequest(request);
102+
try {
103+
adminClient().performRequest(request);
104+
} catch (ResponseException exc) {
105+
// in case the removal failed
106+
assertThat(exc.getResponse().getStatusLine().getStatusCode(), equalTo(400));
107+
}
101108
}
102109

103110
@After

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.concurrent.CountDownLatch;
2525

2626
import static org.hamcrest.Matchers.empty;
27+
import static org.hamcrest.Matchers.equalTo;
2728
import static org.hamcrest.Matchers.hasSize;
2829
import static org.hamcrest.Matchers.is;
2930
import static org.hamcrest.Matchers.oneOf;
@@ -54,6 +55,25 @@ public void testGet() throws IOException {
5455
assertDefaultRerankConfig(rerankModel);
5556
}
5657

58+
public void testDefaultModels() throws IOException {
59+
var elserModel = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
60+
assertDefaultElserConfig(elserModel);
61+
62+
var e5Model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
63+
assertDefaultE5Config(e5Model);
64+
65+
var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
66+
assertDefaultRerankConfig(rerankModel);
67+
68+
putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING));
69+
var registeredModels = getMinimalConfigs();
70+
assertThat(registeredModels.size(), equalTo(1));
71+
assertTrue(registeredModels.containsKey("my-model"));
72+
assertFalse(registeredModels.containsKey(ElasticsearchInternalService.DEFAULT_E5_ID));
73+
assertFalse(registeredModels.containsKey(ElasticsearchInternalService.DEFAULT_ELSER_ID));
74+
assertFalse(registeredModels.containsKey(ElasticsearchInternalService.DEFAULT_RERANK_ID));
75+
}
76+
5777
@SuppressWarnings("unchecked")
5878
public void testInferDeploysDefaultElser() throws IOException {
5979
var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);

0 commit comments

Comments
 (0)