Skip to content

Commit c14d90f

Browse files
committed
Expose model registry to SemanticTextFieldMapper
This change integrates the new model registry with the `SemanticTextFieldMapper`, allowing inference IDs to be eagerly resolved at parse time. It also preserves the existing lenient behavior: no error is thrown if the specified inference id does not exist, only a warning is logged.
1 parent a73f923 commit c14d90f

File tree

7 files changed

+188
-24
lines changed

7 files changed

+188
-24
lines changed

server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,18 +249,13 @@ private static void validateFieldNotPresent(String field, Object fieldValue, Tas
249249
}
250250
}
251251

252-
public ModelConfigurations toModelConfigurations(String inferenceEntityId) {
253-
return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this);
254-
}
255-
256252
/**
257253
* Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition.
258254
*/
259255
public boolean canMergeWith(MinimalServiceSettings other) {
260256
return taskType == other.taskType
261257
&& Objects.equals(dimensions, other.dimensions)
262258
&& similarity == other.similarity
263-
&& elementType == other.elementType
264-
&& (service == null || service.equals(other.service));
259+
&& elementType == other.elementType;
265260
}
266261
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ public class InferencePlugin extends Plugin
197197
private final SetOnce<ElasticInferenceServiceComponents> elasticInferenceServiceComponents = new SetOnce<>();
198198
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
199199
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
200+
private final SetOnce<ModelRegistry> modelRegistry = new SetOnce<>();
200201
private List<InferenceServiceExtension> inferenceServiceExtensions;
201202

202203
public InferencePlugin(Settings settings) {
@@ -260,8 +261,8 @@ public Collection<?> createComponents(PluginServices services) {
260261
var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService());
261262
amazonBedrockFactory.set(amazonBedrockRequestSenderFactory);
262263

263-
ModelRegistry modelRegistry = new ModelRegistry(services.clusterService(), services.client());
264-
services.clusterService().addListener(modelRegistry);
264+
modelRegistry.set(new ModelRegistry(services.clusterService(), services.client()));
265+
services.clusterService().addListener(modelRegistry.get());
265266

266267
if (inferenceServiceExtensions == null) {
267268
inferenceServiceExtensions = new ArrayList<>();
@@ -299,7 +300,7 @@ public Collection<?> createComponents(PluginServices services) {
299300
elasicInferenceServiceFactory.get(),
300301
serviceComponents.get(),
301302
inferenceServiceSettings,
302-
modelRegistry,
303+
modelRegistry.get(),
303304
authorizationHandler
304305
)
305306
)
@@ -317,18 +318,23 @@ public Collection<?> createComponents(PluginServices services) {
317318
var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext);
318319
serviceRegistry.init(services.client());
319320
for (var service : serviceRegistry.getServices().values()) {
320-
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
321+
service.defaultConfigIds().forEach(modelRegistry.get()::addDefaultIds);
321322
}
322323
inferenceServiceRegistry.set(serviceRegistry);
323324

324-
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState());
325+
var actionFilter = new ShardBulkInferenceActionFilter(
326+
services.clusterService(),
327+
serviceRegistry,
328+
modelRegistry.get(),
329+
getLicenseState()
330+
);
325331
shardBulkInferenceActionFilter.set(actionFilter);
326332

327333
var meterRegistry = services.telemetryProvider().getMeterRegistry();
328334
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
329335

330336
components.add(serviceRegistry);
331-
components.add(modelRegistry);
337+
components.add(modelRegistry.get());
332338
components.add(httpClientManager);
333339
components.add(inferenceStats);
334340

@@ -492,11 +498,16 @@ public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
492498
return Map.of(SemanticInferenceMetadataFieldsMapper.NAME, SemanticInferenceMetadataFieldsMapper.PARSER);
493499
}
494500

