Skip to content

Commit c906cc0

Browse files
authored
Expose model registry to SemanticTextFieldMapper (#126635)
This change integrates the new model registry with the `SemanticTextFieldMapper`, allowing inference IDs to be eagerly resolved at parse time. It also preserves the existing lenient behavior: no error is thrown if the specified inference id does not exist, only a warning is logged.
1 parent 358b724 commit c906cc0

File tree

7 files changed

+198
-31
lines changed

7 files changed

+198
-31
lines changed

server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,18 +249,13 @@ private static void validateFieldNotPresent(String field, Object fieldValue, Tas
249249
}
250250
}
251251

252-
public ModelConfigurations toModelConfigurations(String inferenceEntityId) {
253-
return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this);
254-
}
255-
256252
/**
257253
* Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition.
258254
*/
259255
public boolean canMergeWith(MinimalServiceSettings other) {
260256
return taskType == other.taskType
261257
&& Objects.equals(dimensions, other.dimensions)
262258
&& similarity == other.similarity
263-
&& elementType == other.elementType
264-
&& (service == null || service.equals(other.service));
259+
&& elementType == other.elementType;
265260
}
266261
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ public class InferencePlugin extends Plugin
197197
private final SetOnce<ElasticInferenceServiceComponents> elasticInferenceServiceComponents = new SetOnce<>();
198198
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
199199
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
200+
private final SetOnce<ModelRegistry> modelRegistry = new SetOnce<>();
200201
private List<InferenceServiceExtension> inferenceServiceExtensions;
201202

