Skip to content

Commit 4d091b7

Browse files
authored
[8.19] Fix inference model validation for the semantic text field (elastic#127559)
This PR is a partial backport of elastic#127285 that fixes the validation of the inference id when mappings are restored or dynamically updated. This change doesn't include defaulting semantic text dense vector to BBQ since it requires elastic#124581 to be backported first.
1 parent b41fb2d commit 4d091b7

File tree

5 files changed

+250
-69
lines changed

5 files changed

+250
-69
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java

Lines changed: 124 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@
2626
import org.elasticsearch.index.mapper.SourceFieldMapper;
2727
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
2828
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
29+
import org.elasticsearch.inference.Model;
2930
import org.elasticsearch.inference.SimilarityMeasure;
3031
import org.elasticsearch.license.LicenseSettings;
3132
import org.elasticsearch.plugins.Plugin;
3233
import org.elasticsearch.search.builder.SearchSourceBuilder;
3334
import org.elasticsearch.test.ESIntegTestCase;
35+
import org.elasticsearch.test.InternalTestCluster;
36+
import org.elasticsearch.xpack.inference.InferenceIndex;
3437
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
35-
import org.elasticsearch.xpack.inference.Utils;
3638
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
3739
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
3840
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
@@ -45,7 +47,11 @@
4547
import java.util.Locale;
4648
import java.util.Map;
4749
import java.util.Set;
50+
import java.util.function.Function;
4851

52+
import static org.elasticsearch.xpack.inference.Utils.storeDenseModel;
53+
import static org.elasticsearch.xpack.inference.Utils.storeModel;
54+
import static org.elasticsearch.xpack.inference.Utils.storeSparseModel;
4955
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
5056
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
5157
import static org.hamcrest.Matchers.containsString;
@@ -56,6 +62,7 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase {
5662

5763
private final boolean useLegacyFormat;
5864
private final boolean useSyntheticSource;
65+
private ModelRegistry modelRegistry;
5966

6067
public ShardBulkInferenceActionFilterIT(boolean useLegacyFormat, boolean useSyntheticSource) {
6168
this.useLegacyFormat = useLegacyFormat;
@@ -74,16 +81,16 @@ public static Iterable<Object[]> parameters() throws Exception {
7481

7582
@Before
7683
public void setup() throws Exception {
77-
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
84+
modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
7885
DenseVectorFieldMapper.ElementType elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
7986
// dot product means that we need normalized vectors; it's not worth doing that in this test
8087
SimilarityMeasure similarity = randomValueOtherThan(
8188
SimilarityMeasure.DOT_PRODUCT,
8289
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
8390
);
8491
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
85-
Utils.storeSparseModel(modelRegistry);
86-
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
92+
storeSparseModel(modelRegistry);
93+
storeDenseModel(modelRegistry, dimensions, similarity, elementType);
8794
}
8895

8996
@Override
@@ -135,32 +142,131 @@ public void testBulkOperations() throws Exception {
135142
TestDenseInferenceServiceExtension.TestInferenceService.NAME
136143
)
137144
).get();
145+
assertRandomBulkOperations(INDEX_NAME, isIndexRequest -> {
146+
Map<String, Object> map = new HashMap<>();
147+
map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
148+
map.put("dense_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
149+
return map;
150+
});
151+
}
152+
153+
public void testItemFailures() {
154+
prepareCreate(INDEX_NAME).setMapping(
155+
String.format(
156+
Locale.ROOT,
157+
"""
158+
{
159+
"properties": {
160+
"sparse_field": {
161+
"type": "semantic_text",
162+
"inference_id": "%s"
163+
},
164+
"dense_field": {
165+
"type": "semantic_text",
166+
"inference_id": "%s"
167+
}
168+
}
169+
}
170+
""",
171+
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
172+
TestDenseInferenceServiceExtension.TestInferenceService.NAME
173+
)
174+
).get();
175+
176+
BulkRequestBuilder bulkReqBuilder = client().prepareBulk();
177+
int totalBulkSize = randomIntBetween(100, 200); // Use a bulk request size large enough to require batching
178+
for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) {
179+
String id = Integer.toString(bulkSize);
180+
181+
// Set field values that will cause errors when generating inference requests
182+
Map<String, Object> source = new HashMap<>();
183+
source.put("sparse_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar")));
184+
source.put("dense_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar")));
185+
186+
bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source));
187+
}
188+
189+
BulkResponse bulkResponse = bulkReqBuilder.get();
190+
assertThat(bulkResponse.hasFailures(), equalTo(true));
191+
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
192+
assertThat(bulkItemResponse.isFailed(), equalTo(true));
193+
assertThat(bulkItemResponse.getFailureMessage(), containsString("expected [String|Number|Boolean]"));
194+
}
195+
}
196+
197+
public void testRestart() throws Exception {
198+
Model model1 = new TestSparseInferenceServiceExtension.TestSparseModel(
199+
"another_inference_endpoint",
200+
new TestSparseInferenceServiceExtension.TestServiceSettings("sparse_model", null, false)
201+
);
202+
storeModel(modelRegistry, model1);
203+
prepareCreate("index_restart").setMapping("""
204+
{
205+
"properties": {
206+
"sparse_field": {
207+
"type": "semantic_text",
208+
"inference_id": "new_inference_endpoint"
209+
},
210+
"other_field": {
211+
"type": "semantic_text",
212+
"inference_id": "another_inference_endpoint"
213+
}
214+
}
215+
}
216+
""").get();
217+
Model model2 = new TestSparseInferenceServiceExtension.TestSparseModel(
218+
"new_inference_endpoint",
219+
new TestSparseInferenceServiceExtension.TestServiceSettings("sparse_model", null, false)
220+
);
221+
storeModel(modelRegistry, model2);
222+
223+
internalCluster().fullRestart(new InternalTestCluster.RestartCallback());
224+
ensureGreen(InferenceIndex.INDEX_NAME, "index_restart");
138225

226+
assertRandomBulkOperations("index_restart", isIndexRequest -> {
227+
Map<String, Object> map = new HashMap<>();
228+
map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
229+
map.put("other_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
230+
return map;
231+
});
232+
233+
internalCluster().fullRestart(new InternalTestCluster.RestartCallback());
234+
ensureGreen(InferenceIndex.INDEX_NAME, "index_restart");
235+
236+
assertRandomBulkOperations("index_restart", isIndexRequest -> {
237+
Map<String, Object> map = new HashMap<>();
238+
map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
239+
map.put("other_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
240+
return map;
241+
});
242+
}
243+
244+
private void assertRandomBulkOperations(String indexName, Function<Boolean, Map<String, Object>> sourceSupplier) throws Exception {
245+
int numHits = numHits(indexName);
139246
int totalBulkReqs = randomIntBetween(2, 100);
140-
long totalDocs = 0;
247+
long totalDocs = numHits;
141248
Set<String> ids = new HashSet<>();
142-
for (int bulkReqs = 0; bulkReqs < totalBulkReqs; bulkReqs++) {
249+
250+
for (int bulkReqs = numHits; bulkReqs < totalBulkReqs; bulkReqs++) {
143251
BulkRequestBuilder bulkReqBuilder = client().prepareBulk();
144252
int totalBulkSize = randomIntBetween(1, 100);
145253
for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) {
146254
if (ids.size() > 0 && rarely(random())) {
147255
String id = randomFrom(ids);
148256
ids.remove(id);
149-
DeleteRequestBuilder request = new DeleteRequestBuilder(client(), INDEX_NAME).setId(id);
257+
DeleteRequestBuilder request = new DeleteRequestBuilder(client(), indexName).setId(id);
150258
bulkReqBuilder.add(request);
151259
continue;
152260
}
153261
String id = Long.toString(totalDocs++);
154262
boolean isIndexRequest = randomBoolean();
155-
Map<String, Object> source = new HashMap<>();
156-
source.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
157-
source.put("dense_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
263+
Map<String, Object> source = sourceSupplier.apply(isIndexRequest);
158264
if (isIndexRequest) {
159-
bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source));
265+
bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(indexName).setId(id).setSource(source));
160266
ids.add(id);
161267
} else {
162268
boolean isUpsert = randomBoolean();
163-
UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(INDEX_NAME).setDoc(source);
269+
UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(indexName).setDoc(source);
164270
if (isUpsert || ids.size() == 0) {
165271
request.setDocAsUpsert(true);
166272
} else {
@@ -188,59 +294,17 @@ public void testBulkOperations() throws Exception {
188294
}
189295
assertFalse(bulkResponse.hasFailures());
190296
}
297+
client().admin().indices().refresh(new RefreshRequest(indexName)).get();
298+
assertThat(numHits(indexName), equalTo(ids.size() + numHits));
299+
}
191300

192-
client().admin().indices().refresh(new RefreshRequest(INDEX_NAME)).get();
193-
301+
private int numHits(String indexName) throws Exception {
194302
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(0).trackTotalHits(true);
195-
SearchResponse searchResponse = client().search(new SearchRequest(INDEX_NAME).source(sourceBuilder)).get();
303+
SearchResponse searchResponse = client().search(new SearchRequest(indexName).source(sourceBuilder)).get();
196304
try {
197-
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) ids.size()));
305+
return (int) searchResponse.getHits().getTotalHits().value;
198306
} finally {
199307
searchResponse.decRef();
200308
}
201309
}
202-
203-
public void testItemFailures() {
204-
prepareCreate(INDEX_NAME).setMapping(
205-
String.format(
206-
Locale.ROOT,
207-
"""
208-
{
209-
"properties": {
210-
"sparse_field": {
211-
"type": "semantic_text",
212-
"inference_id": "%s"
213-
},
214-
"dense_field": {
215-
"type": "semantic_text",
216-
"inference_id": "%s"
217-
}
218-
}
219-
}
220-
""",
221-
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
222-
TestDenseInferenceServiceExtension.TestInferenceService.NAME
223-
)
224-
).get();
225-
226-
BulkRequestBuilder bulkReqBuilder = client().prepareBulk();
227-
int totalBulkSize = randomIntBetween(100, 200); // Use a bulk request size large enough to require batching
228-
for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) {
229-
String id = Integer.toString(bulkSize);
230-
231-
// Set field values that will cause errors when generating inference requests
232-
Map<String, Object> source = new HashMap<>();
233-
source.put("sparse_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar")));
234-
source.put("dense_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar")));
235-
236-
bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source));
237-
}
238-
239-
BulkResponse bulkResponse = bulkReqBuilder.get();
240-
assertThat(bulkResponse.hasFailures(), equalTo(true));
241-
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
242-
assertThat(bulkItemResponse.isFailed(), equalTo(true));
243-
assertThat(bulkItemResponse.getFailureMessage(), containsString("expected [String|Number|Boolean]"));
244-
}
245-
}
246310
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.elasticsearch.index.mapper.Mapper;
4646
import org.elasticsearch.index.mapper.MapperBuilderContext;
4747
import org.elasticsearch.index.mapper.MapperMergeContext;
48+
import org.elasticsearch.index.mapper.MapperService;
4849
import org.elasticsearch.index.mapper.MappingLookup;
4950
import org.elasticsearch.index.mapper.MappingParserContext;
5051
import org.elasticsearch.index.mapper.NestedObjectMapper;
@@ -204,6 +205,7 @@ public static class Builder extends FieldMapper.Builder {
204205

205206
private final Parameter<Map<String, String>> meta = Parameter.metaParam();
206207

208+
private MinimalServiceSettings resolvedModelSettings;
207209
private Function<MapperBuilderContext, ObjectMapper> inferenceFieldBuilder;
208210

209211
public static Builder from(SemanticTextFieldMapper mapper) {
@@ -283,21 +285,31 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
283285
throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields");
284286
}
285287

286-
if (modelSettings.get() == null) {
288+
if (context.getMergeReason() != MapperService.MergeReason.MAPPING_RECOVERY && modelSettings.get() == null) {
287289
try {
288-
var resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get());
290+
/*
291+
* If the model is not already set and we are not in a recovery scenario, resolve it using the registry.
292+
* Note: We do not set the model in the mapping at this stage. Instead, the model will be added through
293+
* a mapping update during the first ingestion.
294+
* This approach allows mappings to reference inference endpoints that may not yet exist.
295+
* The only requirement is that the referenced inference endpoint must be available at the time of ingestion.
296+
*/
297+
resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get());
289298
if (resolvedModelSettings != null) {
290-
modelSettings.setValue(resolvedModelSettings);
299+
validateServiceSettings(resolvedModelSettings, null);
291300
}
292301
} catch (ResourceNotFoundException exc) {
293-
// We allow the inference ID to be unregistered at this point.
294-
// This will delay the creation of sub-fields, so indexing and querying for this field won't work
295-
// until the corresponding inference endpoint is created.
302+
/* We allow the inference ID to be unregistered at this point.
303+
* This will delay the creation of sub-fields, so indexing and querying for this field won't work
304+
* until the corresponding inference endpoint is created.
305+
*/
296306
}
307+
} else {
308+
resolvedModelSettings = modelSettings.get();
297309
}
298310

