Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,18 +249,13 @@ private static void validateFieldNotPresent(String field, Object fieldValue, Tas
}
}

public ModelConfigurations toModelConfigurations(String inferenceEntityId) {
return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this);
}

/**
* Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition.
*/
public boolean canMergeWith(MinimalServiceSettings other) {
return taskType == other.taskType
&& Objects.equals(dimensions, other.dimensions)
&& similarity == other.similarity
&& elementType == other.elementType
&& (service == null || service.equals(other.service));
&& elementType == other.elementType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ public class InferencePlugin extends Plugin
private final SetOnce<ElasticInferenceServiceComponents> elasticInferenceServiceComponents = new SetOnce<>();
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
private final SetOnce<ModelRegistry> modelRegistry = new SetOnce<>();
private List<InferenceServiceExtension> inferenceServiceExtensions;

public InferencePlugin(Settings settings) {
Expand Down Expand Up @@ -260,8 +261,8 @@ public Collection<?> createComponents(PluginServices services) {
var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService());
amazonBedrockFactory.set(amazonBedrockRequestSenderFactory);

ModelRegistry modelRegistry = new ModelRegistry(services.clusterService(), services.client());
services.clusterService().addListener(modelRegistry);
modelRegistry.set(new ModelRegistry(services.clusterService(), services.client()));
services.clusterService().addListener(modelRegistry.get());

if (inferenceServiceExtensions == null) {
inferenceServiceExtensions = new ArrayList<>();
Expand Down Expand Up @@ -299,7 +300,7 @@ public Collection<?> createComponents(PluginServices services) {
elasicInferenceServiceFactory.get(),
serviceComponents.get(),
inferenceServiceSettings,
modelRegistry,
modelRegistry.get(),
authorizationHandler
)
)
Expand All @@ -317,14 +318,14 @@ public Collection<?> createComponents(PluginServices services) {
var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext);
serviceRegistry.init(services.client());
for (var service : serviceRegistry.getServices().values()) {
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
service.defaultConfigIds().forEach(modelRegistry.get()::addDefaultIds);
}
inferenceServiceRegistry.set(serviceRegistry);

var actionFilter = new ShardBulkInferenceActionFilter(
services.clusterService(),
serviceRegistry,
modelRegistry,
modelRegistry.get(),
getLicenseState(),
services.indexingPressure()
);
Expand All @@ -334,7 +335,7 @@ public Collection<?> createComponents(PluginServices services) {
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));

components.add(serviceRegistry);
components.add(modelRegistry);
components.add(modelRegistry.get());
components.add(httpClientManager);
components.add(inferenceStats);

Expand Down Expand Up @@ -498,11 +499,16 @@ public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
return Map.of(SemanticInferenceMetadataFieldsMapper.NAME, SemanticInferenceMetadataFieldsMapper.PARSER);
}

// Overridable for tests
protected Supplier<ModelRegistry> getModelRegistry() {
return () -> modelRegistry.get();
}