501+
// Overridable for tests
502+
protected Supplier<ModelRegistry> getModelRegistry() {
503+
return () -> modelRegistry.get();
504+
}
505+
495506
@Override
496507
public Map<String, Mapper.TypeParser> getMappers() {
497508
return Map.of(
498509
SemanticTextFieldMapper.CONTENT_TYPE,
499-
SemanticTextFieldMapper.PARSER,
510+
SemanticTextFieldMapper.parser(getModelRegistry()),
500511
OffsetSourceFieldMapper.CONTENT_TYPE,
501512
OffsetSourceFieldMapper.PARSER
502513
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.mapper;
99

10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
1012
import org.apache.lucene.index.FieldInfos;
1113
import org.apache.lucene.index.LeafReaderContext;
1214
import org.apache.lucene.search.DocIdSetIterator;
@@ -18,6 +20,7 @@
1820
import org.apache.lucene.search.join.BitSetProducer;
1921
import org.apache.lucene.search.join.ScoreMode;
2022
import org.apache.lucene.util.BitSet;
23+
import org.elasticsearch.ResourceNotFoundException;
2124
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
2225
import org.elasticsearch.common.Strings;
2326
import org.elasticsearch.common.bytes.BytesReference;
@@ -75,6 +78,7 @@
7578
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
7679
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
7780
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
81+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
7882

7983
import java.io.IOException;
8084
import java.io.UncheckedIOException;
@@ -89,6 +93,7 @@
8993
import java.util.Set;
9094
import java.util.function.BiConsumer;
9195
import java.util.function.Function;
96+
import java.util.function.Supplier;
9297

9398
import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
9499
import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
@@ -112,6 +117,7 @@
112117
* A {@link FieldMapper} for semantic text fields.
113118
*/
114119
public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper {
120+
private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class);
115121
public static final NodeFeature SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX = new NodeFeature("semantic_text.in_object_field_fix");
116122
public static final NodeFeature SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX = new NodeFeature("semantic_text.single_field_update_fix");
117123
public static final NodeFeature SEMANTIC_TEXT_DELETE_FIX = new NodeFeature("semantic_text.delete_fix");
@@ -127,10 +133,12 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
127133
public static final String CONTENT_TYPE = "semantic_text";
128134
public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID;
129135

130-
public static final TypeParser PARSER = new TypeParser(
131-
(n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings()),
132-
List.of(validateParserContext(CONTENT_TYPE))
133-
);
136+
public static final TypeParser parser(Supplier<ModelRegistry> modelRegistry) {
137+
return new TypeParser(
138+
(n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings(), modelRegistry.get()),
139+
List.of(validateParserContext(CONTENT_TYPE))
140+
);
141+
}
134142

135143
public static BiConsumer<String, MappingParserContext> validateParserContext(String type) {
136144
return (n, c) -> {
@@ -142,6 +150,7 @@ public static BiConsumer<String, MappingParserContext> validateParserContext(Str
142150
}
143151

144152
public static class Builder extends FieldMapper.Builder {
153+
private final ModelRegistry modelRegistry;
145154
private final boolean useLegacyFormat;
146155

147156
private final Parameter<String> inferenceId = Parameter.stringParam(
@@ -199,14 +208,21 @@ public static Builder from(SemanticTextFieldMapper mapper) {
199208
Builder builder = new Builder(
200209
mapper.leafName(),
201210
mapper.fieldType().getChunksField().bitsetProducer(),
202-
mapper.fieldType().getChunksField().indexSettings()
211+
mapper.fieldType().getChunksField().indexSettings(),
212+
mapper.modelRegistry
203213
);
204214
builder.init(mapper);
205215
return builder;
206216
}
207217

208-
public Builder(String name, Function<Query, BitSetProducer> bitSetProducer, IndexSettings indexSettings) {
218+
public Builder(
219+
String name,
220+
Function<Query, BitSetProducer> bitSetProducer,
221+
IndexSettings indexSettings,
222+
ModelRegistry modelRegistry
223+
) {
209224
super(name);
225+
this.modelRegistry = modelRegistry;
210226
this.useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(indexSettings.getSettings()) == false;
211227
this.inferenceFieldBuilder = c -> createInferenceField(
212228
c,
@@ -264,6 +280,26 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
264280
if (useLegacyFormat && multiFieldsBuilder.hasMultiFields()) {
265281
throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields");
266282
}
283+
if (modelSettings.get() == null) {
284+
try {
285+
var resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get());
286+
if (resolvedModelSettings != null) {
287+
modelSettings.setValue(resolvedModelSettings);
288+
}
289+
} catch (ResourceNotFoundException exc) {
290+
// We allow the inference ID to be unregistered at this point.
291+
// This will delay the creation of sub-fields, so indexing and querying for this field won't work
292+
// until the corresponding inference endpoint is created.
293+
logger.warn(
294+
"The field [{}] references an unknown inference ID [{}]. "
295+
+ "Indexing and querying this field will not work correctly until the corresponding "
296+
+ "inference endpoint is created.",
297+
leafName(),
298+
inferenceId.get()
299+
);
300+
}
301+
}
302+
267303
if (modelSettings.get() != null) {
268304
validateServiceSettings(modelSettings.get());
269305
}
@@ -287,7 +323,8 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
287323
useLegacyFormat,
288324
meta.getValue()
289325
),
290-
builderParams(this, context)
326+
builderParams(this, context),
327+
modelRegistry
291328
);
292329
}
293330

@@ -328,9 +365,17 @@ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, Map
328365
}
329366
}
330367

331-
private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams) {
368+
private final ModelRegistry modelRegistry;
369+
370+
private SemanticTextFieldMapper(
371+
String simpleName,
372+
MappedFieldType mappedFieldType,
373+
BuilderParams builderParams,
374+
ModelRegistry modelRegistry
375+
) {
332376
super(simpleName, mappedFieldType, builderParams);
333377
ensureMultiFields(builderParams.multiFields().iterator());
378+
this.modelRegistry = modelRegistry;
334379
}
335380

336381
private void ensureMultiFields(Iterator<FieldMapper> mappers) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap)
148148
private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false);
149149
private final Set<String> preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>());
150150

