Skip to content

Commit 66b167e

Browse files
Allow updating inference_id of semantic_text fields
Previously the `inference_id` of `semantic_text` fields was not updatable. This commit allows users to update the `inference_id` of a `semantic_text` field. This is particularly useful for scenarios where the user wants to switch to using the same model but from a different service. There are two circumstances when the update is allowed. - No values have been written for the `semantic_text` field. The inference endpoint can be changed freely as there is no need for compatibility between the current and the new endpoint. - The new inference endpoint is compatible with the previous one. The `model_settings` of the new inference endpoint are compatible with those of the current endpoint, thus the update is allowed.
1 parent 53d5661 commit 66b167e

File tree

6 files changed

+770
-59
lines changed

6 files changed

+770
-59
lines changed

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import org.elasticsearch.common.ValidationException;
1414
import org.elasticsearch.common.io.stream.StreamInput;
1515
import org.elasticsearch.common.io.stream.StreamOutput;
16-
import org.elasticsearch.common.util.LazyInitializable;
1716
import org.elasticsearch.core.Nullable;
1817
import org.elasticsearch.core.TimeValue;
1918
import org.elasticsearch.inference.ChunkInferenceInput;
@@ -43,12 +42,13 @@
4342
import java.util.HashMap;
4443
import java.util.List;
4544
import java.util.Map;
45+
import java.util.Objects;
4646

4747
public class TestSparseInferenceServiceExtension implements InferenceServiceExtension {
4848

4949
@Override
5050
public List<Factory> getInferenceServiceFactories() {
51-
return List.of(TestInferenceService::new);
51+
return List.of(TestInferenceService::new, TestInferenceService2::new);
5252
}
5353

5454
public static class TestSparseModel extends Model {
@@ -60,16 +60,40 @@ public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSett
6060
}
6161
}
6262

