Skip to content

Commit 270ec53

Browse files
authored
Add ModelRegistryMetadata to Cluster State (elastic#121106)
This commit integrates `MinimalServiceSettings` (introduced in elastic#120560) into the cluster state for all registered models in the `ModelRegistry`. These settings allow consumers to access configuration details without requiring asynchronous calls to retrieve full model configurations. To ensure consistency, the cluster state metadata must remain synchronized with the models in the inference index. If a mismatch is detected during startup, the master node performs an upgrade to load all model settings from the index.
1 parent d20528b commit 270ec53

File tree

41 files changed

+1473
-614
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1473
-614
lines changed

docs/changelog/121106.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 121106
2+
summary: Add `ModelRegistryMetadata` to Cluster State
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ static TransportVersion def(int id) {
187187
public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_029_00_0);
188188
public static final TransportVersion ESQL_FAILURE_FROM_REMOTE = def(9_030_00_0);
189189
public static final TransportVersion INDEX_RESHARDING_METADATA = def(9_031_0_00);
190+
public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA = def(9_032_0_00);
190191

191192
/*
192193
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java

Lines changed: 111 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99

1010
package org.elasticsearch.inference;
1111

12+
import org.elasticsearch.TransportVersion;
13+
import org.elasticsearch.TransportVersions;
14+
import org.elasticsearch.cluster.Diff;
15+
import org.elasticsearch.cluster.SimpleDiffable;
16+
import org.elasticsearch.common.io.stream.StreamInput;
17+
import org.elasticsearch.common.io.stream.StreamOutput;
1218
import org.elasticsearch.core.Nullable;
1319
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1420
import org.elasticsearch.xcontent.ConstructingObjectParser;
@@ -46,12 +52,16 @@
4652
* @param elementType the type of elements in the embeddings, applicable only for {@link TaskType#TEXT_EMBEDDING} (nullable).
4753
*/
4854
public record MinimalServiceSettings(
55+
@Nullable String service,
4956
TaskType taskType,
5057
@Nullable Integer dimensions,
5158
@Nullable SimilarityMeasure similarity,
5259
@Nullable ElementType elementType
53-
) implements ToXContentObject {
60+
) implements ServiceSettings, SimpleDiffable<MinimalServiceSettings> {
5461

62+
public static final String NAME = "minimal_service_settings";
63+
64+
public static final String SERVICE_FIELD = "service";
5565
public static final String TASK_TYPE_FIELD = "task_type";
5666
static final String DIMENSIONS_FIELD = "dimensions";
5767
static final String SIMILARITY_FIELD = "similarity";
@@ -61,17 +71,20 @@ public record MinimalServiceSettings(
6171
"model_settings",
6272
true,
6373
args -> {
64-
TaskType taskType = TaskType.fromString((String) args[0]);
65-
Integer dimensions = (Integer) args[1];
66-
SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]);
67-
DenseVectorFieldMapper.ElementType elementType = args[3] == null
74+
String service = (String) args[0];
75+
TaskType taskType = TaskType.fromString((String) args[1]);
76+
Integer dimensions = (Integer) args[2];
77+
SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]);
78+
DenseVectorFieldMapper.ElementType elementType = args[4] == null
6879
? null
69-
: DenseVectorFieldMapper.ElementType.fromString((String) args[3]);
70-
return new MinimalServiceSettings(taskType, dimensions, similarity, elementType);
80+
: DenseVectorFieldMapper.ElementType.fromString((String) args[4]);
81+
return new MinimalServiceSettings(service, taskType, dimensions, similarity, elementType);
7182
}
7283
);
84+
private static final String UNKNOWN_SERVICE = "_unknown_";
7385

7486
static {
87+
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SERVICE_FIELD));
7588
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(TASK_TYPE_FIELD));
7689
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD));
7790
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SIMILARITY_FIELD));
@@ -82,51 +95,95 @@ public static MinimalServiceSettings parse(XContentParser parser) throws IOExcep
8295
return PARSER.parse(parser, null);
8396
}
8497