151+
private volatile ClusterState lastClusterState;
152+
151153
public ModelRegistry(ClusterService clusterService, Client client) {
152154
this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN);
153155
this.defaultConfigIds = new ConcurrentHashMap<>();
@@ -224,11 +226,17 @@ public void clearDefaultIds() {
224226
* @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster.
225227
*/
226228
public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
229+
synchronized (this) {
230+
assert lastClusterState != null : "initial cluster state not set yet";
231+
if (lastClusterState == null) {
232+
throw new IllegalStateException("initial cluster state not set yet");
233+
}
234+
}
227235
var config = defaultConfigIds.get(inferenceEntityId);
228236
if (config != null) {
229237
return config.settings();
230238
}
231-
var state = ModelRegistryMetadata.fromState(clusterService.state().projectState().metadata());
239+
var state = ModelRegistryMetadata.fromState(lastClusterState.projectState().metadata());
232240
var existing = state.getMinimalServiceSettings(inferenceEntityId);
233241
if (state.isUpgraded() && existing == null) {
234242
throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster.");
@@ -935,6 +943,11 @@ static List<InferenceService.DefaultConfigId> taskTypeMatchedDefaults(
935943

936944
@Override
937945
public void clusterChanged(ClusterChangedEvent event) {
946+
// keep track of the last applied cluster state
947+
synchronized (this) {
948+
lastClusterState = event.state();
949+
}
950+
938951
if (event.localNodeMaster() == false) {
939952
return;
940953
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.lucene.search.join.QueryBitSetProducer;
2525
import org.apache.lucene.search.join.ScoreMode;
2626
import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest;
27+
import org.elasticsearch.cluster.ClusterChangedEvent;
2728
import org.elasticsearch.cluster.metadata.IndexMetadata;
2829
import org.elasticsearch.common.CheckedBiConsumer;
2930
import org.elasticsearch.common.CheckedBiFunction;
@@ -63,14 +64,20 @@
6364
import org.elasticsearch.search.LeafNestedDocuments;
6465
import org.elasticsearch.search.NestedDocuments;
6566
import org.elasticsearch.search.SearchHit;
67+
import org.elasticsearch.test.ClusterServiceUtils;
68+
import org.elasticsearch.test.client.NoOpClient;
69+
import org.elasticsearch.threadpool.TestThreadPool;
6670
import org.elasticsearch.xcontent.XContentBuilder;
6771
import org.elasticsearch.xcontent.XContentType;
6872
import org.elasticsearch.xcontent.json.JsonXContent;
6973
import org.elasticsearch.xpack.core.XPackClientPlugin;
7074
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper;
7175
import org.elasticsearch.xpack.inference.InferencePlugin;
7276
import org.elasticsearch.xpack.inference.model.TestModel;
77+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
78+
import org.junit.After;
7379
import org.junit.AssumptionViolatedException;
80+
import org.junit.Before;
7481

7582
import java.io.IOException;
7683
import java.util.Collection;
@@ -80,6 +87,7 @@
8087
import java.util.Map;
8188
import java.util.Set;
8289
import java.util.function.BiConsumer;
90+
import java.util.function.Supplier;
8391

8492
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
8593
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD;
@@ -101,18 +109,43 @@
101109
public class SemanticTextFieldMapperTests extends MapperTestCase {
102110
private final boolean useLegacyFormat;
103111

112+
private TestThreadPool threadPool;
113+
104114
public SemanticTextFieldMapperTests(boolean useLegacyFormat) {
105115
this.useLegacyFormat = useLegacyFormat;
106116
}
107117

118+
@Before
119+
private void startThreadPool() {
120+
threadPool = createThreadPool();
121+
}
122+
123+
@After
124+
private void stopThreadPool() {
125+
threadPool.close();
126+
}
127+
108128
@ParametersFactory
109129
public static Iterable<Object[]> parameters() throws Exception {
110130
return List.of(new Object[] { true }, new Object[] { false });
111131
}
112132

113133
@Override
114134
protected Collection<? extends Plugin> getPlugins() {
115-
return List.of(new InferencePlugin(Settings.EMPTY), new XPackClientPlugin());
135+
var clusterService = ClusterServiceUtils.createClusterService(threadPool);
136+
var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool));
137+
modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) {
138+
@Override
139+
public boolean localNodeMaster() {
140+
return false;
141+
}
142+
});
143+
return List.of(new InferencePlugin(Settings.EMPTY) {
144+
@Override
145+
protected Supplier<ModelRegistry> getModelRegistry() {
146+
return () -> modelRegistry;
147+
}
148+
}, new XPackClientPlugin());
116149
}
117150

118151
private MapperService createMapperService(XContentBuilder mappings, boolean useLegacyFormat) throws IOException {

0 commit comments

Comments
 (0)