202203
public InferencePlugin(Settings settings) {
@@ -260,8 +261,8 @@ public Collection<?> createComponents(PluginServices services) {
260261
var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService());
261262
amazonBedrockFactory.set(amazonBedrockRequestSenderFactory);
262263

263-
ModelRegistry modelRegistry = new ModelRegistry(services.clusterService(), services.client());
264-
services.clusterService().addListener(modelRegistry);
264+
modelRegistry.set(new ModelRegistry(services.clusterService(), services.client()));
265+
services.clusterService().addListener(modelRegistry.get());
265266

266267
if (inferenceServiceExtensions == null) {
267268
inferenceServiceExtensions = new ArrayList<>();
@@ -299,7 +300,7 @@ public Collection<?> createComponents(PluginServices services) {
299300
elasicInferenceServiceFactory.get(),
300301
serviceComponents.get(),
301302
inferenceServiceSettings,
302-
modelRegistry,
303+
modelRegistry.get(),
303304
authorizationHandler
304305
)
305306
)
@@ -317,14 +318,14 @@ public Collection<?> createComponents(PluginServices services) {
317318
var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext);
318319
serviceRegistry.init(services.client());
319320
for (var service : serviceRegistry.getServices().values()) {
320-
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
321+
service.defaultConfigIds().forEach(modelRegistry.get()::addDefaultIds);
321322
}
322323
inferenceServiceRegistry.set(serviceRegistry);
323324

324325
var actionFilter = new ShardBulkInferenceActionFilter(
325326
services.clusterService(),
326327
serviceRegistry,
327-
modelRegistry,
328+
modelRegistry.get(),
328329
getLicenseState(),
329330
services.indexingPressure()
330331
);
@@ -334,7 +335,7 @@ public Collection<?> createComponents(PluginServices services) {
334335
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
335336

336337
components.add(serviceRegistry);
337-
components.add(modelRegistry);
338+
components.add(modelRegistry.get());
338339
components.add(httpClientManager);
339340
components.add(inferenceStats);
340341

@@ -498,11 +499,16 @@ public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
498499
return Map.of(SemanticInferenceMetadataFieldsMapper.NAME, SemanticInferenceMetadataFieldsMapper.PARSER);
499500
}
500501

502+
// Overridable for tests
503+
protected Supplier<ModelRegistry> getModelRegistry() {
504+
return () -> modelRegistry.get();
505+
}
506+
501507
@Override
502508
public Map<String, Mapper.TypeParser> getMappers() {
503509
return Map.of(
504510
SemanticTextFieldMapper.CONTENT_TYPE,
505-
SemanticTextFieldMapper.PARSER,
511+
SemanticTextFieldMapper.parser(getModelRegistry()),
506512
OffsetSourceFieldMapper.CONTENT_TYPE,
507513
OffsetSourceFieldMapper.PARSER
508514
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.mapper;
99

10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
1012
import org.apache.lucene.index.FieldInfos;
1113
import org.apache.lucene.index.LeafReaderContext;
1214
import org.apache.lucene.search.DocIdSetIterator;
@@ -18,6 +20,7 @@
1820
import org.apache.lucene.search.join.BitSetProducer;
1921
import org.apache.lucene.search.join.ScoreMode;
2022
import org.apache.lucene.util.BitSet;
23+
import org.elasticsearch.ResourceNotFoundException;
2124
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
2225
import org.elasticsearch.common.Strings;
2326
import org.elasticsearch.common.bytes.BytesReference;
@@ -75,6 +78,7 @@
7578
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
7679
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
7780
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
81+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
7882

7983
import java.io.IOException;
8084
import java.io.UncheckedIOException;
@@ -89,6 +93,7 @@
8993
import java.util.Set;
9094
import java.util.function.BiConsumer;
9195
import java.util.function.Function;
96+
import java.util.function.Supplier;
9297

9398
import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
9499
import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
@@ -112,6 +117,7 @@
112117
* A {@link FieldMapper} for semantic text fields.
113118
*/
114119
public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper {
120+
private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class);
115121
public static final NodeFeature SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX = new NodeFeature("semantic_text.in_object_field_fix");
116122
public static final NodeFeature SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX = new NodeFeature("semantic_text.single_field_update_fix");
117123
public static final NodeFeature SEMANTIC_TEXT_DELETE_FIX = new NodeFeature("semantic_text.delete_fix");
@@ -127,10 +133,12 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
127133
public static final String CONTENT_TYPE = "semantic_text";
128134
public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID;
129135

130-
public static final TypeParser PARSER = new TypeParser(
131-
(n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings()),
132-
List.of(validateParserContext(CONTENT_TYPE))
133-
);
136+
public static final TypeParser parser(Supplier<ModelRegistry> modelRegistry) {
137+
return new TypeParser(
138+
(n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings(), modelRegistry.get()),
139+
List.of(validateParserContext(CONTENT_TYPE))
140+
);
141+
}
134142

135143
public static BiConsumer<String, MappingParserContext> validateParserContext(String type) {
136144
return (n, c) -> {
@@ -142,6 +150,7 @@ public static BiConsumer<String, MappingParserContext> validateParserContext(Str
142150
}
143151

144152
public static class Builder extends FieldMapper.Builder {
153+
private final ModelRegistry modelRegistry;
145154
private final boolean useLegacyFormat;
146155

147156
private final Parameter<String> inferenceId = Parameter.stringParam(
@@ -199,14 +208,21 @@ public static Builder from(SemanticTextFieldMapper mapper) {
199208
Builder builder = new Builder(
200209
mapper.leafName(),
201210
mapper.fieldType().getChunksField().bitsetProducer(),
202-
mapper.fieldType().getChunksField().indexSettings()
211+
mapper.fieldType().getChunksField().indexSettings(),
212+
mapper.modelRegistry
203213
);
204214
builder.init(mapper);
205215
return builder;
206216
}
207217

208-
public Builder(String name, Function<Query, BitSetProducer> bitSetProducer, IndexSettings indexSettings) {
218+
public Builder(
219+
String name,
220+
Function<Query, BitSetProducer> bitSetProducer,
221+
IndexSettings indexSettings,
222+
ModelRegistry modelRegistry
223+
) {
209224
super(name);
225+
this.modelRegistry = modelRegistry;
210226
this.useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(indexSettings.getSettings()) == false;
211227
this.inferenceFieldBuilder = c -> createInferenceField(
212228
c,
@@ -264,9 +280,32 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
264280
if (useLegacyFormat && multiFieldsBuilder.hasMultiFields()) {
265281
throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields");
266282
}
283+
284+
if (modelSettings.get() == null) {
285+
try {
286+
var resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get());
287+
if (resolvedModelSettings != null) {
288+
modelSettings.setValue(resolvedModelSettings);
289+
}
290+
} catch (ResourceNotFoundException exc) {
291+
// We allow the inference ID to be unregistered at this point.
292+
// This will delay the creation of sub-fields, so indexing and querying for this field won't work
293+
// until the corresponding inference endpoint is created.
294+
}
295+
}
296+
267297
if (modelSettings.get() != null) {
268298
validateServiceSettings(modelSettings.get());
299+
} else {
300+
logger.warn(
301+
"The field [{}] references an unknown inference ID [{}]. "
302+
+ "Indexing and querying this field will not work correctly until the corresponding "
303+
+ "inference endpoint is created.",
304+
leafName(),
305+
inferenceId.get()
306+
);
269307
}
308+
270309
final String fullName = context.buildFullName(leafName());
271310

272311
if (context.isInNestedContext()) {
@@ -287,7 +326,8 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
287326
useLegacyFormat,
288327
meta.getValue()
289328
),
290-
builderParams(this, context)
329+
builderParams(this, context),
330+
modelRegistry
291331
);
292332
}
293333

@@ -328,9 +368,17 @@ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, Map
328368
}
329369
}
330370

331-
private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams) {
371+
private final ModelRegistry modelRegistry;
372+
373+
private SemanticTextFieldMapper(
374+
String simpleName,
375+
MappedFieldType mappedFieldType,
376+
BuilderParams builderParams,
377+
ModelRegistry modelRegistry
378+
) {
332379
super(simpleName, mappedFieldType, builderParams);
333380
ensureMultiFields(builderParams.multiFields().iterator());
381+
this.modelRegistry = modelRegistry;
334382
}
335383

336384
private void ensureMultiFields(Iterator<FieldMapper> mappers) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.elasticsearch.cluster.ClusterStateAckListener;
3737
import org.elasticsearch.cluster.ClusterStateListener;
3838
import org.elasticsearch.cluster.SimpleBatchedAckListenerTaskExecutor;
39+
import org.elasticsearch.cluster.metadata.Metadata;
3940
import org.elasticsearch.cluster.metadata.ProjectId;
4041
import org.elasticsearch.cluster.metadata.ProjectMetadata;
4142
import org.elasticsearch.cluster.service.ClusterService;
@@ -140,18 +141,18 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap)
140141
private static final String MODEL_ID_FIELD = "model_id";
141142
private static final Logger logger = LogManager.getLogger(ModelRegistry.class);
142143

143-
private final ClusterService clusterService;
144144
private final OriginSettingClient client;
145145
private final Map<String, InferenceService.DefaultConfigId> defaultConfigIds;
146146

147147
private final MasterServiceTaskQueue<MetadataTask> metadataTaskQueue;
148148
private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false);
149149
private final Set<String> preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>());
150150

151+
private volatile Metadata lastMetadata;
152+
151153
public ModelRegistry(ClusterService clusterService, Client client) {
152154
this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN);
153155
this.defaultConfigIds = new ConcurrentHashMap<>();
154-
this.clusterService = clusterService;
155156
var executor = new SimpleBatchedAckListenerTaskExecutor<MetadataTask>() {
156157
@Override
157158
public Tuple<ClusterState, ClusterStateAckListener> executeTask(MetadataTask task, ClusterState clusterState) throws Exception {
@@ -224,11 +225,18 @@ public void clearDefaultIds() {
224225
* @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster.
225226
*/
226227
public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
228+
synchronized (this) {
229+
assert lastMetadata != null : "initial cluster state not set yet";
230+
if (lastMetadata == null) {
231+
throw new IllegalStateException("initial cluster state not set yet");
232+
}
233+
}
227234
var config = defaultConfigIds.get(inferenceEntityId);
228235
if (config != null) {
229236
return config.settings();
230237
}
231-
var state = ModelRegistryMetadata.fromState(clusterService.state().projectState().metadata());
238+
var project = lastMetadata.getProject(ProjectId.DEFAULT);
239+
var state = ModelRegistryMetadata.fromState(project);
232240
var existing = state.getMinimalServiceSettings(inferenceEntityId);
233241
if (state.isUpgraded() && existing == null) {
234242
throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster.");
@@ -684,10 +692,14 @@ private ActionListener<BulkResponse> getStoreIndexListener(
684692
if (updateClusterState) {
685693
var storeListener = getStoreMetadataListener(inferenceEntityId, listener);
686694
try {
687-
var projectId = clusterService.state().projectState().projectId();
688695
metadataTaskQueue.submitTask(
689696
"add model [" + inferenceEntityId + "]",
690-
new AddModelMetadataTask(projectId, inferenceEntityId, new MinimalServiceSettings(model), storeListener),
697+
new AddModelMetadataTask(
698+
ProjectId.DEFAULT,
699+
inferenceEntityId,
700+
new MinimalServiceSettings(model),
701+
storeListener
702+
),
691703
timeout
692704
);
693705
} catch (Exception exc) {
@@ -854,10 +866,9 @@ public void onFailure(Exception exc) {
854866
}
855867
};
856868
try {
857-
var projectId = clusterService.state().projectState().projectId();
858869
metadataTaskQueue.submitTask(
859870
"delete models [" + inferenceEntityIds + "]",
860-
new DeleteModelMetadataTask(projectId, inferenceEntityIds, clusterStateListener),
871+
new DeleteModelMetadataTask(ProjectId.DEFAULT, inferenceEntityIds, clusterStateListener),
861872
null
862873
);
863874
} catch (Exception exc) {
@@ -935,6 +946,13 @@ static List<InferenceService.DefaultConfigId> taskTypeMatchedDefaults(
935946

936947
@Override
937948
public void clusterChanged(ClusterChangedEvent event) {
949+
if (lastMetadata == null || event.metadataChanged()) {
950+
// keep track of the last applied cluster state
951+
synchronized (this) {
952+
lastMetadata = event.state().metadata();
953+
}
954+
}
955+
938956
if (event.localNodeMaster() == false) {
939957
return;
940958
}
@@ -984,7 +1002,7 @@ public void onResponse(GetInferenceModelAction.Response response) {
9841002
metadataTaskQueue.submitTask(
9851003
"model registry auto upgrade",
9861004
new UpgradeModelsMetadataTask(
987-
clusterService.state().metadata().getProject().id(),
1005+
ProjectId.DEFAULT,
9881006
map,
9891007
ActionListener.running(() -> upgradeMetadataInProgress.set(false))
9901008
),

0 commit comments

Comments
 (0)