Skip to content

Commit df159ce

Browse files
Allow updating inference_id of semantic_text fields (elastic#136120)
Previously the `inference_id` of `semantic_text` fields was not updatable. This commit allows users to update the `inference_id` of a `semantic_text` field. This is particularly useful for scenarios where the user wants to switch to using the same model but from a different service. There are two circumstances when the update is allowed. - No values have been written for the `semantic_text` field. The inference endpoint can be changed freely as there is no need for compatibility between the current and the new endpoint. - The new inference endpoint is compatible with the previous one. The `model_settings` of the new inference endpoint are compatible with those of the current endpoint, thus the update is allowed.
1 parent b6e36eb commit df159ce

File tree

9 files changed

+1067
-65
lines changed

9 files changed

+1067
-65
lines changed

docs/changelog/136120.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 136120
2+
summary: Allow updating `inference_id` of `semantic_text` fields
3+
area: "Mapping"
4+
type: enhancement
5+
issues: []

docs/reference/elasticsearch/mapping-reference/semantic-text.md

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,31 @@ While we do encourage experimentation, we do not recommend implementing producti
142142

143143
`inference_id`
144144
: (Optional, string) {{infer-cap}} endpoint that will be used to generate
145-
embeddings for the field. By default, `.elser-2-elasticsearch` is used. This
146-
parameter cannot be updated. Use
147-
the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put)
145+
embeddings for the field. By default, `.elser-2-elasticsearch` is used.
146+
Use the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put)
148147
to create the endpoint. If `search_inference_id` is specified, the {{infer}}
149148
endpoint will only be used at index time.
150149

150+
::::{applies-switch}
151+
152+
:::{applies-item} { "stack": "ga 9.0" }
153+
This parameter cannot be updated.
154+
:::
155+
156+
:::{applies-item} { "stack": "ga 9.3" }
157+
158+
You can update this parameter by using
159+
the [Update mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping).
160+
You can update the inference endpoint if no values have been indexed or if the new endpoint is compatible with the current one.
161+
162+
::::{warning}
163+
When updating an `inference_id` it is important to ensure the new {{infer}} endpoint produces embeddings compatible with those already indexed. This typically means using the same underlying model.
164+
::::
165+
166+
:::
167+
168+
::::
169+
151170
`search_inference_id`
152171
: (Optional, string) {{infer-cap}} endpoint that will be used to generate
153172
embeddings at query time. You can update this parameter by using

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
7474
"completion_test_service",
7575
"test_reranking_service",
7676
"test_service",
77+
"alternate_sparse_embedding_test_service",
7778
"text_embedding_test_service",
7879
"voyageai",
7980
"watsonxai",
@@ -209,6 +210,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
209210
"hugging_face",
210211
"streaming_completion_test_service",
211212
"test_service",
213+
"alternate_sparse_embedding_test_service",
212214
"amazon_sagemaker"
213215
).toArray()
214216
)

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import org.elasticsearch.common.ValidationException;
1414
import org.elasticsearch.common.io.stream.StreamInput;
1515
import org.elasticsearch.common.io.stream.StreamOutput;
16-
import org.elasticsearch.common.util.LazyInitializable;
1716
import org.elasticsearch.core.Nullable;
1817
import org.elasticsearch.core.TimeValue;
1918
import org.elasticsearch.inference.ChunkInferenceInput;
@@ -43,12 +42,13 @@
4342
import java.util.HashMap;
4443
import java.util.List;
4544
import java.util.Map;
45+
import java.util.Objects;
4646

4747
public class TestSparseInferenceServiceExtension implements InferenceServiceExtension {
4848

4949
@Override
5050
public List<Factory> getInferenceServiceFactories() {
51-
return List.of(TestInferenceService::new);
51+
return List.of(TestInferenceService::new, TestAlternateSparseInferenceService::new);
5252
}
5353

5454
public static class TestSparseModel extends Model {
@@ -60,16 +60,40 @@ public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSett
6060
}
6161
}
6262

