Skip to content

Commit b7ff74b

Browse files
committed
Revert "Expose model registry to SemanticTextFieldMapper (elastic#126635)"
This reverts commit c906cc0.
1 parent b5b4e1f commit b7ff74b

File tree

7 files changed

+31
-198
lines changed

7 files changed

+31
-198
lines changed

server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,18 @@ private static void validateFieldNotPresent(String field, Object fieldValue, Tas
249249
}
250250
}
251251

252+
public ModelConfigurations toModelConfigurations(String inferenceEntityId) {
253+
return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this);
254+
}
255+
252256
/**
253257
* Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition.
254258
*/
255259
public boolean canMergeWith(MinimalServiceSettings other) {
256260
return taskType == other.taskType
257261
&& Objects.equals(dimensions, other.dimensions)
258262
&& similarity == other.similarity
259-
&& elementType == other.elementType;
263+
&& elementType == other.elementType
264+
&& (service == null || service.equals(other.service));
260265
}
261266
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ public class InferencePlugin extends Plugin
197197
private final SetOnce<ElasticInferenceServiceComponents> elasticInferenceServiceComponents = new SetOnce<>();
198198
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
199199
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
200-
private final SetOnce<ModelRegistry> modelRegistry = new SetOnce<>();
201200
private List<InferenceServiceExtension> inferenceServiceExtensions;
202201