63-
public static class TestInferenceService extends AbstractTestInferenceService {
63+
public static class TestInferenceService extends AbstractSparseTestInferenceService {
6464
public static final String NAME = "test_service";
6565

66+
public TestInferenceService(InferenceServiceFactoryContext inferenceServiceFactoryContext) {}
67+
68+
@Override
69+
protected String testServiceName() {
70+
return NAME;
71+
}
72+
}
73+
74+
/**
75+
* A second sparse service allows testing updates from one service to another.
76+
*/
77+
public static class TestInferenceService2 extends AbstractSparseTestInferenceService {
78+
public static final String NAME = "test_service_2";
79+
80+
public TestInferenceService2(InferenceServiceFactoryContext inferenceServiceFactoryContext) {}
81+
82+
@Override
83+
protected String testServiceName() {
84+
return NAME;
85+
}
86+
}
87+
88+
abstract static class AbstractSparseTestInferenceService extends AbstractTestInferenceService {
89+
6690
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING);
6791

68-
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
92+
protected abstract String testServiceName();
6993

7094
@Override
7195
public String name() {
72-
return NAME;
96+
return testServiceName();
7397
}
7498

7599
@Override
@@ -92,7 +116,7 @@ public void parseRequestConfig(
92116

93117
@Override
94118
public InferenceServiceConfiguration getConfiguration() {
95-
return Configuration.get();
119+
return new Configuration(testServiceName()).get();
96120
}
97121

98122
@Override
@@ -195,41 +219,43 @@ private static float generateEmbedding(String input, int position) {
195219
}
196220

197221
public static class Configuration {
198-
public static InferenceServiceConfiguration get() {
199-
return configuration.getOrCompute();
222+
223+
private final String serviceName;
224+
225+
Configuration(String serviceName) {
226+
this.serviceName = Objects.requireNonNull(serviceName);
227+
}
228+
229+
InferenceServiceConfiguration get() {
230+
var configurationMap = new HashMap<String, SettingsConfiguration>();
231+
232+
configurationMap.put(
233+
"model",
234+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
235+
.setLabel("Model")
236+
.setRequired(true)
237+
.setSensitive(false)
238+
.setType(SettingsConfigurationFieldType.STRING)
239+
.build()
240+
);
241+
242+
configurationMap.put(
243+
"hidden_field",
244+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
245+
.setLabel("Hidden Field")
246+
.setRequired(true)
247+
.setSensitive(false)
248+
.setType(SettingsConfigurationFieldType.STRING)
249+
.build()
250+
);
251+
252+
return new InferenceServiceConfiguration.Builder().setService(serviceName)
253+
.setName(serviceName)
254+
.setTaskTypes(supportedTaskTypes)
255+
.setConfigurations(configurationMap)
256+
.build();
200257
}
201258

202-
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
203-
() -> {
204-
var configurationMap = new HashMap<String, SettingsConfiguration>();
205-
206-
configurationMap.put(
207-
"model",
208-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
209-
.setLabel("Model")
210-
.setRequired(true)
211-
.setSensitive(false)
212-
.setType(SettingsConfigurationFieldType.STRING)
213-
.build()
214-
);
215-
216-
configurationMap.put(
217-
"hidden_field",
218-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
219-
.setLabel("Hidden Field")
220-
.setRequired(true)
221-
.setSensitive(false)
222-
.setType(SettingsConfigurationFieldType.STRING)
223-
.build()
224-
);
225-
226-
return new InferenceServiceConfiguration.Builder().setService(NAME)
227-
.setName(NAME)
228-
.setTaskTypes(supportedTaskTypes)
229-
.setConfigurations(configurationMap)
230-
.build();
231-
}
232-
);
233259
}
234260
}
235261

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS;
2525
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS;
2626
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG;
27+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_UPDATABLE_INFERENCE_ID;
2728
import static org.elasticsearch.xpack.inference.queries.LegacySemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_FILTER_FIX;
2829
import static org.elasticsearch.xpack.inference.queries.LegacySemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
2930
import static org.elasticsearch.xpack.inference.queries.LegacySemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
@@ -92,6 +93,7 @@ public Set<NodeFeature> getTestFeatures() {
9293
SEMANTIC_TEXT_HIGHLIGHTING_FLAT,
9394
SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS,
9495
SEMANTIC_TEXT_FIELDS_CHUNKS_FORMAT,
96+
SEMANTIC_TEXT_UPDATABLE_INFERENCE_ID,
9597
SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS,
9698
SemanticQueryBuilder.SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX,
9799
InterceptedInferenceQueryBuilder.NEW_SEMANTIC_QUERY_INTERCEPTORS,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
149149
public static final NodeFeature SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS = new NodeFeature(
150150
"semantic_text.sparse_vector_index_options"
151151
);
152+
public static final NodeFeature SEMANTIC_TEXT_UPDATABLE_INFERENCE_ID = new NodeFeature("semantic_text.updatable_inference_id");
152153

153154
public static final String CONTENT_TYPE = "semantic_text";
154155
public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID;
@@ -242,7 +243,7 @@ public Builder(
242243

243244
this.inferenceId = Parameter.stringParam(
244245
INFERENCE_ID_FIELD,
245-
false,
246+
true,
246247
mapper -> ((SemanticTextFieldType) mapper.fieldType()).inferenceId,
247248
DEFAULT_ELSER_2_INFERENCE_ID
248249
).addValidator(v -> {
@@ -325,9 +326,68 @@ protected Parameter<?>[] getParameters() {
325326
@Override
326327
protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) {
327328
SemanticTextFieldMapper semanticMergeWith = (SemanticTextFieldMapper) mergeWith;
328-
semanticMergeWith = copySettings(semanticMergeWith, mapperMergeContext);
329329

330-
// We make sure to merge the inference field first to catch any model conflicts
330+
final boolean isInferenceIdUpdate = semanticMergeWith.fieldType().inferenceId.equals(inferenceId.get()) == false;
331+
final boolean hasExplicitModelSettings = modelSettings.get() != null;
332+
333+
if (isInferenceIdUpdate && hasExplicitModelSettings) {
334+
validateModelsAreCompatibleWhenInferenceIdIsUpdated(semanticMergeWith.fieldType().inferenceId, conflicts);
335+
// As the mapper previously had explicit model settings, we need to apply to the new merged mapper
336+
// the resolved model settings if not explicitly set.
337+
semanticMergeWith = copyWithNewModelSettingsIfNotSet(
338+
semanticMergeWith,
339+
modelRegistry.getMinimalServiceSettings(semanticMergeWith.fieldType().inferenceId),
340+
mapperMergeContext
341+
);
342+
}
343+
344+
semanticMergeWith = copyModelSettingsIfNotSet(semanticMergeWith, mapperMergeContext);
345+
346+
// We make sure to merge the inference field first to catch any model conflicts.
347+
// If inference_id is updated and there are no explicit model settings, we should be
348+
// able to switch to the new inference field without the need to check for conflicts.
349+
if (isInferenceIdUpdate == false || hasExplicitModelSettings) {
350+
mergeInferenceField(mapperMergeContext, semanticMergeWith);
351+
}
352+
353+
super.merge(semanticMergeWith, conflicts, mapperMergeContext);
354+
conflicts.check();
355+
}
356+
357+
private void validateModelsAreCompatibleWhenInferenceIdIsUpdated(String newInferenceId, Conflicts conflicts) {
358+
MinimalServiceSettings currentModelSettings = modelSettings.get();
359+
MinimalServiceSettings updatedModelSettings = modelRegistry.getMinimalServiceSettings(newInferenceId);
360+
if (currentModelSettings != null && updatedModelSettings == null) {
361+
throw new IllegalArgumentException(
362+
"Cannot merge ["
363+
+ CONTENT_TYPE
364+
+ "] field ["
365+
+ leafName()
366+
+ "] because inference endpoint ["
367+
+ newInferenceId
368+
+ "] does not exist."
369+
);
370+
}
371+
if (canMergeModelSettings(currentModelSettings, updatedModelSettings, conflicts) == false) {
372+
throw new IllegalArgumentException(
373+
"Cannot merge ["
374+
+ CONTENT_TYPE
375+
+ "] field ["
376+
+ leafName()
377+
+ "] because inference endpoint ["
378+
+ inferenceId.get()
379+
+ "] with model settings ["
380+
+ currentModelSettings
381+
+ "] is not compatible with new inference endpoint ["
382+
+ newInferenceId
383+
+ "] with model settings ["
384+
+ updatedModelSettings
385+
+ "]."
386+
);
387+
}
388+
}
389+
390+
private void mergeInferenceField(MapperMergeContext mapperMergeContext, SemanticTextFieldMapper semanticMergeWith) {
331391
try {
332392
var context = mapperMergeContext.createChildContext(semanticMergeWith.leafName(), ObjectMapper.Dynamic.FALSE);
333393
var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext());
@@ -340,9 +400,6 @@ protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeCont
340400
: "";
341401
throw new IllegalArgumentException(errorMessage, e);
342402
}
343-
344-
super.merge(semanticMergeWith, conflicts, mapperMergeContext);
345-
conflicts.check();
346403
}
347404

348405
/**
@@ -498,18 +555,35 @@ private void validateIndexOptions(SemanticTextIndexOptions indexOptions, String
498555
}
499556

500557
/**
501-
* As necessary, copy settings from this builder to the passed-in mapper.
558+
* As necessary, copy model settings from this builder to the passed-in mapper.
502559
* Used to preserve {@link MinimalServiceSettings} when updating a semantic text mapping to one where the model settings
503560
* are not specified.
504561
*
505562
* @param mapper The mapper
506563
* @return A mapper with the copied settings applied
507564
*/
508-
private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, MapperMergeContext mapperMergeContext) {
565+
private SemanticTextFieldMapper copyModelSettingsIfNotSet(SemanticTextFieldMapper mapper, MapperMergeContext mapperMergeContext) {
566+
return copyWithNewModelSettingsIfNotSet(mapper, modelSettings.getValue(), mapperMergeContext);
567+
}
568+
569+
/**
570+
* Creates a new mapper with the new model settings if model settings are not set on the mapper.
571+
* If the mapper already has model settings or the new model settings are null, the mapper is
572+
* returned unchanged.
573+
*
574+
* @param mapper The mapper
575+
* @param modelSettings the new model settings. If null the mapper will be returned unchanged.
576+
* @return A mapper with the copied settings applied
577+
*/
578+
private SemanticTextFieldMapper copyWithNewModelSettingsIfNotSet(
579+
SemanticTextFieldMapper mapper,
580+
@Nullable MinimalServiceSettings modelSettings,
581+
MapperMergeContext mapperMergeContext
582+
) {
509583
SemanticTextFieldMapper returnedMapper = mapper;
510584
if (mapper.fieldType().getModelSettings() == null) {
511585
Builder builder = from(mapper);
512-
builder.setModelSettings(modelSettings.getValue());
586+
builder.setModelSettings(modelSettings);
513587
returnedMapper = builder.build(mapperMergeContext.getMapperBuilderContext());
514588
}
515589

@@ -783,6 +857,11 @@ protected void doValidate(MappingLookup mappers) {
783857
}
784858
}
785859

860+
@Override
861+
protected void checkIncomingMergeType(FieldMapper mergeWith) {
862+
super.checkIncomingMergeType(mergeWith);
863+
}
864+
786865
public static class SemanticTextFieldType extends SimpleMappedFieldType {
787866
private final String inferenceId;
788867
private final String searchInferenceId;

0 commit comments

Comments
 (0)