63-
public static class TestInferenceService extends AbstractTestInferenceService {
63+
public static class TestInferenceService extends AbstractSparseTestInferenceService {
6464
public static final String NAME = "test_service";
6565

66+
public TestInferenceService(InferenceServiceFactoryContext inferenceServiceFactoryContext) {}
67+
68+
@Override
69+
protected String testServiceName() {
70+
return NAME;
71+
}
72+
}
73+
74+
/**
75+
* A second sparse service allows testing updates from one service to another.
76+
*/
77+
public static class TestAlternateSparseInferenceService extends AbstractSparseTestInferenceService {
78+
public static final String NAME = "alternate_sparse_embedding_test_service";
79+
80+
public TestAlternateSparseInferenceService(InferenceServiceFactoryContext inferenceServiceFactoryContext) {}
81+
82+
@Override
83+
protected String testServiceName() {
84+
return NAME;
85+
}
86+
}
87+
88+
abstract static class AbstractSparseTestInferenceService extends AbstractTestInferenceService {
89+
6690
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING);
6791

68-
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
92+
protected abstract String testServiceName();
6993

7094
@Override
7195
public String name() {
72-
return NAME;
96+
return testServiceName();
7397
}
7498

7599
@Override
@@ -92,7 +116,7 @@ public void parseRequestConfig(
92116

93117
@Override
94118
public InferenceServiceConfiguration getConfiguration() {
95-
return Configuration.get();
119+
return new Configuration(testServiceName()).get();
96120
}
97121

98122
@Override
@@ -195,41 +219,43 @@ private static float generateEmbedding(String input, int position) {
195219
}
196220

197221
public static class Configuration {
198-
public static InferenceServiceConfiguration get() {
199-
return configuration.getOrCompute();
222+
223+
private final String serviceName;
224+
225+
Configuration(String serviceName) {
226+
this.serviceName = Objects.requireNonNull(serviceName);
227+
}
228+
229+
InferenceServiceConfiguration get() {
230+
var configurationMap = new HashMap<String, SettingsConfiguration>();
231+
232+
configurationMap.put(
233+
"model",
234+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
235+
.setLabel("Model")
236+
.setRequired(true)
237+
.setSensitive(false)
238+
.setType(SettingsConfigurationFieldType.STRING)
239+
.build()
240+
);
241+
242+
configurationMap.put(
243+
"hidden_field",
244+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
245+
.setLabel("Hidden Field")
246+
.setRequired(true)
247+
.setSensitive(false)
248+
.setType(SettingsConfigurationFieldType.STRING)
249+
.build()
250+
);
251+
252+
return new InferenceServiceConfiguration.Builder().setService(serviceName)
253+
.setName(serviceName)
254+
.setTaskTypes(supportedTaskTypes)
255+
.setConfigurations(configurationMap)
256+
.build();
200257
}
201258

202-
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
203-
() -> {
204-
var configurationMap = new HashMap<String, SettingsConfiguration>();
205-
206-
configurationMap.put(
207-
"model",
208-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
209-
.setLabel("Model")
210-
.setRequired(true)
211-
.setSensitive(false)
212-
.setType(SettingsConfigurationFieldType.STRING)
213-
.build()
214-
);
215-
216-
configurationMap.put(
217-
"hidden_field",
218-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
219-
.setLabel("Hidden Field")
220-
.setRequired(true)
221-
.setSensitive(false)
222-
.setType(SettingsConfigurationFieldType.STRING)
223-
.build()
224-
);
225-
226-
return new InferenceServiceConfiguration.Builder().setService(NAME)
227-
.setName(NAME)
228-
.setTaskTypes(supportedTaskTypes)
229-
.setConfigurations(configurationMap)
230-
.build();
231-
}
232-
);
233259
}
234260
}
235261

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS;
2525
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS;
2626
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG;
27+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_UPDATABLE_INFERENCE_ID;
2728
import static org.elasticsearch.xpack.inference.queries.LegacySemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_FILTER_FIX;
2829
import static org.elasticsearch.xpack.inference.queries.LegacySemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
2930
import static org.elasticsearch.xpack.inference.queries.LegacySemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
@@ -93,6 +94,7 @@ public Set<NodeFeature> getTestFeatures() {
9394
SEMANTIC_TEXT_HIGHLIGHTING_FLAT,
9495
SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS,
9596
SEMANTIC_TEXT_FIELDS_CHUNKS_FORMAT,
97+
SEMANTIC_TEXT_UPDATABLE_INFERENCE_ID,
9698
SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS,
9799
SemanticQueryBuilder.SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX,
98100
InterceptedInferenceQueryBuilder.NEW_SEMANTIC_QUERY_INTERCEPTORS,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
150150
public static final NodeFeature SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS = new NodeFeature(
151151
"semantic_text.sparse_vector_index_options"
152152
);
153+
public static final NodeFeature SEMANTIC_TEXT_UPDATABLE_INFERENCE_ID = new NodeFeature("semantic_text.updatable_inference_id");
153154

154155
public static final String CONTENT_TYPE = "semantic_text";
155156
public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID;
@@ -243,7 +244,7 @@ public Builder(
243244

244245
this.inferenceId = Parameter.stringParam(
245246
INFERENCE_ID_FIELD,
246-
false,
247+
true,
247248
mapper -> ((SemanticTextFieldType) mapper.fieldType()).inferenceId,
248249
DEFAULT_ELSER_2_INFERENCE_ID
249250
).addValidator(v -> {
@@ -321,9 +322,65 @@ protected Parameter<?>[] getParameters() {
321322
@Override
322323
protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) {
323324
SemanticTextFieldMapper semanticMergeWith = (SemanticTextFieldMapper) mergeWith;
324-
semanticMergeWith = copySettings(semanticMergeWith, mapperMergeContext);
325325

326-
// We make sure to merge the inference field first to catch any model conflicts
326+
final boolean isInferenceIdUpdate = semanticMergeWith.fieldType().inferenceId.equals(inferenceId.get()) == false;
327+
final boolean hasExplicitModelSettings = modelSettings.get() != null;
328+
329+
MinimalServiceSettings updatedModelSettings = modelSettings.get();
330+
if (isInferenceIdUpdate && hasExplicitModelSettings) {
331+
validateModelsAreCompatibleWhenInferenceIdIsUpdated(semanticMergeWith.fieldType().inferenceId, conflicts);
332+
// As the mapper previously had explicit model settings, we need to apply to the new merged mapper
333+
// the resolved model settings if not explicitly set.
334+
updatedModelSettings = modelRegistry.getMinimalServiceSettings(semanticMergeWith.fieldType().inferenceId);
335+
}
336+
337+
semanticMergeWith = copyWithNewModelSettingsIfNotSet(semanticMergeWith, updatedModelSettings, mapperMergeContext);
338+
339+
// We make sure to merge the inference field first to catch any model conflicts.
340+
// If inference_id is updated and there are no explicit model settings, we should be
341+
// able to switch to the new inference field without the need to check for conflicts.
342+
if (isInferenceIdUpdate == false || hasExplicitModelSettings) {
343+
mergeInferenceField(mapperMergeContext, semanticMergeWith);
344+
}
345+
346+
super.merge(semanticMergeWith, conflicts, mapperMergeContext);
347+
conflicts.check();
348+
}
349+
350+
private void validateModelsAreCompatibleWhenInferenceIdIsUpdated(String newInferenceId, Conflicts conflicts) {
351+
MinimalServiceSettings currentModelSettings = modelSettings.get();
352+
MinimalServiceSettings updatedModelSettings = modelRegistry.getMinimalServiceSettings(newInferenceId);
353+
if (currentModelSettings != null && updatedModelSettings == null) {
354+
throw new IllegalArgumentException(
355+
"Cannot update ["
356+
+ CONTENT_TYPE
357+
+ "] field ["
358+
+ leafName()
359+
+ "] because inference endpoint ["
360+
+ newInferenceId
361+
+ "] does not exist."
362+
);
363+
}
364+
if (canMergeModelSettings(currentModelSettings, updatedModelSettings, conflicts) == false) {
365+
throw new IllegalArgumentException(
366+
"Cannot update ["
367+
+ CONTENT_TYPE
368+
+ "] field ["
369+
+ leafName()
370+
+ "] because inference endpoint ["
371+
+ inferenceId.get()
372+
+ "] with model settings ["
373+
+ currentModelSettings
374+
+ "] is not compatible with new inference endpoint ["
375+
+ newInferenceId
376+
+ "] with model settings ["
377+
+ updatedModelSettings
378+
+ "]."
379+
);
380+
}
381+
}
382+
383+
private void mergeInferenceField(MapperMergeContext mapperMergeContext, SemanticTextFieldMapper semanticMergeWith) {
327384
try {
328385
var context = mapperMergeContext.createChildContext(semanticMergeWith.leafName(), ObjectMapper.Dynamic.FALSE);
329386
var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext());
@@ -336,9 +393,6 @@ protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeCont
336393
: "";
337394
throw new IllegalArgumentException(errorMessage, e);
338395
}
339-
340-
super.merge(semanticMergeWith, conflicts, mapperMergeContext);
341-
conflicts.check();
342396
}
343397

344398
/**
@@ -494,18 +548,23 @@ private void validateIndexOptions(SemanticTextIndexOptions indexOptions, String
494548
}
495549

496550
/**
497-
* As necessary, copy settings from this builder to the passed-in mapper.
498-
* Used to preserve {@link MinimalServiceSettings} when updating a semantic text mapping to one where the model settings
499-
* are not specified.
551+
* Creates a new mapper with the new model settings if model settings are not set on the mapper.
552+
* If the mapper already has model settings or the new model settings are null, the mapper is
553+
* returned unchanged.
500554
*
501-
* @param mapper The mapper
555+
* @param mapper The mapper
556+
* @param modelSettings the new model settings. If null the mapper will be returned unchanged.
502557
* @return A mapper with the copied settings applied
503558
*/
504-
private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, MapperMergeContext mapperMergeContext) {
559+
private SemanticTextFieldMapper copyWithNewModelSettingsIfNotSet(
560+
SemanticTextFieldMapper mapper,
561+
@Nullable MinimalServiceSettings modelSettings,
562+
MapperMergeContext mapperMergeContext
563+
) {
505564
SemanticTextFieldMapper returnedMapper = mapper;
506565
if (mapper.fieldType().getModelSettings() == null) {
507566
Builder builder = from(mapper);
508-
builder.setModelSettings(modelSettings.getValue());
567+
builder.setModelSettings(modelSettings);
509568
returnedMapper = builder.build(mapperMergeContext.getMapperBuilderContext());
510569
}
511570

0 commit comments

Comments
 (0)