From c14d90f8f6b99bf8c1ef2c4d124ba9be8be4ef87 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 10 Apr 2025 20:43:23 +0200 Subject: [PATCH 1/2] Expose model registry to SemanticTextFieldMapper This change integrates the new model registry with the `SemanticTextFieldMapper`, allowing inference IDs to be eagerly resolved at parse time. It also preserves the existing lenient behavior: no error is thrown if the specified inference id does not exist, only a warning is logged. --- .../inference/MinimalServiceSettings.java | 7 +-- .../xpack/inference/InferencePlugin.java | 25 +++++--- .../mapper/SemanticTextFieldMapper.java | 61 ++++++++++++++++--- .../inference/registry/ModelRegistry.java | 15 ++++- .../mapper/SemanticTextFieldMapperTests.java | 35 ++++++++++- .../queries/SemanticQueryBuilderTests.java | 42 ++++++++++++- .../xpack/inference/InferenceRestIT.java | 27 ++++++++ 7 files changed, 188 insertions(+), 24 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java index e4d8ab24f6e73..b9d2696b347c7 100644 --- a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java @@ -249,10 +249,6 @@ private static void validateFieldNotPresent(String field, Object fieldValue, Tas } } - public ModelConfigurations toModelConfigurations(String inferenceEntityId) { - return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this); - } - /** * Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition. */ @@ -260,7 +256,6 @@ public boolean canMergeWith(MinimalServiceSettings other) { return taskType == other.taskType && Objects.equals(dimensions, other.dimensions) && similarity == other.similarity - && elementType == other.elementType - && (service == null || service.equals(other.service)); + && elementType == other.elementType; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index a270e52af8552..af1cb41c3871d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -197,6 +197,7 @@ public class InferencePlugin extends Plugin private final SetOnce elasticInferenceServiceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); + private final SetOnce modelRegistry = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -260,8 +261,8 @@ public Collection createComponents(PluginServices services) { var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService()); amazonBedrockFactory.set(amazonBedrockRequestSenderFactory); - ModelRegistry modelRegistry = new ModelRegistry(services.clusterService(), services.client()); - services.clusterService().addListener(modelRegistry); + modelRegistry.set(new ModelRegistry(services.clusterService(), services.client())); + services.clusterService().addListener(modelRegistry.get()); if (inferenceServiceExtensions == null) { inferenceServiceExtensions = new ArrayList<>(); @@ -299,7 +300,7 @@ public Collection createComponents(PluginServices services) { elasicInferenceServiceFactory.get(), serviceComponents.get(), inferenceServiceSettings, - modelRegistry, + modelRegistry.get(), authorizationHandler ) ) @@ -317,18 +318,23 @@ public Collection createComponents(PluginServices services) { var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext); serviceRegistry.init(services.client()); for (var service : serviceRegistry.getServices().values()) { - service.defaultConfigIds().forEach(modelRegistry::addDefaultIds); + service.defaultConfigIds().forEach(modelRegistry.get()::addDefaultIds); } inferenceServiceRegistry.set(serviceRegistry); - var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState()); + var actionFilter = new ShardBulkInferenceActionFilter( + services.clusterService(), + serviceRegistry, + modelRegistry.get(), + getLicenseState() + ); shardBulkInferenceActionFilter.set(actionFilter); var meterRegistry = services.telemetryProvider().getMeterRegistry(); var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry)); components.add(serviceRegistry); - components.add(modelRegistry); + components.add(modelRegistry.get()); components.add(httpClientManager); components.add(inferenceStats); @@ -492,11 +498,16 @@ public Map getMetadataMappers() { return Map.of(SemanticInferenceMetadataFieldsMapper.NAME, SemanticInferenceMetadataFieldsMapper.PARSER); } + // Overridable for tests + protected Supplier getModelRegistry() { + return () -> modelRegistry.get(); + } + @Override public Map getMappers() { return Map.of( SemanticTextFieldMapper.CONTENT_TYPE, - SemanticTextFieldMapper.PARSER, + SemanticTextFieldMapper.parser(getModelRegistry()), OffsetSourceFieldMapper.CONTENT_TYPE, OffsetSourceFieldMapper.PARSER ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 3a942a8e73537..5e63b79e07bbe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.mapper; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; @@ -18,6 +20,7 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.util.BitSet; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; @@ -75,6 +78,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.io.UncheckedIOException; @@ -89,6 +93,7 @@ import java.util.Set; import java.util.function.BiConsumer; import java.util.function.Function; +import java.util.function.Supplier; import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; @@ -112,6 +117,7 @@ * A {@link FieldMapper} for semantic text fields. */ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { + private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final NodeFeature SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX = new NodeFeature("semantic_text.in_object_field_fix"); public static final NodeFeature SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX = new NodeFeature("semantic_text.single_field_update_fix"); public static final NodeFeature SEMANTIC_TEXT_DELETE_FIX = new NodeFeature("semantic_text.delete_fix"); @@ -127,10 +133,12 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie public static final String CONTENT_TYPE = "semantic_text"; public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID; - public static final TypeParser PARSER = new TypeParser( - (n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings()), - List.of(validateParserContext(CONTENT_TYPE)) - ); + public static final TypeParser parser(Supplier modelRegistry) { + return new TypeParser( + (n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings(), modelRegistry.get()), + List.of(validateParserContext(CONTENT_TYPE)) + ); + } public static BiConsumer validateParserContext(String type) { return (n, c) -> { @@ -142,6 +150,7 @@ public static BiConsumer validateParserContext(Str } public static class Builder extends FieldMapper.Builder { + private final ModelRegistry modelRegistry; private final boolean useLegacyFormat; private final Parameter inferenceId = Parameter.stringParam( @@ -199,14 +208,21 @@ public static Builder from(SemanticTextFieldMapper mapper) { Builder builder = new Builder( mapper.leafName(), mapper.fieldType().getChunksField().bitsetProducer(), - mapper.fieldType().getChunksField().indexSettings() + mapper.fieldType().getChunksField().indexSettings(), + mapper.modelRegistry ); builder.init(mapper); return builder; } - public Builder(String name, Function bitSetProducer, IndexSettings indexSettings) { + public Builder( + String name, + Function bitSetProducer, + IndexSettings indexSettings, + ModelRegistry modelRegistry + ) { super(name); + this.modelRegistry = modelRegistry; this.useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(indexSettings.getSettings()) == false; this.inferenceFieldBuilder = c -> createInferenceField( c, @@ -264,6 +280,26 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { if (useLegacyFormat && multiFieldsBuilder.hasMultiFields()) { throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields"); } + if (modelSettings.get() == null) { + try { + var resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get()); + if (resolvedModelSettings != null) { + modelSettings.setValue(resolvedModelSettings); + } + } catch (ResourceNotFoundException exc) { + // We allow the inference ID to be unregistered at this point. + // This will delay the creation of sub-fields, so indexing and querying for this field won't work + // until the corresponding inference endpoint is created. + logger.warn( + "The field [{}] references an unknown inference ID [{}]. " + + "Indexing and querying this field will not work correctly until the corresponding " + + "inference endpoint is created.", + leafName(), + inferenceId.get() + ); + } + } + if (modelSettings.get() != null) { validateServiceSettings(modelSettings.get()); } @@ -287,7 +323,8 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { useLegacyFormat, meta.getValue() ), - builderParams(this, context) + builderParams(this, context), + modelRegistry ); } @@ -328,9 +365,17 @@ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, Map } } - private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams) { + private final ModelRegistry modelRegistry; + + private SemanticTextFieldMapper( + String simpleName, + MappedFieldType mappedFieldType, + BuilderParams builderParams, + ModelRegistry modelRegistry + ) { super(simpleName, mappedFieldType, builderParams); ensureMultiFields(builderParams.multiFields().iterator()); + this.modelRegistry = modelRegistry; } private void ensureMultiFields(Iterator mappers) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 19c5dcfa96e42..04d4fa2c1d3c6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -148,6 +148,8 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false); private final Set preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>()); + private volatile ClusterState lastClusterState; + public ModelRegistry(ClusterService clusterService, Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); this.defaultConfigIds = new ConcurrentHashMap<>(); @@ -224,11 +226,17 @@ public void clearDefaultIds() { * @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster. */ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException { + synchronized (this) { + assert lastClusterState != null : "initial cluster state not set yet"; + if (lastClusterState == null) { + throw new IllegalStateException("initial cluster state not set yet"); + } + } var config = defaultConfigIds.get(inferenceEntityId); if (config != null) { return config.settings(); } - var state = ModelRegistryMetadata.fromState(clusterService.state().projectState().metadata()); + var state = ModelRegistryMetadata.fromState(lastClusterState.projectState().metadata()); var existing = state.getMinimalServiceSettings(inferenceEntityId); if (state.isUpgraded() && existing == null) { throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster."); @@ -935,6 +943,11 @@ static List taskTypeMatchedDefaults( @Override public void clusterChanged(ClusterChangedEvent event) { + // keep track of the last applied cluster state + synchronized (this) { + lastClusterState = event.state(); + } + if (event.localNodeMaster() == false) { return; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 4d2a76f915af3..397a6867f51ac 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -24,6 +24,7 @@ import org.apache.lucene.search.join.QueryBitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.CheckedBiFunction; @@ -63,6 +64,9 @@ import org.elasticsearch.search.LeafNestedDocuments; import org.elasticsearch.search.NestedDocuments; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; @@ -70,7 +74,10 @@ import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; import org.junit.AssumptionViolatedException; +import org.junit.Before; import java.io.IOException; import java.util.Collection; @@ -80,6 +87,7 @@ import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; +import java.util.function.Supplier; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; @@ -101,10 +109,22 @@ public class SemanticTextFieldMapperTests extends MapperTestCase { private final boolean useLegacyFormat; + private TestThreadPool threadPool; + public SemanticTextFieldMapperTests(boolean useLegacyFormat) { this.useLegacyFormat = useLegacyFormat; } + @Before + private void startThreadPool() { + threadPool = createThreadPool(); + } + + @After + private void stopThreadPool() { + threadPool.close(); + } + @ParametersFactory public static Iterable parameters() throws Exception { return List.of(new Object[] { true }, new Object[] { false }); @@ -112,7 +132,20 @@ public static Iterable parameters() throws Exception { @Override protected Collection getPlugins() { - return List.of(new InferencePlugin(Settings.EMPTY), new XPackClientPlugin()); + var clusterService = ClusterServiceUtils.createClusterService(threadPool); + var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { + @Override + public boolean localNodeMaster() { + return false; + } + }); + return List.of(new InferencePlugin(Settings.EMPTY) { + @Override + protected Supplier getModelRegistry() { + return () -> modelRegistry; + } + }, new XPackClientPlugin()); } private MapperService createMapperService(XContentBuilder mappings, boolean useLegacyFormat) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index c4a6b92ac033c..e0ba14c8959fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -22,12 +22,14 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.IOUtils; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.MapperService; @@ -46,6 +48,9 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.AbstractQueryTestCase; +import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; @@ -60,6 +65,8 @@ import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -70,6 +77,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static org.apache.lucene.search.BooleanClause.Occur.FILTER; import static org.apache.lucene.search.BooleanClause.Occur.MUST; @@ -118,6 +126,24 @@ public static void setInferenceResultType() { useSearchInferenceId = randomBoolean(); } + @BeforeClass + public static void startModelRegistry() { + threadPool = new TestThreadPool(SemanticQueryBuilderTests.class.getName()); + var clusterService = ClusterServiceUtils.createClusterService(threadPool); + modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { + @Override + public boolean localNodeMaster() { + return false; + } + }); + } + + @AfterClass + public static void stopModelRegistry() { + IOUtils.closeWhileHandlingException(threadPool); + } + @Override @Before public void setUp() throws Exception { @@ -127,7 +153,7 @@ public void setUp() throws Exception { @Override protected Collection> getPlugins() { - return List.of(XPackClientPlugin.class, InferencePlugin.class, FakeMlPlugin.class); + return List.of(XPackClientPlugin.class, InferencePluginWithModelRegistry.class, FakeMlPlugin.class); } @Override @@ -394,4 +420,18 @@ public List getNamedWriteables() { return new MlInferenceNamedXContentProvider().getNamedWriteables(); } } + + private static TestThreadPool threadPool; + private static ModelRegistry modelRegistry; + + public static class InferencePluginWithModelRegistry extends InferencePlugin { + public InferencePluginWithModelRegistry(Settings settings) { + super(settings); + } + + @Override + protected Supplier getModelRegistry() { + return () -> modelRegistry; + } + } } diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index da01459b057b6..41b577a32bcaa 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -9,13 +9,20 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.client.Request; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; +import org.junit.After; import org.junit.ClassRule; +import java.io.IOException; +import java.util.List; +import java.util.Map; + public class InferenceRestIT extends ESClientYamlSuiteTestCase { @ClassRule @@ -50,4 +57,24 @@ protected String getTestRestCluster() { public static Iterable parameters() throws Exception { return ESClientYamlSuiteTestCase.createParameters(); } + + @SuppressWarnings("unchecked") + static List> getAllModels() throws IOException { + var request = new Request("GET", "_inference/_all"); + var response = client().performRequest(request); + return (List>) entityAsMap(response).get("endpoints"); + } + + @After + public void cleanup() throws Exception { + for (var model : getAllModels()) { + var inferenceId = model.get("inference_id"); + try { + var endpoint = Strings.format("_inference/%s?force", inferenceId); + adminClient().performRequest(new Request("DELETE", endpoint)); + } catch (Exception ex) { + logger.warn(() -> "failed to delete inference endpoint " + inferenceId, ex); + } + } + } } From a8bc4a2095cffbd909e83ebbe2d821ed0fbaf8ec Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 11 Apr 2025 09:39:16 +0200 Subject: [PATCH 2/2] cleanup --- .../mapper/SemanticTextFieldMapper.java | 17 ++++++---- .../inference/registry/ModelRegistry.java | 33 +++++++++++-------- .../xpack/inference/InferenceRestIT.java | 14 ++++---- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 5e63b79e07bbe..b1c9db52d01c0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -280,6 +280,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { if (useLegacyFormat && multiFieldsBuilder.hasMultiFields()) { throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields"); } + if (modelSettings.get() == null) { try { var resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get()); @@ -290,19 +291,21 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { // We allow the inference ID to be unregistered at this point. // This will delay the creation of sub-fields, so indexing and querying for this field won't work // until the corresponding inference endpoint is created. - logger.warn( - "The field [{}] references an unknown inference ID [{}]. " - + "Indexing and querying this field will not work correctly until the corresponding " - + "inference endpoint is created.", - leafName(), - inferenceId.get() - ); } } if (modelSettings.get() != null) { validateServiceSettings(modelSettings.get()); + } else { + logger.warn( + "The field [{}] references an unknown inference ID [{}]. " + + "Indexing and querying this field will not work correctly until the corresponding " + + "inference endpoint is created.", + leafName(), + inferenceId.get() + ); } + final String fullName = context.buildFullName(leafName()); if (context.isInNestedContext()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 04d4fa2c1d3c6..64cde6dce939e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -36,6 +36,7 @@ import org.elasticsearch.cluster.ClusterStateAckListener; import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.SimpleBatchedAckListenerTaskExecutor; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.service.ClusterService; @@ -140,7 +141,6 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private static final String MODEL_ID_FIELD = "model_id"; private static final Logger logger = LogManager.getLogger(ModelRegistry.class); - private final ClusterService clusterService; private final OriginSettingClient client; private final Map defaultConfigIds; @@ -148,12 +148,11 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false); private final Set preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>()); - private volatile ClusterState lastClusterState; + private volatile Metadata lastMetadata; public ModelRegistry(ClusterService clusterService, Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); this.defaultConfigIds = new ConcurrentHashMap<>(); - this.clusterService = clusterService; var executor = new SimpleBatchedAckListenerTaskExecutor() { @Override public Tuple executeTask(MetadataTask task, ClusterState clusterState) throws Exception { @@ -227,8 +226,8 @@ public void clearDefaultIds() { */ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException { synchronized (this) { - assert lastClusterState != null : "initial cluster state not set yet"; - if (lastClusterState == null) { + assert lastMetadata != null : "initial cluster state not set yet"; + if (lastMetadata == null) { throw new IllegalStateException("initial cluster state not set yet"); } } @@ -236,7 +235,8 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId if (config != null) { return config.settings(); } - var state = ModelRegistryMetadata.fromState(lastClusterState.projectState().metadata()); + var project = lastMetadata.getProject(ProjectId.DEFAULT); + var state = ModelRegistryMetadata.fromState(project); var existing = state.getMinimalServiceSettings(inferenceEntityId); if (state.isUpgraded() && existing == null) { throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster."); @@ -692,10 +692,14 @@ private ActionListener getStoreIndexListener( if (updateClusterState) { var storeListener = getStoreMetadataListener(inferenceEntityId, listener); try { - var projectId = clusterService.state().projectState().projectId(); metadataTaskQueue.submitTask( "add model [" + inferenceEntityId + "]", - new AddModelMetadataTask(projectId, inferenceEntityId, new MinimalServiceSettings(model), storeListener), + new AddModelMetadataTask( + ProjectId.DEFAULT, + inferenceEntityId, + new MinimalServiceSettings(model), + storeListener + ), timeout ); } catch (Exception exc) { @@ -862,10 +866,9 @@ public void onFailure(Exception exc) { } }; try { - var projectId = clusterService.state().projectState().projectId(); metadataTaskQueue.submitTask( "delete models [" + inferenceEntityIds + "]", - new DeleteModelMetadataTask(projectId, inferenceEntityIds, clusterStateListener), + new DeleteModelMetadataTask(ProjectId.DEFAULT, inferenceEntityIds, clusterStateListener), null ); } catch (Exception exc) { @@ -943,9 +946,11 @@ static List taskTypeMatchedDefaults( @Override public void clusterChanged(ClusterChangedEvent event) { - // keep track of the last applied cluster state - synchronized (this) { - lastClusterState = event.state(); + if (lastMetadata == null || event.metadataChanged()) { + // keep track of the last applied cluster state + synchronized (this) { + lastMetadata = event.state().metadata(); + } } if (event.localNodeMaster() == false) { @@ -997,7 +1002,7 @@ public void onResponse(GetInferenceModelAction.Response response) { metadataTaskQueue.submitTask( "model registry auto upgrade", new UpgradeModelsMetadataTask( - clusterService.state().metadata().getProject().id(), + ProjectId.DEFAULT, map, ActionListener.running(() -> upgradeMetadataInProgress.set(false)) ), diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index 41b577a32bcaa..f39b3f2b01368 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -58,13 +58,6 @@ public static Iterable parameters() throws Exception { return ESClientYamlSuiteTestCase.createParameters(); } - @SuppressWarnings("unchecked") - static List> getAllModels() throws IOException { - var request = new Request("GET", "_inference/_all"); - var response = client().performRequest(request); - return (List>) entityAsMap(response).get("endpoints"); - } - @After public void cleanup() throws Exception { for (var model : getAllModels()) { @@ -77,4 +70,11 @@ public void cleanup() throws Exception { } } } + + @SuppressWarnings("unchecked") + static List> getAllModels() throws IOException { + var request = new Request("GET", "_inference/_all"); + var response = client().performRequest(request); + return (List>) entityAsMap(response).get("endpoints"); + } }