Skip to content

Commit 7e03356

Browse files
Add Minimal Service Settings to the Model Registry (#120560) (#120946)
This commit introduces minimal service settings in the model registry, accessible without querying the inference index. These settings are now available for the default models exposed by the inference service. The ability to access settings without an inference index query is needed for the semantic text field, as it would benefit from eager validation of configuration during field creation. This is not feasible currently because retrieving service settings relies on an asynchronous call to the inference index. ### Follow-Up Plans: 1. Extend this capability to include minimal service settings for all newly added models, making them accessible via the cluster state. 2. Update the semantic text field to eagerly retrieve service settings directly from the model registry. Co-authored-by: Jonathan Buttner <[email protected]>
1 parent 751c1c5 commit 7e03356

File tree

15 files changed

+439
-254
lines changed

15 files changed

+439
-254
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ default boolean canStream(TaskType taskType) {
219219
return supportedStreamingTasks().contains(taskType);
220220
}
221221

222-
record DefaultConfigId(String inferenceId, TaskType taskType, InferenceService service) {};
222+
record DefaultConfigId(String inferenceId, MinimalServiceSettings settings, InferenceService service) {};
223223

224224
/**
225225
* Get the Ids and task type of any default configurations provided by this service
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.inference;
11+
12+
import org.elasticsearch.core.Nullable;
13+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
14+
import org.elasticsearch.xcontent.ConstructingObjectParser;
15+
import org.elasticsearch.xcontent.ParseField;
16+
import org.elasticsearch.xcontent.ToXContentObject;
17+
import org.elasticsearch.xcontent.XContentBuilder;
18+
import org.elasticsearch.xcontent.XContentParser;
19+
20+
import java.io.IOException;
21+
import java.util.Objects;
22+
23+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
24+
import static org.elasticsearch.inference.TaskType.COMPLETION;
25+
import static org.elasticsearch.inference.TaskType.RERANK;
26+
import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
27+
import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
28+
29+
/**
30+
* Defines the base settings required to configure an inference endpoint.
31+
*
32+
* These settings are immutable and describe the input and output types that the endpoint will handle.
33+
* They capture the essential properties of an inference model, ensuring the endpoint is correctly configured.
34+
*
35+
* Key properties include:
36+
* <ul>
37+
* <li>{@code taskType} - Specifies the type of task the model performs, such as classification or text embeddings.</li>
38+
* <li>{@code dimensions}, {@code similarity}, and {@code elementType} - These settings are applicable only when
39+
* the {@code taskType} is {@link TaskType#TEXT_EMBEDDING}. They define the structure and behavior of embeddings.</li>
40+
* </ul>
41+
*
42+
* @param taskType the type of task the inference model performs.
43+
* @param dimensions the number of dimensions for the embeddings, applicable only for {@link TaskType#TEXT_EMBEDDING} (nullable).
44+
* @param similarity the similarity measure used for embeddings, applicable only for {@link TaskType#TEXT_EMBEDDING} (nullable).
45+
* @param elementType the type of elements in the embeddings, applicable only for {@link TaskType#TEXT_EMBEDDING} (nullable).
46+
*/
47+
public record MinimalServiceSettings(
48+
TaskType taskType,
49+
@Nullable Integer dimensions,
50+
@Nullable SimilarityMeasure similarity,
51+
@Nullable ElementType elementType
52+
) implements ToXContentObject {
53+
54+
public static final String TASK_TYPE_FIELD = "task_type";
55+
static final String DIMENSIONS_FIELD = "dimensions";
56+
static final String SIMILARITY_FIELD = "similarity";
57+
static final String ELEMENT_TYPE_FIELD = "element_type";
58+
59+
private static final ConstructingObjectParser<MinimalServiceSettings, Void> PARSER = new ConstructingObjectParser<>(
60+
"model_settings",
61+
true,
62+
args -> {
63+
TaskType taskType = TaskType.fromString((String) args[0]);
64+
Integer dimensions = (Integer) args[1];
65+
SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]);
66+
DenseVectorFieldMapper.ElementType elementType = args[3] == null
67+
? null
68+
: DenseVectorFieldMapper.ElementType.fromString((String) args[3]);
69+
return new MinimalServiceSettings(taskType, dimensions, similarity, elementType);
70+
}
71+
);
72+
73+
static {
74+
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(TASK_TYPE_FIELD));
75+
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD));
76+
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SIMILARITY_FIELD));
77+
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ELEMENT_TYPE_FIELD));
78+
}
79+
80+
public static MinimalServiceSettings parse(XContentParser parser) throws IOException {
81+
return PARSER.parse(parser, null);
82+
}
83+
84+
public static MinimalServiceSettings textEmbedding(int dimensions, SimilarityMeasure similarity, ElementType elementType) {
85+
return new MinimalServiceSettings(TEXT_EMBEDDING, dimensions, similarity, elementType);
86+
}
87+
88+
public static MinimalServiceSettings sparseEmbedding() {
89+
return new MinimalServiceSettings(SPARSE_EMBEDDING, null, null, null);
90+
}
91+
92+
public static MinimalServiceSettings rerank() {
93+
return new MinimalServiceSettings(RERANK, null, null, null);
94+
}
95+
96+
public static MinimalServiceSettings completion() {
97+
return new MinimalServiceSettings(COMPLETION, null, null, null);
98+
}
99+
100+
public MinimalServiceSettings(Model model) {
101+
this(
102+
model.getTaskType(),
103+
model.getServiceSettings().dimensions(),
104+
model.getServiceSettings().similarity(),
105+
model.getServiceSettings().elementType()
106+
);
107+
}
108+
109+
public MinimalServiceSettings(
110+
TaskType taskType,
111+
@Nullable Integer dimensions,
112+
@Nullable SimilarityMeasure similarity,
113+
@Nullable ElementType elementType
114+
) {
115+
this.taskType = Objects.requireNonNull(taskType, "task type must not be null");
116+
this.dimensions = dimensions;
117+
this.similarity = similarity;
118+
this.elementType = elementType;
119+
validate();
120+
}
121+
122+
@Override
123+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
124+
builder.startObject();
125+
builder.field(TASK_TYPE_FIELD, taskType.toString());
126+
if (dimensions != null) {
127+
builder.field(DIMENSIONS_FIELD, dimensions);
128+
}
129+
if (similarity != null) {
130+
builder.field(SIMILARITY_FIELD, similarity);
131+
}
132+
if (elementType != null) {
133+
builder.field(ELEMENT_TYPE_FIELD, elementType);
134+
}
135+
return builder.endObject();
136+
}
137+
138+
@Override
139+
public String toString() {
140+
final StringBuilder sb = new StringBuilder();
141+
sb.append("task_type=").append(taskType);
142+
if (dimensions != null) {
143+
sb.append(", dimensions=").append(dimensions);
144+
}
145+
if (similarity != null) {
146+
sb.append(", similarity=").append(similarity);
147+
}
148+
if (elementType != null) {
149+
sb.append(", element_type=").append(elementType);
150+
}
151+
return sb.toString();
152+
}
153+
154+
private void validate() {
155+
switch (taskType) {
156+
case TEXT_EMBEDDING:
157+
validateFieldPresent(DIMENSIONS_FIELD, dimensions);
158+
validateFieldPresent(SIMILARITY_FIELD, similarity);
159+
validateFieldPresent(ELEMENT_TYPE_FIELD, elementType);
160+
break;
161+
162+
default:
163+
validateFieldNotPresent(DIMENSIONS_FIELD, dimensions);
164+
validateFieldNotPresent(SIMILARITY_FIELD, similarity);
165+
validateFieldNotPresent(ELEMENT_TYPE_FIELD, elementType);
166+
break;
167+
}
168+
}
169+
170+
private void validateFieldPresent(String field, Object fieldValue) {
171+
if (fieldValue == null) {
172+
throw new IllegalArgumentException("required [" + field + "] field is missing for task_type [" + taskType.name() + "]");
173+
}
174+
}
175+
176+
private void validateFieldNotPresent(String field, Object fieldValue) {
177+
if (fieldValue != null) {
178+
throw new IllegalArgumentException("[" + field + "] is not allowed for task_type [" + taskType.name() + "]");
179+
}
180+
}
181+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.inference;
11+
12+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
13+
import org.elasticsearch.test.AbstractXContentTestCase;
14+
import org.elasticsearch.xcontent.XContentParser;
15+
16+
import java.io.IOException;
17+
18+
public class MinimalServiceSettingsTests extends AbstractXContentTestCase<MinimalServiceSettings> {
19+
@Override
20+
protected MinimalServiceSettings createTestInstance() {
21+
TaskType taskType = randomFrom(TaskType.values());
22+
Integer dimensions = null;
23+
SimilarityMeasure similarity = null;
24+
DenseVectorFieldMapper.ElementType elementType = null;
25+
26+
if (taskType == TaskType.TEXT_EMBEDDING) {
27+
dimensions = randomIntBetween(2, 1024);
28+
similarity = randomFrom(SimilarityMeasure.values());
29+
elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
30+
}
31+
return new MinimalServiceSettings(taskType, dimensions, similarity, elementType);
32+
}
33+
34+
@Override
35+
protected MinimalServiceSettings doParseInstance(XContentParser parser) throws IOException {
36+
return MinimalServiceSettings.parse(parser);
37+
}
38+
39+
@Override
40+
protected boolean supportsUnknownFields() {
41+
return false;
42+
}
43+
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
import org.elasticsearch.common.io.stream.StreamOutput;
1616
import org.elasticsearch.common.settings.Settings;
1717
import org.elasticsearch.index.IndexNotFoundException;
18+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1819
import org.elasticsearch.inference.InferenceService;
1920
import org.elasticsearch.inference.InferenceServiceExtension;
21+
import org.elasticsearch.inference.MinimalServiceSettings;
2022
import org.elasticsearch.inference.Model;
2123
import org.elasticsearch.inference.ModelConfigurations;
2224
import org.elasticsearch.inference.ModelSecrets;
2325
import org.elasticsearch.inference.SecretSettings;
2426
import org.elasticsearch.inference.ServiceSettings;
27+
import org.elasticsearch.inference.SimilarityMeasure;
2528
import org.elasticsearch.inference.TaskSettings;
2629
import org.elasticsearch.inference.TaskType;
2730
import org.elasticsearch.inference.UnparsedModel;
@@ -34,6 +37,7 @@
3437
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
3538
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
3639
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
40+
import org.elasticsearch.xpack.inference.registry.ModelRegistryTests;
3741
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
3842
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
3943
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests;
@@ -305,9 +309,9 @@ public void testGetAllModels_WithDefaults() throws Exception {
305309
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
306310
for (int i = 0; i < defaultModelCount; i++) {
307311
var id = "default-" + i;
308-
var taskType = randomFrom(TaskType.values());
309-
defaultConfigs.add(createModel(id, taskType, serviceName));
310-
defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
312+
var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
313+
defaultConfigs.add(createModel(id, modelSettings.taskType(), serviceName));
314+
defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
311315
}
312316

313317
doAnswer(invocation -> {
@@ -371,9 +375,9 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
371375
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
372376
for (int i = 0; i < defaultModelCount; i++) {
373377
var id = "default-" + i;
374-
var taskType = randomFrom(TaskType.values());
375-
defaultConfigs.add(createModel(id, taskType, serviceName));
376-
defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
378+
var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
379+
defaultConfigs.add(createModel(id, modelSettings.taskType(), serviceName));
380+
defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
377381
}
378382

379383
doAnswer(invocation -> {
@@ -414,9 +418,9 @@ public void testGetAllModels_withDoNotPersist() throws Exception {
414418
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
415419
for (int i = 0; i < defaultModelCount; i++) {
416420
var id = "default-" + i;
417-
var taskType = randomFrom(TaskType.values());
418-
defaultConfigs.add(createModel(id, taskType, serviceName));
419-
defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
421+
var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
422+
defaultConfigs.add(createModel(id, modelSettings.taskType(), serviceName));
423+
defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
420424
}
421425

422426
doAnswer(invocation -> {
@@ -452,8 +456,14 @@ public void testGet_WithDefaults() throws InterruptedException {
452456

453457
defaultConfigs.add(createModel("default-sparse", TaskType.SPARSE_EMBEDDING, serviceName));
454458
defaultConfigs.add(createModel("default-text", TaskType.TEXT_EMBEDDING, serviceName));
455-
defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
456-
defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
459+
defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", MinimalServiceSettings.sparseEmbedding(), service));
460+
defaultIds.add(
461+
new InferenceService.DefaultConfigId(
462+
"default-text",
463+
MinimalServiceSettings.textEmbedding(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
464+
service
465+
)
466+
);
457467

458468
doAnswer(invocation -> {
459469
@SuppressWarnings("unchecked")
@@ -499,9 +509,15 @@ public void testGetByTaskType_WithDefaults() throws Exception {
499509

500510
var service = mock(InferenceService.class);
501511
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
502-
defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
503-
defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
504-
defaultIds.add(new InferenceService.DefaultConfigId("default-chat", TaskType.COMPLETION, service));
512+
defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", MinimalServiceSettings.sparseEmbedding(), service));
513+
defaultIds.add(
514+
new InferenceService.DefaultConfigId(
515+
"default-text",
516+
MinimalServiceSettings.textEmbedding(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
517+
service
518+
)
519+
);
520+
defaultIds.add(new InferenceService.DefaultConfigId("default-chat", MinimalServiceSettings.completion(), service));
505521

506522
doAnswer(invocation -> {
507523
@SuppressWarnings("unchecked")

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.elasticsearch.inference.InferenceService;
3636
import org.elasticsearch.inference.InferenceServiceRegistry;
3737
import org.elasticsearch.inference.InputType;
38+
import org.elasticsearch.inference.MinimalServiceSettings;
3839
import org.elasticsearch.inference.Model;
3940
import org.elasticsearch.inference.UnparsedModel;
4041
import org.elasticsearch.rest.RestStatus;
@@ -438,7 +439,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
438439
useLegacyFormat ? inputs : null,
439440
new SemanticTextField.InferenceResult(
440441
inferenceFieldMetadata.getInferenceId(),
441-
model != null ? new SemanticTextField.ModelSettings(model) : null,
442+
model != null ? new MinimalServiceSettings(model) : null,
442443
chunkMap
443444
),
444445
indexRequest.getContentType()

0 commit comments

Comments
 (0)