diff --git a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java index eb7823d572302..39b768631a586 100644 --- a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java @@ -242,10 +242,6 @@ private static void validateFieldNotPresent(String field, Object fieldValue, Tas } } - public ModelConfigurations toModelConfigurations(String inferenceEntityId) { - return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this); - } - /** * Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition. */ @@ -253,7 +249,6 @@ public boolean canMergeWith(MinimalServiceSettings other) { return taskType == other.taskType && Objects.equals(dimensions, other.dimensions) && similarity == other.similarity - && elementType == other.elementType - && (service == null || service.equals(other.service)); + && elementType == other.elementType; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 93b2fa0ebb979..f7d3f791d79d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -199,6 +199,7 @@ public class InferencePlugin extends Plugin private final SetOnce elasticInferenceServiceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); + private final SetOnce modelRegistry = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -262,8 +263,8 @@ public Collection createComponents(PluginServices services) { var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService()); amazonBedrockFactory.set(amazonBedrockRequestSenderFactory); - ModelRegistry modelRegistry = new ModelRegistry(services.clusterService(), services.client()); - services.clusterService().addListener(modelRegistry); + modelRegistry.set(new ModelRegistry(services.clusterService(), services.client())); + services.clusterService().addListener(modelRegistry.get()); if (inferenceServiceExtensions == null) { inferenceServiceExtensions = new ArrayList<>(); @@ -301,7 +302,7 @@ public Collection createComponents(PluginServices services) { elasicInferenceServiceFactory.get(), serviceComponents.get(), inferenceServiceSettings, - modelRegistry, + modelRegistry.get(), authorizationHandler ) ) @@ -319,18 +320,23 @@ public Collection createComponents(PluginServices services) { var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext); serviceRegistry.init(services.client()); for (var service : serviceRegistry.getServices().values()) { - service.defaultConfigIds().forEach(modelRegistry::addDefaultIds); + service.defaultConfigIds().forEach(modelRegistry.get()::addDefaultIds); } inferenceServiceRegistry.set(serviceRegistry); - var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState()); + var actionFilter = new ShardBulkInferenceActionFilter( + services.clusterService(), + serviceRegistry, + modelRegistry.get(), + getLicenseState() + ); shardBulkInferenceActionFilter.set(actionFilter); var meterRegistry = services.telemetryProvider().getMeterRegistry(); var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry)); components.add(serviceRegistry); - components.add(modelRegistry); + components.add(modelRegistry.get()); components.add(httpClientManager); components.add(inferenceStats); @@ -497,11 +503,16 @@ public Map getMetadataMappers() { return Map.of(SemanticInferenceMetadataFieldsMapper.NAME, SemanticInferenceMetadataFieldsMapper.PARSER); } + // Overridable for tests + protected Supplier getModelRegistry() { + return modelRegistry::get; + } + @Override public Map getMappers() { return Map.of( SemanticTextFieldMapper.CONTENT_TYPE, - SemanticTextFieldMapper.PARSER, + SemanticTextFieldMapper.parser(getModelRegistry()), OffsetSourceFieldMapper.CONTENT_TYPE, OffsetSourceFieldMapper.PARSER ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index f78bcc9106979..f882a6a782871 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.mapper; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; @@ -18,6 +20,7 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.util.BitSet; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; @@ -75,6 +78,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.io.UncheckedIOException; @@ -89,6 +93,7 @@ import java.util.Set; import java.util.function.BiConsumer; import java.util.function.Function; +import java.util.function.Supplier; import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; @@ -112,6 +117,7 @@ * A {@link FieldMapper} for semantic text fields. */ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { + private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final NodeFeature SEMANTIC_TEXT_SEARCH_INFERENCE_ID = new NodeFeature("semantic_text.search_inference_id", true); public static final NodeFeature SEMANTIC_TEXT_DEFAULT_ELSER_2 = new NodeFeature("semantic_text.default_elser_2", true); public static final NodeFeature SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX = new NodeFeature("semantic_text.in_object_field_fix"); @@ -129,10 +135,12 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie public static final String CONTENT_TYPE = "semantic_text"; public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID; - public static final TypeParser PARSER = new TypeParser( - (n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings()), - List.of(validateParserContext(CONTENT_TYPE)) - ); + public static final TypeParser parser(Supplier modelRegistry) { + return new TypeParser( + (n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings(), modelRegistry.get()), + List.of(validateParserContext(CONTENT_TYPE)) + ); + } public static BiConsumer validateParserContext(String type) { return (n, c) -> { @@ -144,6 +152,7 @@ public static BiConsumer validateParserContext(Str } public static class Builder extends FieldMapper.Builder { + private final ModelRegistry modelRegistry; private final boolean useLegacyFormat; private final Parameter inferenceId = Parameter.stringParam( @@ -201,14 +210,21 @@ public static Builder from(SemanticTextFieldMapper mapper) { Builder builder = new Builder( mapper.leafName(), mapper.fieldType().getChunksField().bitsetProducer(), - mapper.fieldType().getChunksField().indexSettings() + mapper.fieldType().getChunksField().indexSettings(), + mapper.modelRegistry ); builder.init(mapper); return builder; } - public Builder(String name, Function bitSetProducer, IndexSettings indexSettings) { + public Builder( + String name, + Function bitSetProducer, + IndexSettings indexSettings, + ModelRegistry modelRegistry + ) { super(name); + this.modelRegistry = modelRegistry; this.useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(indexSettings.getSettings()) == false; this.inferenceFieldBuilder = c -> createInferenceField( c, @@ -266,9 +282,32 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { if (useLegacyFormat && multiFieldsBuilder.hasMultiFields()) { throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields"); } + + if (modelSettings.get() == null) { + try { + var resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get()); + if (resolvedModelSettings != null) { + modelSettings.setValue(resolvedModelSettings); + } + } catch (ResourceNotFoundException exc) { + // We allow the inference ID to be unregistered at this point. + // This will delay the creation of sub-fields, so indexing and querying for this field won't work + // until the corresponding inference endpoint is created. + } + } + if (modelSettings.get() != null) { validateServiceSettings(modelSettings.get()); + } else { + logger.warn( + "The field [{}] references an unknown inference ID [{}]. " + + "Indexing and querying this field will not work correctly until the corresponding " + + "inference endpoint is created.", + leafName(), + inferenceId.get() + ); } + final String fullName = context.buildFullName(leafName()); if (context.isInNestedContext()) { @@ -289,7 +328,8 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { useLegacyFormat, meta.getValue() ), - builderParams(this, context) + builderParams(this, context), + modelRegistry ); } @@ -330,9 +370,17 @@ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, Map } } - private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams) { + private final ModelRegistry modelRegistry; + + private SemanticTextFieldMapper( + String simpleName, + MappedFieldType mappedFieldType, + BuilderParams builderParams, + ModelRegistry modelRegistry + ) { super(simpleName, mappedFieldType, builderParams); ensureMultiFields(builderParams.multiFields().iterator()); + this.modelRegistry = modelRegistry; } private void ensureMultiFields(Iterator mappers) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 5c99135e1dd22..1d031420e53c9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -139,7 +139,6 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private static final String MODEL_ID_FIELD = "model_id"; private static final Logger logger = LogManager.getLogger(ModelRegistry.class); - private final ClusterService clusterService; private final OriginSettingClient client; private final Map defaultConfigIds; @@ -147,10 +146,11 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false); private final Set preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>()); + private volatile Metadata lastMetadata; + public ModelRegistry(ClusterService clusterService, Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); this.defaultConfigIds = new ConcurrentHashMap<>(); - this.clusterService = clusterService; var executor = new SimpleBatchedAckListenerTaskExecutor() { @Override public Tuple executeTask(MetadataTask task, ClusterState clusterState) throws Exception { @@ -222,11 +222,17 @@ public void clearDefaultIds() { * @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster. */ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException { + synchronized (this) { + assert lastMetadata != null : "initial cluster state not set yet"; + if (lastMetadata == null) { + throw new IllegalStateException("initial cluster state not set yet"); + } + } var config = defaultConfigIds.get(inferenceEntityId); if (config != null) { return config.settings(); } - var state = ModelRegistryMetadata.fromState(clusterService.state().metadata()); + var state = ModelRegistryMetadata.fromState(lastMetadata); var existing = state.getMinimalServiceSettings(inferenceEntityId); if (state.isUpgraded() && existing == null) { throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster."); @@ -931,6 +937,13 @@ static List taskTypeMatchedDefaults( @Override public void clusterChanged(ClusterChangedEvent event) { + if (lastMetadata == null || event.metadataChanged()) { + // keep track of the last applied cluster state + synchronized (this) { + lastMetadata = event.state().metadata(); + } + } + if (event.localNodeMaster() == false) { return; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index f8c24c7e4bb2c..f872d8f302f37 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -24,6 +24,7 @@ import org.apache.lucene.search.join.QueryBitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.Strings; @@ -62,6 +63,9 @@ import org.elasticsearch.search.LeafNestedDocuments; import org.elasticsearch.search.NestedDocuments; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; @@ -69,7 +73,10 @@ import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; import org.junit.AssumptionViolatedException; +import org.junit.Before; import java.io.IOException; import java.util.Collection; @@ -79,6 +86,7 @@ import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; +import java.util.function.Supplier; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; @@ -100,10 +108,22 @@ public class SemanticTextFieldMapperTests extends MapperTestCase { private final boolean useLegacyFormat; + private TestThreadPool threadPool; + public SemanticTextFieldMapperTests(boolean useLegacyFormat) { this.useLegacyFormat = useLegacyFormat; } + @Before + private void startThreadPool() { + threadPool = createThreadPool(); + } + + @After + private void stopThreadPool() { + threadPool.close(); + } + @ParametersFactory public static Iterable parameters() throws Exception { return List.of(new Object[] { true }, new Object[] { false }); @@ -111,7 +131,20 @@ public static Iterable parameters() throws Exception { @Override protected Collection getPlugins() { - return List.of(new InferencePlugin(Settings.EMPTY), new XPackClientPlugin()); + var clusterService = ClusterServiceUtils.createClusterService(threadPool); + var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { + @Override + public boolean localNodeMaster() { + return false; + } + }); + return List.of(new InferencePlugin(Settings.EMPTY) { + @Override + protected Supplier getModelRegistry() { + return () -> modelRegistry; + } + }, new XPackClientPlugin()); } private MapperService createMapperService(XContentBuilder mappings, boolean useLegacyFormat) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index ce84046d7cc5f..fca6b371e18b0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -22,12 +22,14 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.IOUtils; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.MapperService; @@ -46,6 +48,9 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.AbstractQueryTestCase; +import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; @@ -60,6 +65,8 @@ import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -70,6 +77,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static org.apache.lucene.search.BooleanClause.Occur.FILTER; import static org.apache.lucene.search.BooleanClause.Occur.MUST; @@ -118,6 +126,24 @@ public static void setInferenceResultType() { useSearchInferenceId = randomBoolean(); } + @BeforeClass + public static void startModelRegistry() { + threadPool = new TestThreadPool(SemanticQueryBuilderTests.class.getName()); + var clusterService = ClusterServiceUtils.createClusterService(threadPool); + modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { + @Override + public boolean localNodeMaster() { + return false; + } + }); + } + + @AfterClass + public static void stopModelRegistry() { + IOUtils.closeWhileHandlingException(threadPool); + } + @Override @Before public void setUp() throws Exception { @@ -127,7 +153,7 @@ public void setUp() throws Exception { @Override protected Collection> getPlugins() { - return List.of(XPackClientPlugin.class, InferencePlugin.class, FakeMlPlugin.class); + return List.of(XPackClientPlugin.class, InferencePluginWithModelRegistry.class, FakeMlPlugin.class); } @Override @@ -394,4 +420,18 @@ public List getNamedWriteables() { return new MlInferenceNamedXContentProvider().getNamedWriteables(); } } + + private static TestThreadPool threadPool; + private static ModelRegistry modelRegistry; + + public static class InferencePluginWithModelRegistry extends InferencePlugin { + public InferencePluginWithModelRegistry(Settings settings) { + super(settings); + } + + @Override + protected Supplier getModelRegistry() { + return () -> modelRegistry; + } + } } diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index 88b2d851bf69f..1477aae9216a2 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -9,13 +9,20 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.client.Request; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; +import org.junit.After; import org.junit.ClassRule; +import java.io.IOException; +import java.util.List; +import java.util.Map; + public class InferenceRestIT extends ESClientYamlSuiteTestCase { @ClassRule @@ -50,4 +57,24 @@ protected String getTestRestCluster() { public static Iterable parameters() throws Exception { return ESClientYamlSuiteTestCase.createParameters(); } + + @After + public void cleanup() throws Exception { + for (var model : getAllModels()) { + var inferenceId = model.get("inference_id"); + try { + var endpoint = Strings.format("_inference/%s?force", inferenceId); + adminClient().performRequest(new Request("DELETE", endpoint)); + } catch (Exception ex) { + logger.warn(() -> "failed to delete inference endpoint " + inferenceId, ex); + } + } + } + + @SuppressWarnings("unchecked") + static List> getAllModels() throws IOException { + var request = new Request("GET", "_inference/_all"); + var response = client().performRequest(request); + return (List>) entityAsMap(response).get("endpoints"); + } }