299311
if (modelSettings.get() != null) {
300-
validateServiceSettings(modelSettings.get());
312+
validateServiceSettings(modelSettings.get(), resolvedModelSettings);
301313
} else {
302314
logger.warn(
303315
"The field [{}] references an unknown inference ID [{}]. "
@@ -333,7 +345,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
333345
);
334346
}
335347

336-
private void validateServiceSettings(MinimalServiceSettings settings) {
348+
private void validateServiceSettings(MinimalServiceSettings settings, MinimalServiceSettings resolved) {
337349
switch (settings.taskType()) {
338350
case SPARSE_EMBEDDING, TEXT_EMBEDDING -> {
339351
}
@@ -348,6 +360,17 @@ private void validateServiceSettings(MinimalServiceSettings settings) {
348360
+ settings.taskType().name()
349361
);
350362
}
363+
364+
if (resolved != null && settings.canMergeWith(resolved) == false) {
365+
throw new IllegalArgumentException(
366+
"Mismatch between provided and registered inference model settings. "
367+
+ "Provided: ["
368+
+ settings
369+
+ "], Expected: ["
370+
+ resolved
371+
+ "]."
372+
);
373+
}
351374
}
352375

353376
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ public void clearDefaultIds() {
223223
*/
224224
public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
225225
synchronized (this) {
226-
assert lastMetadata != null : "initial cluster state not set yet";
227226
if (lastMetadata == null) {
228227
throw new IllegalStateException("initial cluster state not set yet");
229228
}

0 commit comments

Comments
 (0)