203202
public InferencePlugin(Settings settings) {
@@ -261,8 +260,8 @@ public Collection<?> createComponents(PluginServices services) {
261260
var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService());
262261
amazonBedrockFactory.set(amazonBedrockRequestSenderFactory);
263262

264-
modelRegistry.set(new ModelRegistry(services.clusterService(), services.client()));
265-
services.clusterService().addListener(modelRegistry.get());
263+
ModelRegistry modelRegistry = new ModelRegistry(services.clusterService(), services.client());
264+
services.clusterService().addListener(modelRegistry);
266265

267266
if (inferenceServiceExtensions == null) {
268267
inferenceServiceExtensions = new ArrayList<>();
@@ -300,7 +299,7 @@ public Collection<?> createComponents(PluginServices services) {
300299
elasicInferenceServiceFactory.get(),
301300
serviceComponents.get(),
302301
inferenceServiceSettings,
303-
modelRegistry.get(),
302+
modelRegistry,
304303
authorizationHandler
305304
)
306305
)
@@ -318,14 +317,14 @@ public Collection<?> createComponents(PluginServices services) {
318317
var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext);
319318
serviceRegistry.init(services.client());
320319
for (var service : serviceRegistry.getServices().values()) {
321-
service.defaultConfigIds().forEach(modelRegistry.get()::addDefaultIds);
320+
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
322321
}
323322
inferenceServiceRegistry.set(serviceRegistry);
324323

325324
var actionFilter = new ShardBulkInferenceActionFilter(
326325
services.clusterService(),
327326
serviceRegistry,
328-
modelRegistry.get(),
327+
modelRegistry,
329328
getLicenseState(),
330329
services.indexingPressure()
331330
);
@@ -335,7 +334,7 @@ public Collection<?> createComponents(PluginServices services) {
335334
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
336335

337336
components.add(serviceRegistry);
338-
components.add(modelRegistry.get());
337+
components.add(modelRegistry);
339338
components.add(httpClientManager);
340339
components.add(inferenceStats);
341340

@@ -499,16 +498,11 @@ public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
499498
return Map.of(SemanticInferenceMetadataFieldsMapper.NAME, SemanticInferenceMetadataFieldsMapper.PARSER);
500499
}
501500

502-
// Overridable for tests
503-
protected Supplier<ModelRegistry> getModelRegistry() {
504-
return () -> modelRegistry.get();
505-
}
506-
507501
@Override
508502
public Map<String, Mapper.TypeParser> getMappers() {
509503
return Map.of(
510504
SemanticTextFieldMapper.CONTENT_TYPE,
511-
SemanticTextFieldMapper.parser(getModelRegistry()),
505+
SemanticTextFieldMapper.PARSER,
512506
OffsetSourceFieldMapper.CONTENT_TYPE,
513507
OffsetSourceFieldMapper.PARSER
514508
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 8 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.mapper;
99

10-
import org.apache.logging.log4j.LogManager;
11-
import org.apache.logging.log4j.Logger;
1210
import org.apache.lucene.index.FieldInfos;
1311
import org.apache.lucene.index.LeafReaderContext;
1412
import org.apache.lucene.search.DocIdSetIterator;
@@ -20,7 +18,6 @@
2018
import org.apache.lucene.search.join.BitSetProducer;
2119
import org.apache.lucene.search.join.ScoreMode;
2220
import org.apache.lucene.util.BitSet;
23-
import org.elasticsearch.ResourceNotFoundException;
2421
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
2522
import org.elasticsearch.common.Strings;
2623
import org.elasticsearch.common.bytes.BytesReference;
@@ -78,7 +75,6 @@
7875
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
7976
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
8077
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
81-
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
8278

8379
import java.io.IOException;
8480
import java.io.UncheckedIOException;
@@ -93,7 +89,6 @@
9389
import java.util.Set;
9490
import java.util.function.BiConsumer;
9591
import java.util.function.Function;
96-
import java.util.function.Supplier;
9792

9893
import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
9994
import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
@@ -117,7 +112,6 @@
117112
* A {@link FieldMapper} for semantic text fields.
118113
*/
119114
public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper {
120-
private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class);
121115
public static final NodeFeature SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX = new NodeFeature("semantic_text.in_object_field_fix");
122116
public static final NodeFeature SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX = new NodeFeature("semantic_text.single_field_update_fix");
123117
public static final NodeFeature SEMANTIC_TEXT_DELETE_FIX = new NodeFeature("semantic_text.delete_fix");
@@ -133,12 +127,10 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
133127
public static final String CONTENT_TYPE = "semantic_text";
134128
public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID;
135129

136-
public static final TypeParser parser(Supplier<ModelRegistry> modelRegistry) {
137-
return new TypeParser(
138-
(n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings(), modelRegistry.get()),
139-
List.of(validateParserContext(CONTENT_TYPE))
140-
);
141-
}
130+
public static final TypeParser PARSER = new TypeParser(
131+
(n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings()),
132+
List.of(validateParserContext(CONTENT_TYPE))
133+
);
142134

143135
public static BiConsumer<String, MappingParserContext> validateParserContext(String type) {
144136
return (n, c) -> {
@@ -150,7 +142,6 @@ public static BiConsumer<String, MappingParserContext> validateParserContext(Str
150142
}
151143

152144
public static class Builder extends FieldMapper.Builder {
153-
private final ModelRegistry modelRegistry;
154145
private final boolean useLegacyFormat;
155146

156147
private final Parameter<String> inferenceId = Parameter.stringParam(
@@ -208,21 +199,14 @@ public static Builder from(SemanticTextFieldMapper mapper) {
208199
Builder builder = new Builder(
209200
mapper.leafName(),
210201
mapper.fieldType().getChunksField().bitsetProducer(),
211-
mapper.fieldType().getChunksField().indexSettings(),
212-
mapper.modelRegistry
202+
mapper.fieldType().getChunksField().indexSettings()
213203
);
214204
builder.init(mapper);
215205
return builder;
216206
}
217207

218-
public Builder(
219-
String name,
220-
Function<Query, BitSetProducer> bitSetProducer,
221-
IndexSettings indexSettings,
222-
ModelRegistry modelRegistry
223-
) {
208+
public Builder(String name, Function<Query, BitSetProducer> bitSetProducer, IndexSettings indexSettings) {
224209
super(name);
225-
this.modelRegistry = modelRegistry;
226210
this.useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(indexSettings.getSettings()) == false;
227211
this.inferenceFieldBuilder = c -> createInferenceField(
228212
c,
@@ -280,32 +264,9 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
280264
if (useLegacyFormat && multiFieldsBuilder.hasMultiFields()) {
281265
throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields");
282266
}
283-
284-
if (modelSettings.get() == null) {
285-
try {
286-
var resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get());
287-
if (resolvedModelSettings != null) {
288-
modelSettings.setValue(resolvedModelSettings);
289-
}
290-
} catch (ResourceNotFoundException exc) {
291-
// We allow the inference ID to be unregistered at this point.
292-
// This will delay the creation of sub-fields, so indexing and querying for this field won't work
293-
// until the corresponding inference endpoint is created.
294-
}
295-
}
296-
297267
if (modelSettings.get() != null) {
298268
validateServiceSettings(modelSettings.get());
299-
} else {
300-
logger.warn(
301-
"The field [{}] references an unknown inference ID [{}]. "
302-
+ "Indexing and querying this field will not work correctly until the corresponding "
303-
+ "inference endpoint is created.",
304-
leafName(),
305-
inferenceId.get()
306-
);
307269
}
308-
309270
final String fullName = context.buildFullName(leafName());
310271

311272
if (context.isInNestedContext()) {
@@ -326,8 +287,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
326287
useLegacyFormat,
327288
meta.getValue()
328289
),
329-
builderParams(this, context),
330-
modelRegistry
290+
builderParams(this, context)
331291
);
332292
}
333293

@@ -368,17 +328,9 @@ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, Map
368328
}
369329
}
370330

371-
private final ModelRegistry modelRegistry;
372-
373-
private SemanticTextFieldMapper(
374-
String simpleName,
375-
MappedFieldType mappedFieldType,
376-
BuilderParams builderParams,
377-
ModelRegistry modelRegistry
378-
) {
331+
private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams) {
379332
super(simpleName, mappedFieldType, builderParams);
380333
ensureMultiFields(builderParams.multiFields().iterator());
381-
this.modelRegistry = modelRegistry;
382334
}
383335

384336
private void ensureMultiFields(Iterator<FieldMapper> mappers) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import org.elasticsearch.cluster.ClusterStateAckListener;
3737
import org.elasticsearch.cluster.ClusterStateListener;
3838
import org.elasticsearch.cluster.SimpleBatchedAckListenerTaskExecutor;
39-
import org.elasticsearch.cluster.metadata.Metadata;
4039
import org.elasticsearch.cluster.metadata.ProjectId;
4140
import org.elasticsearch.cluster.metadata.ProjectMetadata;
4241
import org.elasticsearch.cluster.service.ClusterService;
@@ -141,18 +140,18 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap)
141140
private static final String MODEL_ID_FIELD = "model_id";
142141
private static final Logger logger = LogManager.getLogger(ModelRegistry.class);
143142

143+
private final ClusterService clusterService;
144144
private final OriginSettingClient client;
145145
private final Map<String, InferenceService.DefaultConfigId> defaultConfigIds;
146146

147147
private final MasterServiceTaskQueue<MetadataTask> metadataTaskQueue;
148148
private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false);
149149
private final Set<String> preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>());
150150

151-
private volatile Metadata lastMetadata;
152-
153151
public ModelRegistry(ClusterService clusterService, Client client) {
154152
this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN);
155153
this.defaultConfigIds = new ConcurrentHashMap<>();
154+
this.clusterService = clusterService;
156155
var executor = new SimpleBatchedAckListenerTaskExecutor<MetadataTask>() {
157156
@Override
158157
public Tuple<ClusterState, ClusterStateAckListener> executeTask(MetadataTask task, ClusterState clusterState) throws Exception {
@@ -225,18 +224,11 @@ public void clearDefaultIds() {
225224
* @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster.
226225
*/
227226
public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
228-
synchronized (this) {
229-
assert lastMetadata != null : "initial cluster state not set yet";
230-
if (lastMetadata == null) {
231-
throw new IllegalStateException("initial cluster state not set yet");
232-
}
233-
}
234227
var config = defaultConfigIds.get(inferenceEntityId);
235228
if (config != null) {
236229
return config.settings();
237230
}
238-
var project = lastMetadata.getProject(ProjectId.DEFAULT);
239-
var state = ModelRegistryMetadata.fromState(project);
231+
var state = ModelRegistryMetadata.fromState(clusterService.state().projectState().metadata());
240232
var existing = state.getMinimalServiceSettings(inferenceEntityId);
241233
if (state.isUpgraded() && existing == null) {
242234
throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster.");
@@ -692,14 +684,10 @@ private ActionListener<BulkResponse> getStoreIndexListener(
692684
if (updateClusterState) {
693685
var storeListener = getStoreMetadataListener(inferenceEntityId, listener);
694686
try {
687+
var projectId = clusterService.state().projectState().projectId();
695688
metadataTaskQueue.submitTask(
696689
"add model [" + inferenceEntityId + "]",
697-
new AddModelMetadataTask(
698-
ProjectId.DEFAULT,
699-
inferenceEntityId,
700-
new MinimalServiceSettings(model),
701-
storeListener
702-
),
690+
new AddModelMetadataTask(projectId, inferenceEntityId, new MinimalServiceSettings(model), storeListener),
703691
timeout
704692
);
705693
} catch (Exception exc) {
@@ -866,9 +854,10 @@ public void onFailure(Exception exc) {
866854
}
867855
};
868856
try {
857+
var projectId = clusterService.state().projectState().projectId();
869858
metadataTaskQueue.submitTask(
870859
"delete models [" + inferenceEntityIds + "]",
871-
new DeleteModelMetadataTask(ProjectId.DEFAULT, inferenceEntityIds, clusterStateListener),
860+
new DeleteModelMetadataTask(projectId, inferenceEntityIds, clusterStateListener),
872861
null
873862
);
874863
} catch (Exception exc) {
@@ -946,13 +935,6 @@ static List<InferenceService.DefaultConfigId> taskTypeMatchedDefaults(
946935

947936
@Override
948937
public void clusterChanged(ClusterChangedEvent event) {
949-
if (lastMetadata == null || event.metadataChanged()) {
950-
// keep track of the last applied cluster state
951-
synchronized (this) {
952-
lastMetadata = event.state().metadata();
953-
}
954-
}
955-
956938
if (event.localNodeMaster() == false) {
957939
return;
958940
}
@@ -1002,7 +984,7 @@ public void onResponse(GetInferenceModelAction.Response response) {
1002984
metadataTaskQueue.submitTask(
1003985
"model registry auto upgrade",
1004986
new UpgradeModelsMetadataTask(
1005-
ProjectId.DEFAULT,
987+
clusterService.state().metadata().getProject().id(),
1006988
map,
1007989
ActionListener.running(() -> upgradeMetadataInProgress.set(false))
1008990
),

0 commit comments

Comments
 (0)