@Override
public Map<String, Mapper.TypeParser> getMappers() {
return Map.of(
SemanticTextFieldMapper.CONTENT_TYPE,
SemanticTextFieldMapper.PARSER,
SemanticTextFieldMapper.parser(getModelRegistry()),
OffsetSourceFieldMapper.CONTENT_TYPE,
OffsetSourceFieldMapper.PARSER
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.xpack.inference.mapper;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
Expand All @@ -18,6 +20,7 @@
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.util.BitSet;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
Expand Down Expand Up @@ -75,6 +78,7 @@
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;

import java.io.IOException;
import java.io.UncheckedIOException;
Expand All @@ -89,6 +93,7 @@
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;

import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
Expand All @@ -112,6 +117,7 @@
* A {@link FieldMapper} for semantic text fields.
*/
public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper {
private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class);
public static final NodeFeature SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX = new NodeFeature("semantic_text.in_object_field_fix");
public static final NodeFeature SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX = new NodeFeature("semantic_text.single_field_update_fix");
public static final NodeFeature SEMANTIC_TEXT_DELETE_FIX = new NodeFeature("semantic_text.delete_fix");
Expand All @@ -127,10 +133,12 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
public static final String CONTENT_TYPE = "semantic_text";
public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID;

public static final TypeParser PARSER = new TypeParser(
(n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings()),
List.of(validateParserContext(CONTENT_TYPE))
);
public static final TypeParser parser(Supplier<ModelRegistry> modelRegistry) {
return new TypeParser(
(n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings(), modelRegistry.get()),
List.of(validateParserContext(CONTENT_TYPE))
);
}

public static BiConsumer<String, MappingParserContext> validateParserContext(String type) {
return (n, c) -> {
Expand All @@ -142,6 +150,7 @@ public static BiConsumer<String, MappingParserContext> validateParserContext(Str
}

public static class Builder extends FieldMapper.Builder {
private final ModelRegistry modelRegistry;
private final boolean useLegacyFormat;

private final Parameter<String> inferenceId = Parameter.stringParam(
Expand Down Expand Up @@ -199,14 +208,21 @@ public static Builder from(SemanticTextFieldMapper mapper) {
Builder builder = new Builder(
mapper.leafName(),
mapper.fieldType().getChunksField().bitsetProducer(),
mapper.fieldType().getChunksField().indexSettings()
mapper.fieldType().getChunksField().indexSettings(),
mapper.modelRegistry
);
builder.init(mapper);
return builder;
}

public Builder(String name, Function<Query, BitSetProducer> bitSetProducer, IndexSettings indexSettings) {
public Builder(
String name,
Function<Query, BitSetProducer> bitSetProducer,
IndexSettings indexSettings,
ModelRegistry modelRegistry
) {
super(name);
this.modelRegistry = modelRegistry;
this.useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(indexSettings.getSettings()) == false;
this.inferenceFieldBuilder = c -> createInferenceField(
c,
Expand Down Expand Up @@ -264,9 +280,32 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
if (useLegacyFormat && multiFieldsBuilder.hasMultiFields()) {
throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields");
}

if (modelSettings.get() == null) {
try {
var resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get());
if (resolvedModelSettings != null) {
modelSettings.setValue(resolvedModelSettings);
}
} catch (ResourceNotFoundException exc) {
// We allow the inference ID to be unregistered at this point.
// This will delay the creation of sub-fields, so indexing and querying for this field won't work
// until the corresponding inference endpoint is created.
}
}

if (modelSettings.get() != null) {
validateServiceSettings(modelSettings.get());
} else {
logger.warn(
"The field [{}] references an unknown inference ID [{}]. "
+ "Indexing and querying this field will not work correctly until the corresponding "
+ "inference endpoint is created.",
leafName(),
inferenceId.get()
);
}

final String fullName = context.buildFullName(leafName());

if (context.isInNestedContext()) {
Expand All @@ -287,7 +326,8 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
useLegacyFormat,
meta.getValue()
),
builderParams(this, context)
builderParams(this, context),
modelRegistry
);
}

Expand Down Expand Up @@ -328,9 +368,17 @@ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, Map
}
}

private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams) {
private final ModelRegistry modelRegistry;

private SemanticTextFieldMapper(
String simpleName,
MappedFieldType mappedFieldType,
BuilderParams builderParams,
ModelRegistry modelRegistry
) {
super(simpleName, mappedFieldType, builderParams);
ensureMultiFields(builderParams.multiFields().iterator());
this.modelRegistry = modelRegistry;
}

private void ensureMultiFields(Iterator<FieldMapper> mappers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.elasticsearch.cluster.ClusterStateAckListener;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.SimpleBatchedAckListenerTaskExecutor;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -140,18 +141,18 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap)
private static final String MODEL_ID_FIELD = "model_id";
private static final Logger logger = LogManager.getLogger(ModelRegistry.class);

private final ClusterService clusterService;
private final OriginSettingClient client;
private final Map<String, InferenceService.DefaultConfigId> defaultConfigIds;

private final MasterServiceTaskQueue<MetadataTask> metadataTaskQueue;
private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false);
private final Set<String> preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>());

private volatile Metadata lastMetadata;

public ModelRegistry(ClusterService clusterService, Client client) {
this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN);
this.defaultConfigIds = new ConcurrentHashMap<>();
this.clusterService = clusterService;
var executor = new SimpleBatchedAckListenerTaskExecutor<MetadataTask>() {
@Override
public Tuple<ClusterState, ClusterStateAckListener> executeTask(MetadataTask task, ClusterState clusterState) throws Exception {
Expand Down Expand Up @@ -224,11 +225,18 @@ public void clearDefaultIds() {
* @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster.
*/
public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
synchronized (this) {
assert lastMetadata != null : "initial cluster state not set yet";
if (lastMetadata == null) {
throw new IllegalStateException("initial cluster state not set yet");
}
}
var config = defaultConfigIds.get(inferenceEntityId);
if (config != null) {
return config.settings();
}
var state = ModelRegistryMetadata.fromState(clusterService.state().projectState().metadata());
var project = lastMetadata.getProject(ProjectId.DEFAULT);
var state = ModelRegistryMetadata.fromState(project);
var existing = state.getMinimalServiceSettings(inferenceEntityId);
if (state.isUpgraded() && existing == null) {
throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster.");
Expand Down Expand Up @@ -684,10 +692,14 @@ private ActionListener<BulkResponse> getStoreIndexListener(
if (updateClusterState) {
var storeListener = getStoreMetadataListener(inferenceEntityId, listener);
try {
var projectId = clusterService.state().projectState().projectId();
metadataTaskQueue.submitTask(
"add model [" + inferenceEntityId + "]",
new AddModelMetadataTask(projectId, inferenceEntityId, new MinimalServiceSettings(model), storeListener),
new AddModelMetadataTask(
ProjectId.DEFAULT,
inferenceEntityId,
new MinimalServiceSettings(model),
storeListener
),
timeout
);
} catch (Exception exc) {
Expand Down Expand Up @@ -854,10 +866,9 @@ public void onFailure(Exception exc) {
}
};
try {
var projectId = clusterService.state().projectState().projectId();
metadataTaskQueue.submitTask(
"delete models [" + inferenceEntityIds + "]",
new DeleteModelMetadataTask(projectId, inferenceEntityIds, clusterStateListener),
new DeleteModelMetadataTask(ProjectId.DEFAULT, inferenceEntityIds, clusterStateListener),
null
);
} catch (Exception exc) {
Expand Down Expand Up @@ -935,6 +946,13 @@ static List<InferenceService.DefaultConfigId> taskTypeMatchedDefaults(

@Override
public void clusterChanged(ClusterChangedEvent event) {
if (lastMetadata == null || event.metadataChanged()) {
// keep track of the last applied cluster state
synchronized (this) {
lastMetadata = event.state().metadata();
}
}

if (event.localNodeMaster() == false) {
return;
}
Expand Down Expand Up @@ -984,7 +1002,7 @@ public void onResponse(GetInferenceModelAction.Response response) {
metadataTaskQueue.submitTask(
"model registry auto upgrade",
new UpgradeModelsMetadataTask(
clusterService.state().metadata().getProject().id(),
ProjectId.DEFAULT,
map,
ActionListener.running(() -> upgradeMetadataInProgress.set(false))
),
Expand Down
Loading