Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/136120.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 136120
summary: Allow updating `inference_id` of `semantic_text` fields
area: "Search"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Search or Mapping?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say mapping

type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
Expand Down Expand Up @@ -43,12 +42,13 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class TestSparseInferenceServiceExtension implements InferenceServiceExtension {

@Override
public List<Factory> getInferenceServiceFactories() {
return List.of(TestInferenceService::new);
return List.of(TestInferenceService::new, TestInferenceService2::new);
}

public static class TestSparseModel extends Model {
Expand All @@ -60,16 +60,40 @@ public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSett
}
}

public static class TestInferenceService extends AbstractTestInferenceService {
public static class TestInferenceService extends AbstractSparseTestInferenceService {
public static final String NAME = "test_service";

public TestInferenceService(InferenceServiceFactoryContext inferenceServiceFactoryContext) {}

@Override
protected String testServiceName() {
return NAME;
}
}

/**
* A second sparse service allows testing updates from one service to another.
*/
public static class TestInferenceService2 extends AbstractSparseTestInferenceService {
public static final String NAME = "test_service_2";

public TestInferenceService2(InferenceServiceFactoryContext inferenceServiceFactoryContext) {}

@Override
protected String testServiceName() {
return NAME;
}
}

abstract static class AbstractSparseTestInferenceService extends AbstractTestInferenceService {

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING);

public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
protected abstract String testServiceName();

@Override
public String name() {
return NAME;
return testServiceName();
}

@Override
Expand All @@ -92,7 +116,7 @@ public void parseRequestConfig(

@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
return new Configuration(testServiceName()).get();
}

@Override
Expand Down Expand Up @@ -195,41 +219,43 @@ private static float generateEmbedding(String input, int position) {
}

public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();

private final String serviceName;

Configuration(String serviceName) {
this.serviceName = Objects.requireNonNull(serviceName);
}

InferenceServiceConfiguration get() {
var configurationMap = new HashMap<String, SettingsConfiguration>();

configurationMap.put(
"model",
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
.setLabel("Model")
.setRequired(true)
.setSensitive(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
);

configurationMap.put(
"hidden_field",
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
.setLabel("Hidden Field")
.setRequired(true)
.setSensitive(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
);

return new InferenceServiceConfiguration.Builder().setService(serviceName)
.setName(serviceName)
.setTaskTypes(supportedTaskTypes)
.setConfigurations(configurationMap)
.build();
}

private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
() -> {
var configurationMap = new HashMap<String, SettingsConfiguration>();

configurationMap.put(
"model",
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
.setLabel("Model")
.setRequired(true)
.setSensitive(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
);

configurationMap.put(
"hidden_field",
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
.setLabel("Hidden Field")
.setRequired(true)
.setSensitive(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(NAME)
.setTaskTypes(supportedTaskTypes)
.setConfigurations(configurationMap)
.build();
}
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_UPDATABLE_INFERENCE_ID;
import static org.elasticsearch.xpack.inference.queries.LegacySemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_FILTER_FIX;
import static org.elasticsearch.xpack.inference.queries.LegacySemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
import static org.elasticsearch.xpack.inference.queries.LegacySemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
Expand Down Expand Up @@ -92,6 +93,7 @@ public Set<NodeFeature> getTestFeatures() {
SEMANTIC_TEXT_HIGHLIGHTING_FLAT,
SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS,
SEMANTIC_TEXT_FIELDS_CHUNKS_FORMAT,
SEMANTIC_TEXT_UPDATABLE_INFERENCE_ID,
SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS,
SemanticQueryBuilder.SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX,
InterceptedInferenceQueryBuilder.NEW_SEMANTIC_QUERY_INTERCEPTORS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
public static final NodeFeature SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS = new NodeFeature(
"semantic_text.sparse_vector_index_options"
);
public static final NodeFeature SEMANTIC_TEXT_UPDATABLE_INFERENCE_ID = new NodeFeature("semantic_text.updatable_inference_id");

public static final String CONTENT_TYPE = "semantic_text";
public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID;
Expand Down Expand Up @@ -242,7 +243,7 @@ public Builder(

this.inferenceId = Parameter.stringParam(
INFERENCE_ID_FIELD,
false,
true,
mapper -> ((SemanticTextFieldType) mapper.fieldType()).inferenceId,
DEFAULT_ELSER_2_INFERENCE_ID
).addValidator(v -> {
Expand Down Expand Up @@ -325,9 +326,68 @@ protected Parameter<?>[] getParameters() {
@Override
protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) {
SemanticTextFieldMapper semanticMergeWith = (SemanticTextFieldMapper) mergeWith;
semanticMergeWith = copySettings(semanticMergeWith, mapperMergeContext);

// We make sure to merge the inference field first to catch any model conflicts
final boolean isInferenceIdUpdate = semanticMergeWith.fieldType().inferenceId.equals(inferenceId.get()) == false;
final boolean hasExplicitModelSettings = modelSettings.get() != null;

if (isInferenceIdUpdate && hasExplicitModelSettings) {
validateModelsAreCompatibleWhenInferenceIdIsUpdated(semanticMergeWith.fieldType().inferenceId, conflicts);
// As the mapper previously had explicit model settings, we need to apply to the new merged mapper
// the resolved model settings if not explicitly set.
semanticMergeWith = copyWithNewModelSettingsIfNotSet(
semanticMergeWith,
modelRegistry.getMinimalServiceSettings(semanticMergeWith.fieldType().inferenceId),
mapperMergeContext
);
}

semanticMergeWith = copyModelSettingsIfNotSet(semanticMergeWith, mapperMergeContext);

// We make sure to merge the inference field first to catch any model conflicts.
// If inference_id is updated and there are no explicit model settings, we should be
// able to switch to the new inference field without the need to check for conflicts.
if (isInferenceIdUpdate == false || hasExplicitModelSettings) {
mergeInferenceField(mapperMergeContext, semanticMergeWith);
}

super.merge(semanticMergeWith, conflicts, mapperMergeContext);
conflicts.check();
}

private void validateModelsAreCompatibleWhenInferenceIdIsUpdated(String newInferenceId, Conflicts conflicts) {
MinimalServiceSettings currentModelSettings = modelSettings.get();
MinimalServiceSettings updatedModelSettings = modelRegistry.getMinimalServiceSettings(newInferenceId);
if (currentModelSettings != null && updatedModelSettings == null) {
throw new IllegalArgumentException(
"Cannot merge ["
+ CONTENT_TYPE
+ "] field ["
+ leafName()
+ "] because inference endpoint ["
+ newInferenceId
+ "] does not exist."
);
}
if (canMergeModelSettings(currentModelSettings, updatedModelSettings, conflicts) == false) {
throw new IllegalArgumentException(
"Cannot merge ["
+ CONTENT_TYPE
+ "] field ["
+ leafName()
+ "] because inference endpoint ["
+ inferenceId.get()
+ "] with model settings ["
+ currentModelSettings
+ "] is not compatible with new inference endpoint ["
+ newInferenceId
+ "] with model settings ["
+ updatedModelSettings
+ "]."
);
}
}

private void mergeInferenceField(MapperMergeContext mapperMergeContext, SemanticTextFieldMapper semanticMergeWith) {
try {
var context = mapperMergeContext.createChildContext(semanticMergeWith.leafName(), ObjectMapper.Dynamic.FALSE);
var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext());
Expand All @@ -340,9 +400,6 @@ protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeCont
: "";
throw new IllegalArgumentException(errorMessage, e);
}

super.merge(semanticMergeWith, conflicts, mapperMergeContext);
conflicts.check();
}

/**
Expand Down Expand Up @@ -498,18 +555,35 @@ private void validateIndexOptions(SemanticTextIndexOptions indexOptions, String
}

/**
* As necessary, copy settings from this builder to the passed-in mapper.
* As necessary, copy model settings from this builder to the passed-in mapper.
* Used to preserve {@link MinimalServiceSettings} when updating a semantic text mapping to one where the model settings
* are not specified.
*
* @param mapper The mapper
* @return A mapper with the copied settings applied
*/
private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, MapperMergeContext mapperMergeContext) {
private SemanticTextFieldMapper copyModelSettingsIfNotSet(SemanticTextFieldMapper mapper, MapperMergeContext mapperMergeContext) {
return copyWithNewModelSettingsIfNotSet(mapper, modelSettings.getValue(), mapperMergeContext);
}

/**
* Creates a new mapper with the new model settings if model settings are not set on the mapper.
* If the mapper already has model settings or the new model settings are null, the mapper is
* returned unchanged.
*
* @param mapper The mapper
* @param modelSettings the new model settings. If null the mapper will be returned unchanged.
* @return A mapper with the copied settings applied
*/
private SemanticTextFieldMapper copyWithNewModelSettingsIfNotSet(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we're doing a null check here before we copy the settings, but we silently ignore if that case happens. Should we throw if it's not null and this method is called?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be that it's not null because the update itself contains model_settings. In that case the explicitly set model_settings take precedence.

SemanticTextFieldMapper mapper,
@Nullable MinimalServiceSettings modelSettings,
MapperMergeContext mapperMergeContext
) {
SemanticTextFieldMapper returnedMapper = mapper;
if (mapper.fieldType().getModelSettings() == null) {
Builder builder = from(mapper);
builder.setModelSettings(modelSettings.getValue());
builder.setModelSettings(modelSettings);
returnedMapper = builder.build(mapperMergeContext.getMapperBuilderContext());
}

Expand Down Expand Up @@ -783,6 +857,11 @@ protected void doValidate(MappingLookup mappers) {
}
}

@Override
protected void checkIncomingMergeType(FieldMapper mergeWith) {
super.checkIncomingMergeType(mergeWith);
}

public static class SemanticTextFieldType extends SimpleMappedFieldType {
private final String inferenceId;
private final String searchInferenceId;
Expand Down
Loading