85-
public static MinimalServiceSettings textEmbedding(int dimensions, SimilarityMeasure similarity, ElementType elementType) {
86-
return new MinimalServiceSettings(TEXT_EMBEDDING, dimensions, similarity, elementType);
98+
public static MinimalServiceSettings textEmbedding(
99+
String serviceName,
100+
int dimensions,
101+
SimilarityMeasure similarity,
102+
ElementType elementType
103+
) {
104+
return new MinimalServiceSettings(serviceName, TEXT_EMBEDDING, dimensions, similarity, elementType);
105+
}
106+
107+
public static MinimalServiceSettings sparseEmbedding(String serviceName) {
108+
return new MinimalServiceSettings(serviceName, SPARSE_EMBEDDING, null, null, null);
87109
}
88110

89-
public static MinimalServiceSettings sparseEmbedding() {
90-
return new MinimalServiceSettings(SPARSE_EMBEDDING, null, null, null);
111+
public static MinimalServiceSettings rerank(String serviceName) {
112+
return new MinimalServiceSettings(serviceName, RERANK, null, null, null);
91113
}
92114

93-
public static MinimalServiceSettings rerank() {
94-
return new MinimalServiceSettings(RERANK, null, null, null);
115+
public static MinimalServiceSettings completion(String serviceName) {
116+
return new MinimalServiceSettings(serviceName, COMPLETION, null, null, null);
95117
}
96118

97-
public static MinimalServiceSettings completion() {
98-
return new MinimalServiceSettings(COMPLETION, null, null, null);
119+
public static MinimalServiceSettings chatCompletion(String serviceName) {
120+
return new MinimalServiceSettings(serviceName, CHAT_COMPLETION, null, null, null);
99121
}
100122

101-
public static MinimalServiceSettings chatCompletion() {
102-
return new MinimalServiceSettings(CHAT_COMPLETION, null, null, null);
123+
public MinimalServiceSettings {
124+
Objects.requireNonNull(taskType, "task type must not be null");
125+
validate(taskType, dimensions, similarity, elementType);
103126
}
104127

105128
public MinimalServiceSettings(Model model) {
106129
this(
130+
model.getConfigurations().getService(),
107131
model.getTaskType(),
108132
model.getServiceSettings().dimensions(),
109133
model.getServiceSettings().similarity(),
110134
model.getServiceSettings().elementType()
111135
);
112136
}
113137

114-
public MinimalServiceSettings(
115-
TaskType taskType,
116-
@Nullable Integer dimensions,
117-
@Nullable SimilarityMeasure similarity,
118-
@Nullable ElementType elementType
119-
) {
120-
this.taskType = Objects.requireNonNull(taskType, "task type must not be null");
121-
this.dimensions = dimensions;
122-
this.similarity = similarity;
123-
this.elementType = elementType;
124-
validate();
138+
public MinimalServiceSettings(StreamInput in) throws IOException {
139+
this(
140+
in.readOptionalString(),
141+
TaskType.fromStream(in),
142+
in.readOptionalInt(),
143+
in.readOptionalEnum(SimilarityMeasure.class),
144+
in.readOptionalEnum(ElementType.class)
145+
);
146+
}
147+
148+
@Override
149+
public void writeTo(StreamOutput out) throws IOException {
150+
out.writeOptionalString(service);
151+
taskType.writeTo(out);
152+
out.writeOptionalInt(dimensions);
153+
out.writeOptionalEnum(similarity);
154+
out.writeOptionalEnum(elementType);
155+
}
156+
157+
@Override
158+
public String getWriteableName() {
159+
return NAME;
160+
}
161+
162+
@Override
163+
public TransportVersion getMinimalSupportedVersion() {
164+
return TransportVersions.INFERENCE_MODEL_REGISTRY_METADATA;
165+
}
166+
167+
@Override
168+
public ToXContentObject getFilteredXContentObject() {
169+
return this::toXContent;
170+
}
171+
172+
@Override
173+
public String modelId() {
174+
return null;
175+
}
176+
177+
public static Diff<MinimalServiceSettings> readDiffFrom(StreamInput in) throws IOException {
178+
return SimpleDiffable.readDiffFrom(MinimalServiceSettings::new, in);
125179
}
126180

127181
@Override
128182
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
129183
builder.startObject();
184+
if (service != null) {
185+
builder.field(SERVICE_FIELD, service);
186+
}
130187
builder.field(TASK_TYPE_FIELD, taskType.toString());
131188
if (dimensions != null) {
132189
builder.field(DIMENSIONS_FIELD, dimensions);
@@ -143,7 +200,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
143200
@Override
144201
public String toString() {
145202
final StringBuilder sb = new StringBuilder();
146-
sb.append("task_type=").append(taskType);
203+
sb.append("service=").append(service);
204+
sb.append(", task_type=").append(taskType);
147205
if (dimensions != null) {
148206
sb.append(", dimensions=").append(dimensions);
149207
}
@@ -156,31 +214,46 @@ public String toString() {
156214
return sb.toString();
157215
}
158216

159-
private void validate() {
217+
private static void validate(TaskType taskType, Integer dimensions, SimilarityMeasure similarity, ElementType elementType) {
160218
switch (taskType) {
161219
case TEXT_EMBEDDING:
162-
validateFieldPresent(DIMENSIONS_FIELD, dimensions);
163-
validateFieldPresent(SIMILARITY_FIELD, similarity);
164-
validateFieldPresent(ELEMENT_TYPE_FIELD, elementType);
220+
validateFieldPresent(DIMENSIONS_FIELD, dimensions, taskType);
221+
validateFieldPresent(SIMILARITY_FIELD, similarity, taskType);
222+
validateFieldPresent(ELEMENT_TYPE_FIELD, elementType, taskType);
165223
break;
166224

167225
default:
168-
validateFieldNotPresent(DIMENSIONS_FIELD, dimensions);
169-
validateFieldNotPresent(SIMILARITY_FIELD, similarity);
170-
validateFieldNotPresent(ELEMENT_TYPE_FIELD, elementType);
226+
validateFieldNotPresent(DIMENSIONS_FIELD, dimensions, taskType);
227+
validateFieldNotPresent(SIMILARITY_FIELD, similarity, taskType);
228+
validateFieldNotPresent(ELEMENT_TYPE_FIELD, elementType, taskType);
171229
break;
172230
}
173231
}
174232

175-
private void validateFieldPresent(String field, Object fieldValue) {
233+
private static void validateFieldPresent(String field, Object fieldValue, TaskType taskType) {
176234
if (fieldValue == null) {
177235
throw new IllegalArgumentException("required [" + field + "] field is missing for task_type [" + taskType.name() + "]");
178236
}
179237
}
180238

181-
private void validateFieldNotPresent(String field, Object fieldValue) {
239+
private static void validateFieldNotPresent(String field, Object fieldValue, TaskType taskType) {
182240
if (fieldValue != null) {
183241
throw new IllegalArgumentException("[" + field + "] is not allowed for task_type [" + taskType.name() + "]");
184242
}
185243
}
244+
245+
public ModelConfigurations toModelConfigurations(String inferenceEntityId) {
246+
return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this);
247+
}
248+
249+
/**
250+
* Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition.
251+
*/
252+
public boolean canMergeWith(MinimalServiceSettings other) {
253+
return taskType == other.taskType
254+
&& Objects.equals(dimensions, other.dimensions)
255+
&& similarity == other.similarity
256+
&& elementType == other.elementType
257+
&& (service == null || service.equals(other.service));
258+
}
186259
}

server/src/test/java/org/elasticsearch/inference/MinimalServiceSettingsTests.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
import java.io.IOException;
1717

1818
public class MinimalServiceSettingsTests extends AbstractXContentTestCase<MinimalServiceSettings> {
19-
@Override
20-
protected MinimalServiceSettings createTestInstance() {
19+
public static MinimalServiceSettings randomInstance() {
2120
TaskType taskType = randomFrom(TaskType.values());
2221
Integer dimensions = null;
2322
SimilarityMeasure similarity = null;
@@ -28,7 +27,12 @@ protected MinimalServiceSettings createTestInstance() {
2827
similarity = randomFrom(SimilarityMeasure.values());
2928
elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
3029
}
31-
return new MinimalServiceSettings(taskType, dimensions, similarity, elementType);
30+
return new MinimalServiceSettings(randomBoolean() ? null : randomAlphaOfLength(10), taskType, dimensions, similarity, elementType);
31+
}
32+
33+
@Override
34+
protected MinimalServiceSettings createTestInstance() {
35+
return randomInstance();
3236
}
3337

3438
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,23 @@ public static class Request extends AcknowledgedRequest<GetInferenceModelAction.
4444
// no effect when getting a single model
4545
private final boolean persistDefaultConfig;
4646

47+
// For testing only, retrieves the minimal config from the cluster state.
48+
private final boolean returnMinimalConfig;
49+
4750
public Request(String inferenceEntityId, TaskType taskType) {
48-
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
49-
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
50-
this.taskType = Objects.requireNonNull(taskType);
51-
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
51+
this(inferenceEntityId, taskType, PERSIST_DEFAULT_CONFIGS);
5252
}
5353

5454
public Request(String inferenceEntityId, TaskType taskType, boolean persistDefaultConfig) {
55+
this(inferenceEntityId, taskType, persistDefaultConfig, false);
56+
}
57+
58+
public Request(String inferenceEntityId, TaskType taskType, boolean persistDefaultConfig, boolean returnMinimalConfig) {
5559
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
5660
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
5761
this.taskType = Objects.requireNonNull(taskType);
5862
this.persistDefaultConfig = persistDefaultConfig;
63+
this.returnMinimalConfig = returnMinimalConfig;
5964
}
6065

6166
public Request(StreamInput in) throws IOException {
@@ -68,6 +73,12 @@ public Request(StreamInput in) throws IOException {
6873
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
6974
}
7075

76+
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_MODEL_REGISTRY_METADATA)) {
77+
this.returnMinimalConfig = in.readBoolean();
78+
} else {
79+
this.returnMinimalConfig = false;
80+
}
81+
7182
}
7283

7384
public String getInferenceEntityId() {
@@ -82,6 +93,10 @@ public boolean isPersistDefaultConfig() {
8293
return persistDefaultConfig;
8394
}
8495

96+
public boolean isReturnMinimalConfig() {
97+
return returnMinimalConfig;
98+
}
99+
85100
@Override
86101
public void writeTo(StreamOutput out) throws IOException {
87102
super.writeTo(out);
@@ -90,6 +105,10 @@ public void writeTo(StreamOutput out) throws IOException {
90105
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
91106
out.writeBoolean(this.persistDefaultConfig);
92107
}
108+
109+
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_MODEL_REGISTRY_METADATA)) {
110+
out.writeBoolean(returnMinimalConfig);
111+
}
93112
}
94113

95114
@Override
@@ -99,12 +118,13 @@ public boolean equals(Object o) {
99118
Request request = (Request) o;
100119
return Objects.equals(inferenceEntityId, request.inferenceEntityId)
101120
&& taskType == request.taskType
102-
&& persistDefaultConfig == request.persistDefaultConfig;
121+
&& persistDefaultConfig == request.persistDefaultConfig
122+
&& returnMinimalConfig == request.returnMinimalConfig;
103123
}
104124

105125
@Override
106126
public int hashCode() {
107-
return Objects.hash(inferenceEntityId, taskType, persistDefaultConfig);
127+
return Objects.hash(inferenceEntityId, taskType, persistDefaultConfig, returnMinimalConfig);
108128
}
109129
}
110130

x-pack/plugin/esql/qa/server/multi-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/multi_node/SemanticMatchIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
1818
public class SemanticMatchIT extends SemanticMatchTestCase {
1919
@ClassRule
20-
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test"));
20+
public static ElasticsearchCluster cluster = Clusters.testCluster(
21+
spec -> spec.module("x-pack-inference").plugin("inference-service-test")
22+
);
2123

2224
@Override
2325
protected String getTestRestCluster() {

0 commit comments

Comments
 (0)