diff --git a/docs/changelog/126629.yaml b/docs/changelog/126629.yaml new file mode 100644 index 0000000000000..49d04856c0b64 --- /dev/null +++ b/docs/changelog/126629.yaml @@ -0,0 +1,5 @@ +pr: 126629 +summary: Default new `semantic_text` fields to use BBQ when models are compatible +area: Relevance +type: enhancement +issues: [] diff --git a/docs/changelog/127285.yaml b/docs/changelog/127285.yaml new file mode 100644 index 0000000000000..e735580b5f310 --- /dev/null +++ b/docs/changelog/127285.yaml @@ -0,0 +1,5 @@ +pr: 127285 +summary: Restore model registry validation for the semantic text field +area: Search +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 96386ceca6c9f..b1cf51d89ad31 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -162,6 +162,7 @@ private static Version parseUnchecked(String version) { public static final IndexVersion UPGRADE_TO_LUCENE_10_2_0 = def(9_022_00_0, Version.LUCENE_10_2_0); public static final IndexVersion UPGRADE_TO_LUCENE_10_2_1 = def(9_023_00_0, Version.LUCENE_10_2_1); public static final IndexVersion DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ = def(9_024_0_00, Version.LUCENE_10_2_1); + public static final IndexVersion SEMANTIC_TEXT_DEFAULTS_TO_BBQ = def(9_025_0_00, Version.LUCENE_10_2_1); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 53a7e8836e3db..7de743fb251fa 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -295,6 +295,11 @@ public Builder elementType(ElementType elementType) { return this; } + public Builder indexOptions(IndexOptions indexOptions) { + this.indexOptions.setValue(indexOptions); + return this; + } + @Override public DenseVectorFieldMapper build(MapperBuilderContext context) { // Validate again here because the dimensions or element type could have been set programmatically, @@ -1226,7 +1231,7 @@ public final String toString() { public abstract VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersion, ElementType elementType); } - abstract static class IndexOptions implements ToXContent { + public abstract static class IndexOptions implements ToXContent { final VectorIndexType type; IndexOptions(VectorIndexType type) { @@ -1235,21 +1240,36 @@ abstract static class IndexOptions implements ToXContent { abstract KnnVectorsFormat getVectorsFormat(ElementType elementType); - final void validateElementType(ElementType elementType) { - if (type.supportsElementType(elementType) == false) { + public boolean validate(ElementType elementType, int dim, boolean throwOnError) { + return validateElementType(elementType, throwOnError) && validateDimension(dim, throwOnError); + } + + public boolean validateElementType(ElementType elementType) { + return validateElementType(elementType, true); + } + + final boolean validateElementType(ElementType elementType, boolean throwOnError) { + boolean validElementType = type.supportsElementType(elementType); + if (throwOnError && validElementType == false) { throw new IllegalArgumentException( "[element_type] cannot be [" + elementType.toString() + "] when using index type [" + type + "]" ); } + return validElementType; } abstract boolean updatableTo(IndexOptions update); - public void validateDimension(int dim) { - if (type.supportsDimension(dim)) { - return; + public boolean validateDimension(int dim) { + return validateDimension(dim, true); + } + + public boolean validateDimension(int dim, boolean throwOnError) { + boolean supportsDimension = type.supportsDimension(dim); + if (throwOnError && supportsDimension == false) { + throw new IllegalArgumentException(type.name + " only supports even dimensions; provided=" + dim); } - throw new IllegalArgumentException(type.name + " only supports even dimensions; provided=" + dim); + return supportsDimension; } abstract boolean doEquals(IndexOptions other); @@ -1758,12 +1778,12 @@ boolean updatableTo(IndexOptions update) { } - static class Int8HnswIndexOptions extends QuantizedIndexOptions { + public static class Int8HnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; private final Float confidenceInterval; - Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) { + public Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) { super(VectorIndexType.INT8_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; @@ -1901,11 +1921,11 @@ public String toString() { } } - static class BBQHnswIndexOptions extends QuantizedIndexOptions { + public static class BBQHnswIndexOptions extends QuantizedIndexOptions { private final int m; private final int efConstruction; - BBQHnswIndexOptions(int m, int efConstruction, RescoreVector rescoreVector) { + public BBQHnswIndexOptions(int m, int efConstruction, RescoreVector rescoreVector) { super(VectorIndexType.BBQ_HNSW, rescoreVector); this.m = m; this.efConstruction = efConstruction; @@ -1947,11 +1967,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - public void validateDimension(int dim) { - if (type.supportsDimension(dim)) { - return; + public boolean validateDimension(int dim, boolean throwOnError) { + boolean supportsDimension = type.supportsDimension(dim); + if (throwOnError && supportsDimension == false) { + throw new IllegalArgumentException( + type.name + " does not support dimensions fewer than " + BBQ_MIN_DIMS + "; provided=" + dim + ); } - throw new IllegalArgumentException(type.name + " does not support dimensions fewer than " + BBQ_MIN_DIMS + "; provided=" + dim); + return supportsDimension; } } @@ -1995,15 +2018,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - public void validateDimension(int dim) { - if (type.supportsDimension(dim)) { - return; + public boolean validateDimension(int dim, boolean throwOnError) { + boolean supportsDimension = type.supportsDimension(dim); + if (throwOnError && supportsDimension == false) { + throw new IllegalArgumentException( + type.name + " does not support dimensions fewer than " + BBQ_MIN_DIMS + "; provided=" + dim + ); } - throw new IllegalArgumentException(type.name + " does not support dimensions fewer than " + BBQ_MIN_DIMS + "; provided=" + dim); + return supportsDimension; } + } - record RescoreVector(float oversample) implements ToXContentObject { + public record RescoreVector(float oversample) implements ToXContentObject { static final String NAME = "rescore_vector"; static final String OVERSAMPLE = "oversample"; @@ -2323,7 +2350,7 @@ ElementType getElementType() { return elementType; } - IndexOptions getIndexOptions() { + public IndexOptions getIndexOptions() { return indexOptions; } } diff --git a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java index e4d8ab24f6e73..b9d2696b347c7 100644 --- a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java @@ -249,10 +249,6 @@ private static void validateFieldNotPresent(String field, Object fieldValue, Tas } } - public ModelConfigurations toModelConfigurations(String inferenceEntityId) { - return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this); - } - /** * Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition. */ @@ -260,7 +256,6 @@ public boolean canMergeWith(MinimalServiceSettings other) { return taskType == other.taskType && Objects.equals(dimensions, other.dimensions) && similarity == other.similarity - && elementType == other.elementType - && (service == null || service.equals(other.service)); + && elementType == other.elementType; } } diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java index b62e400826836..99d9bbf30158b 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java @@ -207,6 +207,13 @@ protected final MapperService createMapperService(Settings settings, String mapp return mapperService; } + protected final MapperService createMapperService(IndexVersion indexVersion, Settings settings, XContentBuilder mappings) + throws IOException { + MapperService mapperService = createMapperService(indexVersion, settings, () -> true, mappings); + merge(mapperService, mappings); + return mapperService; + } + protected final MapperService createMapperService(IndexVersion version, XContentBuilder mapping) throws IOException { return createMapperService(version, getIndexSettings(), () -> true, mapping); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java index 2b77426e64323..074678bbea095 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -26,11 +26,14 @@ import org.elasticsearch.index.mapper.SourceFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.InternalTestCluster; +import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; @@ -45,7 +48,9 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.function.Function; +import static org.elasticsearch.xpack.inference.Utils.storeModel; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; import static org.hamcrest.Matchers.containsString; @@ -56,6 +61,7 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase { private final boolean useLegacyFormat; private final boolean useSyntheticSource; + private ModelRegistry modelRegistry; public ShardBulkInferenceActionFilterIT(boolean useLegacyFormat, boolean useSyntheticSource) { this.useLegacyFormat = useLegacyFormat; @@ -74,7 +80,7 @@ public static Iterable parameters() throws Exception { @Before public void setup() throws Exception { - ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class); + modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class); DenseVectorFieldMapper.ElementType elementType = randomFrom(DenseVectorFieldMapper.ElementType.values()); // dot product means that we need normalized vectors; it's not worth doing that in this test SimilarityMeasure similarity = randomValueOtherThan( @@ -135,32 +141,131 @@ public void testBulkOperations() throws Exception { TestDenseInferenceServiceExtension.TestInferenceService.NAME ) ).get(); + assertRandomBulkOperations(INDEX_NAME, isIndexRequest -> { + Map map = new HashMap<>(); + map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput()); + map.put("dense_field", isIndexRequest && rarely() ? null : randomSemanticTextInput()); + return map; + }); + } + + public void testItemFailures() { + prepareCreate(INDEX_NAME).setMapping( + String.format( + Locale.ROOT, + """ + { + "properties": { + "sparse_field": { + "type": "semantic_text", + "inference_id": "%s" + }, + "dense_field": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + """, + TestSparseInferenceServiceExtension.TestInferenceService.NAME, + TestDenseInferenceServiceExtension.TestInferenceService.NAME + ) + ).get(); + + BulkRequestBuilder bulkReqBuilder = client().prepareBulk(); + int totalBulkSize = randomIntBetween(100, 200); // Use a bulk request size large enough to require batching + for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) { + String id = Integer.toString(bulkSize); + + // Set field values that will cause errors when generating inference requests + Map source = new HashMap<>(); + source.put("sparse_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar"))); + source.put("dense_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar"))); + + bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source)); + } + + BulkResponse bulkResponse = bulkReqBuilder.get(); + assertThat(bulkResponse.hasFailures(), equalTo(true)); + for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) { + assertThat(bulkItemResponse.isFailed(), equalTo(true)); + assertThat(bulkItemResponse.getFailureMessage(), containsString("expected [String|Number|Boolean]")); + } + } + + public void testRestart() throws Exception { + Model model1 = new TestSparseInferenceServiceExtension.TestSparseModel( + "another_inference_endpoint", + new TestSparseInferenceServiceExtension.TestServiceSettings("sparse_model", null, false) + ); + storeModel(modelRegistry, model1); + prepareCreate("index_restart").setMapping(""" + { + "properties": { + "sparse_field": { + "type": "semantic_text", + "inference_id": "new_inference_endpoint" + }, + "other_field": { + "type": "semantic_text", + "inference_id": "another_inference_endpoint" + } + } + } + """).get(); + Model model2 = new TestSparseInferenceServiceExtension.TestSparseModel( + "new_inference_endpoint", + new TestSparseInferenceServiceExtension.TestServiceSettings("sparse_model", null, false) + ); + storeModel(modelRegistry, model2); + + internalCluster().fullRestart(new InternalTestCluster.RestartCallback()); + ensureGreen(InferenceIndex.INDEX_NAME, "index_restart"); + assertRandomBulkOperations("index_restart", isIndexRequest -> { + Map map = new HashMap<>(); + map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput()); + map.put("other_field", isIndexRequest && rarely() ? null : randomSemanticTextInput()); + return map; + }); + + internalCluster().fullRestart(new InternalTestCluster.RestartCallback()); + ensureGreen(InferenceIndex.INDEX_NAME, "index_restart"); + + assertRandomBulkOperations("index_restart", isIndexRequest -> { + Map map = new HashMap<>(); + map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput()); + map.put("other_field", isIndexRequest && rarely() ? null : randomSemanticTextInput()); + return map; + }); + } + + private void assertRandomBulkOperations(String indexName, Function> sourceSupplier) throws Exception { + int numHits = numHits(indexName); int totalBulkReqs = randomIntBetween(2, 100); - long totalDocs = 0; + long totalDocs = numHits; Set ids = new HashSet<>(); - for (int bulkReqs = 0; bulkReqs < totalBulkReqs; bulkReqs++) { + + for (int bulkReqs = numHits; bulkReqs < totalBulkReqs; bulkReqs++) { BulkRequestBuilder bulkReqBuilder = client().prepareBulk(); int totalBulkSize = randomIntBetween(1, 100); for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) { if (ids.size() > 0 && rarely(random())) { String id = randomFrom(ids); ids.remove(id); - DeleteRequestBuilder request = new DeleteRequestBuilder(client(), INDEX_NAME).setId(id); + DeleteRequestBuilder request = new DeleteRequestBuilder(client(), indexName).setId(id); bulkReqBuilder.add(request); continue; } String id = Long.toString(totalDocs++); boolean isIndexRequest = randomBoolean(); - Map source = new HashMap<>(); - source.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput()); - source.put("dense_field", isIndexRequest && rarely() ? null : randomSemanticTextInput()); + Map source = sourceSupplier.apply(isIndexRequest); if (isIndexRequest) { - bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source)); + bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(indexName).setId(id).setSource(source)); ids.add(id); } else { boolean isUpsert = randomBoolean(); - UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(INDEX_NAME).setDoc(source); + UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(indexName).setDoc(source); if (isUpsert || ids.size() == 0) { request.setDocAsUpsert(true); } else { @@ -188,59 +293,17 @@ public void testBulkOperations() throws Exception { } assertFalse(bulkResponse.hasFailures()); } + client().admin().indices().refresh(new RefreshRequest(indexName)).get(); + assertThat(numHits(indexName), equalTo(ids.size() + numHits)); + } - client().admin().indices().refresh(new RefreshRequest(INDEX_NAME)).get(); - + private int numHits(String indexName) throws Exception { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(0).trackTotalHits(true); - SearchResponse searchResponse = client().search(new SearchRequest(INDEX_NAME).source(sourceBuilder)).get(); + SearchResponse searchResponse = client().search(new SearchRequest(indexName).source(sourceBuilder)).get(); try { - assertThat(searchResponse.getHits().getTotalHits().value(), equalTo((long) ids.size())); + return (int) searchResponse.getHits().getTotalHits().value(); } finally { searchResponse.decRef(); } } - - public void testItemFailures() { - prepareCreate(INDEX_NAME).setMapping( - String.format( - Locale.ROOT, - """ - { - "properties": { - "sparse_field": { - "type": "semantic_text", - "inference_id": "%s" - }, - "dense_field": { - "type": "semantic_text", - "inference_id": "%s" - } - } - } - """, - TestSparseInferenceServiceExtension.TestInferenceService.NAME, - TestDenseInferenceServiceExtension.TestInferenceService.NAME - ) - ).get(); - - BulkRequestBuilder bulkReqBuilder = client().prepareBulk(); - int totalBulkSize = randomIntBetween(100, 200); // Use a bulk request size large enough to require batching - for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) { - String id = Integer.toString(bulkSize); - - // Set field values that will cause errors when generating inference requests - Map source = new HashMap<>(); - source.put("sparse_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar"))); - source.put("dense_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar"))); - - bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source)); - } - - BulkResponse bulkResponse = bulkReqBuilder.get(); - assertThat(bulkResponse.hasFailures(), equalTo(true)); - for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) { - assertThat(bulkItemResponse.isFailed(), equalTo(true)); - assertThat(bulkItemResponse.getFailureMessage(), containsString("expected [String|Number|Boolean]")); - } - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 114d9eaedfa53..7edb724132ac6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -197,6 +197,7 @@ public class InferencePlugin extends Plugin private final SetOnce elasticInferenceServiceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); + private final SetOnce modelRegistry = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -260,8 +261,8 @@ public Collection createComponents(PluginServices services) { var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService()); amazonBedrockFactory.set(amazonBedrockRequestSenderFactory); - ModelRegistry modelRegistry = new ModelRegistry(services.clusterService(), services.client()); - services.clusterService().addListener(modelRegistry); + modelRegistry.set(new ModelRegistry(services.clusterService(), services.client())); + services.clusterService().addListener(modelRegistry.get()); if (inferenceServiceExtensions == null) { inferenceServiceExtensions = new ArrayList<>(); @@ -299,7 +300,7 @@ public Collection createComponents(PluginServices services) { elasicInferenceServiceFactory.get(), serviceComponents.get(), inferenceServiceSettings, - modelRegistry, + modelRegistry.get(), authorizationHandler ) ) @@ -317,14 +318,14 @@ public Collection createComponents(PluginServices services) { var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext); serviceRegistry.init(services.client()); for (var service : serviceRegistry.getServices().values()) { - service.defaultConfigIds().forEach(modelRegistry::addDefaultIds); + service.defaultConfigIds().forEach(modelRegistry.get()::addDefaultIds); } inferenceServiceRegistry.set(serviceRegistry); var actionFilter = new ShardBulkInferenceActionFilter( services.clusterService(), serviceRegistry, - modelRegistry, + modelRegistry.get(), getLicenseState(), services.indexingPressure() ); @@ -334,7 +335,7 @@ public Collection createComponents(PluginServices services) { var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry)); components.add(serviceRegistry); - components.add(modelRegistry); + components.add(modelRegistry.get()); components.add(httpClientManager); components.add(inferenceStats); @@ -498,11 +499,16 @@ public Map getMetadataMappers() { return Map.of(SemanticInferenceMetadataFieldsMapper.NAME, SemanticInferenceMetadataFieldsMapper.PARSER); } + // Overridable for tests + protected Supplier getModelRegistry() { + return () -> modelRegistry.get(); + } + @Override public Map getMappers() { return Map.of( SemanticTextFieldMapper.CONTENT_TYPE, - SemanticTextFieldMapper.PARSER, + SemanticTextFieldMapper.parser(getModelRegistry()), OffsetSourceFieldMapper.CONTENT_TYPE, OffsetSourceFieldMapper.PARSER ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 3a942a8e73537..f170f79809a59 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.inference.mapper; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; @@ -18,6 +21,7 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.util.BitSet; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; @@ -42,6 +46,7 @@ import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.MapperMergeContext; +import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MappingLookup; import org.elasticsearch.index.mapper.MappingParserContext; import org.elasticsearch.index.mapper.NestedObjectMapper; @@ -75,6 +80,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.io.UncheckedIOException; @@ -89,7 +95,9 @@ import java.util.Set; import java.util.function.BiConsumer; import java.util.function.Function; +import java.util.function.Supplier; +import static org.elasticsearch.index.IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ; import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; @@ -112,6 +120,7 @@ * A {@link FieldMapper} for semantic text fields. */ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { + private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final NodeFeature SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX = new NodeFeature("semantic_text.in_object_field_fix"); public static final NodeFeature SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX = new NodeFeature("semantic_text.single_field_update_fix"); public static final NodeFeature SEMANTIC_TEXT_DELETE_FIX = new NodeFeature("semantic_text.delete_fix"); @@ -127,10 +136,14 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie public static final String CONTENT_TYPE = "semantic_text"; public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID; - public static final TypeParser PARSER = new TypeParser( - (n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings()), - List.of(validateParserContext(CONTENT_TYPE)) - ); + public static final float DEFAULT_RESCORE_OVERSAMPLE = 3.0f; + + public static final TypeParser parser(Supplier modelRegistry) { + return new TypeParser( + (n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings(), modelRegistry.get()), + List.of(validateParserContext(CONTENT_TYPE)) + ); + } public static BiConsumer validateParserContext(String type) { return (n, c) -> { @@ -142,6 +155,7 @@ public static BiConsumer validateParserContext(Str } public static class Builder extends FieldMapper.Builder { + private final ModelRegistry modelRegistry; private final boolean useLegacyFormat; private final Parameter inferenceId = Parameter.stringParam( @@ -193,26 +207,34 @@ public static class Builder extends FieldMapper.Builder { private final Parameter> meta = Parameter.metaParam(); + private MinimalServiceSettings resolvedModelSettings; private Function inferenceFieldBuilder; public static Builder from(SemanticTextFieldMapper mapper) { Builder builder = new Builder( mapper.leafName(), mapper.fieldType().getChunksField().bitsetProducer(), - mapper.fieldType().getChunksField().indexSettings() + mapper.fieldType().getChunksField().indexSettings(), + mapper.modelRegistry ); builder.init(mapper); return builder; } - public Builder(String name, Function bitSetProducer, IndexSettings indexSettings) { + public Builder( + String name, + Function bitSetProducer, + IndexSettings indexSettings, + ModelRegistry modelRegistry + ) { super(name); + this.modelRegistry = modelRegistry; this.useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(indexSettings.getSettings()) == false; this.inferenceFieldBuilder = c -> createInferenceField( c, indexSettings.getIndexVersionCreated(), useLegacyFormat, - modelSettings.get(), + resolvedModelSettings, bitSetProducer, indexSettings ); @@ -264,9 +286,39 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { if (useLegacyFormat && multiFieldsBuilder.hasMultiFields()) { throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields"); } + + if (context.getMergeReason() != MapperService.MergeReason.MAPPING_RECOVERY && modelSettings.get() == null) { + try { + /* + * If the model is not already set and we are not in a recovery scenario, resolve it using the registry. + * Note: We do not set the model in the mapping at this stage. Instead, the model will be added through + * a mapping update during the first ingestion. + * This approach allows mappings to reference inference endpoints that may not yet exist. + * The only requirement is that the referenced inference endpoint must be available at the time of ingestion. + */ + resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get()); + } catch (ResourceNotFoundException exc) { + /* We allow the inference ID to be unregistered at this point. + * This will delay the creation of sub-fields, so indexing and querying for this field won't work + * until the corresponding inference endpoint is created. + */ + } + } else { + resolvedModelSettings = modelSettings.get(); + } + if (modelSettings.get() != null) { - validateServiceSettings(modelSettings.get()); + validateServiceSettings(modelSettings.get(), resolvedModelSettings); + } else { + logger.warn( + "The field [{}] references an unknown inference ID [{}]. " + + "Indexing and querying this field will not work correctly until the corresponding " + + "inference endpoint is created.", + leafName(), + inferenceId.get() + ); } + final String fullName = context.buildFullName(leafName()); if (context.isInNestedContext()) { @@ -287,11 +339,12 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { useLegacyFormat, meta.getValue() ), - builderParams(this, context) + builderParams(this, context), + modelRegistry ); } - private void validateServiceSettings(MinimalServiceSettings settings) { + private void validateServiceSettings(MinimalServiceSettings settings, MinimalServiceSettings resolved) { switch (settings.taskType()) { case SPARSE_EMBEDDING, TEXT_EMBEDDING -> { } @@ -306,6 +359,16 @@ private void validateServiceSettings(MinimalServiceSettings settings) { + settings.taskType().name() ); } + if (resolved != null && settings.canMergeWith(resolved) == false) { + throw new IllegalArgumentException( + "Mismatch between provided and registered inference model settings. " + + "Provided: [" + + settings + + "], Expected: [" + + resolved + + "]." + ); + } } /** @@ -328,9 +391,17 @@ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, Map } } - private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams) { + private final ModelRegistry modelRegistry; + + private SemanticTextFieldMapper( + String simpleName, + MappedFieldType mappedFieldType, + BuilderParams builderParams, + ModelRegistry modelRegistry + ) { super(simpleName, mappedFieldType, builderParams); ensureMultiFields(builderParams.multiFields().iterator()); + this.modelRegistry = modelRegistry; } private void ensureMultiFields(Iterator mappers) { @@ -1006,12 +1077,30 @@ private static Mapper.Builder createEmbeddingsField( denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); denseVectorMapperBuilder.elementType(modelSettings.elementType()); + DenseVectorFieldMapper.IndexOptions defaultIndexOptions = null; + if (indexVersionCreated.onOrAfter(SEMANTIC_TEXT_DEFAULTS_TO_BBQ)) { + defaultIndexOptions = defaultSemanticDenseIndexOptions(); + } + if (defaultIndexOptions != null + && defaultIndexOptions.validate(modelSettings.elementType(), modelSettings.dimensions(), false)) { + denseVectorMapperBuilder.indexOptions(defaultIndexOptions); + } + yield denseVectorMapperBuilder; } default -> throw new IllegalArgumentException("Invalid task_type in model_settings [" + modelSettings.taskType().name() + "]"); }; } + static DenseVectorFieldMapper.IndexOptions defaultSemanticDenseIndexOptions() { + // As embedding models for text perform better with BBQ, we aggressively default semantic_text fields to use optimized index + // options outside of dense_vector defaults + int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; + int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; + DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE); + return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, rescoreVector); + } + private static boolean canMergeModelSettings(MinimalServiceSettings previous, MinimalServiceSettings current, Conflicts conflicts) { if (previous != null && current != null && previous.canMergeWith(current)) { return true; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 16533f2dc9b8a..37a82b2160595 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -36,6 +36,7 @@ import org.elasticsearch.cluster.ClusterStateAckListener; import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.SimpleBatchedAckListenerTaskExecutor; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.service.ClusterService; @@ -140,7 +141,6 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private static final String MODEL_ID_FIELD = "model_id"; private static final Logger logger = LogManager.getLogger(ModelRegistry.class); - private final ClusterService clusterService; private final OriginSettingClient client; private final Map defaultConfigIds; @@ -148,10 +148,11 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false); private final Set preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>()); + private volatile Metadata lastMetadata; + public ModelRegistry(ClusterService clusterService, Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); this.defaultConfigIds = new ConcurrentHashMap<>(); - this.clusterService = clusterService; var executor = new SimpleBatchedAckListenerTaskExecutor() { @Override public Tuple executeTask(MetadataTask task, ClusterState clusterState) throws Exception { @@ -224,11 +225,17 @@ public void clearDefaultIds() { * @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster. */ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException { + synchronized (this) { + if (lastMetadata == null) { + throw new IllegalStateException("initial cluster state not set yet"); + } + } var config = defaultConfigIds.get(inferenceEntityId); if (config != null) { return config.settings(); } - var state = ModelRegistryMetadata.fromState(clusterService.state().projectState().metadata()); + var project = lastMetadata.getProject(ProjectId.DEFAULT); + var state = ModelRegistryMetadata.fromState(project); var existing = state.getMinimalServiceSettings(inferenceEntityId); if (state.isUpgraded() && existing == null) { throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster."); @@ -684,10 +691,14 @@ private ActionListener getStoreIndexListener( if (updateClusterState) { var storeListener = getStoreMetadataListener(inferenceEntityId, listener); try { - var projectId = clusterService.state().projectState().projectId(); metadataTaskQueue.submitTask( "add model [" + inferenceEntityId + "]", - new AddModelMetadataTask(projectId, inferenceEntityId, new MinimalServiceSettings(model), storeListener), + new AddModelMetadataTask( + ProjectId.DEFAULT, + inferenceEntityId, + new MinimalServiceSettings(model), + storeListener + ), timeout ); } catch (Exception exc) { @@ -854,10 +865,9 @@ public void onFailure(Exception exc) { } }; try { - var projectId = clusterService.state().projectState().projectId(); metadataTaskQueue.submitTask( "delete models [" + inferenceEntityIds + "]", - new DeleteModelMetadataTask(projectId, inferenceEntityIds, clusterStateListener), + new DeleteModelMetadataTask(ProjectId.DEFAULT, inferenceEntityIds, clusterStateListener), null ); } catch (Exception exc) { @@ -935,6 +945,13 @@ static List taskTypeMatchedDefaults( @Override public void clusterChanged(ClusterChangedEvent event) { + if (lastMetadata == null || event.metadataChanged()) { + // keep track of the last applied cluster state + synchronized (this) { + lastMetadata = event.state().metadata(); + } + } + if (event.localNodeMaster() == false) { return; } @@ -984,7 +1001,7 @@ public void onResponse(GetInferenceModelAction.Response response) { metadataTaskQueue.submitTask( "model registry auto upgrade", new UpgradeModelsMetadataTask( - clusterService.state().metadata().getProject().id(), + ProjectId.DEFAULT, map, ActionListener.running(() -> upgradeMetadataInProgress.set(false)) ), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsMapperTests.java index 23519ec86cbc4..f21d4ab8cf4de 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsMapperTests.java @@ -37,7 +37,10 @@ public void testIsEnabled() { assertFalse(InferenceMetadataFieldsMapper.isEnabled(settings)); settings = Settings.builder() - .put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), getRandomCompatibleIndexVersion(true)) + .put( + IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), + getRandomCompatibleIndexVersion(true, IndexVersionUtils.getPreviousVersion(IndexVersions.INFERENCE_METADATA_FIELDS)) + ) .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), false) .build(); assertFalse(InferenceMetadataFieldsMapper.isEnabled(settings)); @@ -114,18 +117,18 @@ public MappedFieldType getMappedFieldType() { } static IndexVersion getRandomCompatibleIndexVersion(boolean useLegacyFormat) { + return getRandomCompatibleIndexVersion(useLegacyFormat, IndexVersion.current()); + } + + static IndexVersion getRandomCompatibleIndexVersion(boolean useLegacyFormat, IndexVersion maxVersion) { if (useLegacyFormat) { if (randomBoolean()) { - return IndexVersionUtils.randomVersionBetween( - random(), - IndexVersions.UPGRADE_TO_LUCENE_10_0_0, - IndexVersionUtils.getPreviousVersion(IndexVersions.INFERENCE_METADATA_FIELDS) - ); + return IndexVersionUtils.randomVersionBetween(random(), IndexVersions.UPGRADE_TO_LUCENE_10_0_0, maxVersion); } return IndexVersionUtils.randomPreviousCompatibleVersion(random(), IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT); } else { if (randomBoolean()) { - return IndexVersionUtils.randomVersionBetween(random(), IndexVersions.INFERENCE_METADATA_FIELDS, IndexVersion.current()); + return IndexVersionUtils.randomVersionBetween(random(), IndexVersions.INFERENCE_METADATA_FIELDS, maxVersion); } return IndexVersionUtils.randomVersionBetween( random(), @@ -134,4 +137,5 @@ static IndexVersion getRandomCompatibleIndexVersion(boolean useLegacyFormat) { ); } } + } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 4d2a76f915af3..25772345a3065 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -9,6 +9,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.IndexableField; @@ -24,6 +25,7 @@ import org.apache.lucene.search.join.QueryBitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.CheckedBiFunction; @@ -34,6 +36,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; @@ -63,6 +66,10 @@ import org.elasticsearch.search.LeafNestedDocuments; import org.elasticsearch.search.NestedDocuments; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.test.index.IndexVersionUtils; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; @@ -70,7 +77,10 @@ import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; import org.junit.AssumptionViolatedException; +import org.junit.Before; import java.io.IOException; import java.util.Collection; @@ -80,6 +90,7 @@ import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; +import java.util.function.Supplier; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; @@ -101,10 +112,22 @@ public class SemanticTextFieldMapperTests extends MapperTestCase { private final boolean useLegacyFormat; + private TestThreadPool threadPool; + public SemanticTextFieldMapperTests(boolean useLegacyFormat) { this.useLegacyFormat = useLegacyFormat; } + @Before + private void startThreadPool() { + threadPool = createThreadPool(); + } + + @After + private void stopThreadPool() { + threadPool.close(); + } + @ParametersFactory public static Iterable parameters() throws Exception { return List.of(new Object[] { true }, new Object[] { false }); @@ -112,18 +135,61 @@ public static Iterable parameters() throws Exception { @Override protected Collection getPlugins() { - return List.of(new InferencePlugin(Settings.EMPTY), new XPackClientPlugin()); + var clusterService = ClusterServiceUtils.createClusterService(threadPool); + var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { + @Override + public boolean localNodeMaster() { + return false; + } + }); + return List.of(new InferencePlugin(Settings.EMPTY) { + @Override + protected Supplier getModelRegistry() { + return () -> modelRegistry; + } + }, new XPackClientPlugin()); } private MapperService createMapperService(XContentBuilder mappings, boolean useLegacyFormat) throws IOException { + IndexVersion indexVersion = SemanticInferenceMetadataFieldsMapperTests.getRandomCompatibleIndexVersion(useLegacyFormat); + return createMapperService(mappings, useLegacyFormat, indexVersion, indexVersion, false); + } + + private MapperService createMapperService(XContentBuilder mappings, boolean useLegacyFormat, IndexVersion minIndexVersion) + throws IOException { + return createMapperService(mappings, useLegacyFormat, minIndexVersion, IndexVersion.current(), false); + } + + private MapperService createMapperService( + XContentBuilder mappings, + boolean useLegacyFormat, + IndexVersion minIndexVersion, + IndexVersion maxIndexVersion, + boolean propagateIndexVersion + ) throws IOException { + validateIndexVersion(minIndexVersion, useLegacyFormat); + IndexVersion indexVersion = IndexVersionUtils.randomVersionBetween(random(), minIndexVersion, maxIndexVersion); var settings = Settings.builder() - .put( - IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), - SemanticInferenceMetadataFieldsMapperTests.getRandomCompatibleIndexVersion(useLegacyFormat) - ) + .put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), indexVersion) .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat) .build(); - return createMapperService(settings, mappings); + // TODO - This is added, because we discovered a bug where the index version was not being correctly propagated + // in our mappings even though we were specifying the index version in settings. We will fix this in a followup and + // remove the boolean flag accordingly. + if (propagateIndexVersion) { + return createMapperService(indexVersion, settings, mappings); + } else { + return createMapperService(settings, mappings); + } + } + + private static void validateIndexVersion(IndexVersion indexVersion, boolean useLegacyFormat) { + if (useLegacyFormat == false + && indexVersion.before(IndexVersions.INFERENCE_METADATA_FIELDS) + && indexVersion.between(IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT, IndexVersions.UPGRADE_TO_LUCENE_10_0_0) == false) { + throw new IllegalArgumentException("Index version " + indexVersion + " does not support new semantic text format"); + } } @Override @@ -569,14 +635,15 @@ public void testUpdateSearchInferenceId() throws IOException { } private static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { - assertSemanticTextField(mapperService, fieldName, expectedModelSettings, null); + assertSemanticTextField(mapperService, fieldName, expectedModelSettings, null, null); } private static void assertSemanticTextField( MapperService mapperService, String fieldName, boolean expectedModelSettings, - ChunkingSettings expectedChunkingSettings + ChunkingSettings expectedChunkingSettings, + DenseVectorFieldMapper.IndexOptions expectedIndexOptions ) { Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); assertNotNull(mapper); @@ -622,8 +689,17 @@ private static void assertSemanticTextField( assertThat(embeddingsMapper, instanceOf(SparseVectorFieldMapper.class)); SparseVectorFieldMapper sparseMapper = (SparseVectorFieldMapper) embeddingsMapper; assertEquals(sparseMapper.fieldType().isStored(), semanticTextFieldType.useLegacyFormat() == false); + assertNull(expectedIndexOptions); + } + case TEXT_EMBEDDING -> { + assertThat(embeddingsMapper, instanceOf(DenseVectorFieldMapper.class)); + DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) embeddingsMapper; + if (expectedIndexOptions != null) { + assertEquals(expectedIndexOptions, denseVectorFieldMapper.fieldType().getIndexOptions()); + } else { + assertNull(denseVectorFieldMapper.fieldType().getIndexOptions()); + } } - case TEXT_EMBEDDING -> assertThat(embeddingsMapper, instanceOf(DenseVectorFieldMapper.class)); default -> throw new AssertionError("Invalid task type"); } } else { @@ -918,11 +994,11 @@ public void testSettingAndUpdatingChunkingSettings() throws IOException { mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, chunkingSettings)), useLegacyFormat ); - assertSemanticTextField(mapperService, fieldName, false, chunkingSettings); + assertSemanticTextField(mapperService, fieldName, false, chunkingSettings, null); ChunkingSettings newChunkingSettings = generateRandomChunkingSettingsOtherThan(chunkingSettings); merge(mapperService, mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, newChunkingSettings))); - assertSemanticTextField(mapperService, fieldName, false, newChunkingSettings); + assertSemanticTextField(mapperService, fieldName, false, newChunkingSettings, null); } public void testModelSettingsRequiredWithChunks() throws IOException { @@ -1052,6 +1128,74 @@ public void testExistsQueryDenseVector() throws IOException { assertThat(existsQuery, instanceOf(ESToParentBlockJoinQuery.class)); } + private static DenseVectorFieldMapper.IndexOptions defaultDenseVectorIndexOptions() { + // These are the default index options for dense_vector fields, and used for semantic_text fields incompatible with BBQ. + int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; + int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; + return new DenseVectorFieldMapper.Int8HnswIndexOptions(m, efConstruction, null, null); + } + + public void testDefaultIndexOptions() throws IOException { + + // We default to BBQ for eligible dense vectors + var mapperService = createMapperService(fieldMapping(b -> { + b.field("type", "semantic_text"); + b.field("inference_id", "another_inference_id"); + b.startObject("model_settings"); + b.field("task_type", "text_embedding"); + b.field("dimensions", 100); + b.field("similarity", "cosine"); + b.field("element_type", "float"); + b.endObject(); + }), useLegacyFormat, IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ); + assertSemanticTextField(mapperService, "field", true, null, SemanticTextFieldMapper.defaultSemanticDenseIndexOptions()); + + // Element types that are incompatible with BBQ will continue to use dense_vector defaults + mapperService = createMapperService(fieldMapping(b -> { + b.field("type", "semantic_text"); + b.field("inference_id", "another_inference_id"); + b.startObject("model_settings"); + b.field("task_type", "text_embedding"); + b.field("dimensions", 100); + b.field("similarity", "cosine"); + b.field("element_type", "byte"); + b.endObject(); + }), useLegacyFormat, IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ); + assertSemanticTextField(mapperService, "field", true, null, null); + + // A dim count of 10 is too small to support BBQ, so we continue to use dense_vector defaults + mapperService = createMapperService(fieldMapping(b -> { + b.field("type", "semantic_text"); + b.field("inference_id", "another_inference_id"); + b.startObject("model_settings"); + b.field("task_type", "text_embedding"); + b.field("dimensions", 10); + b.field("similarity", "cosine"); + b.field("element_type", "float"); + b.endObject(); + }), useLegacyFormat, IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ); + assertSemanticTextField(mapperService, "field", true, null, defaultDenseVectorIndexOptions()); + + // Previous index versions do not set BBQ index options + mapperService = createMapperService(fieldMapping(b -> { + b.field("type", "semantic_text"); + b.field("inference_id", "another_inference_id"); + b.startObject("model_settings"); + b.field("task_type", "text_embedding"); + b.field("dimensions", 100); + b.field("similarity", "cosine"); + b.field("element_type", "float"); + b.endObject(); + }), + useLegacyFormat, + IndexVersions.INFERENCE_METADATA_FIELDS, + IndexVersionUtils.getPreviousVersion(IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ), + true + ); + assertSemanticTextField(mapperService, "field", true, null, defaultDenseVectorIndexOptions()); + + } + @Override protected void assertExistsQuery(MappedFieldType fieldType, Query query, LuceneDocument fields) { // Until a doc is indexed, the query is rewritten as match no docs diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index c4a6b92ac033c..e0ba14c8959fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -22,12 +22,14 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.IOUtils; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.MapperService; @@ -46,6 +48,9 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.AbstractQueryTestCase; +import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; @@ -60,6 +65,8 @@ import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -70,6 +77,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static org.apache.lucene.search.BooleanClause.Occur.FILTER; import static org.apache.lucene.search.BooleanClause.Occur.MUST; @@ -118,6 +126,24 @@ public static void setInferenceResultType() { useSearchInferenceId = randomBoolean(); } + @BeforeClass + public static void startModelRegistry() { + threadPool = new TestThreadPool(SemanticQueryBuilderTests.class.getName()); + var clusterService = ClusterServiceUtils.createClusterService(threadPool); + modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { + @Override + public boolean localNodeMaster() { + return false; + } + }); + } + + @AfterClass + public static void stopModelRegistry() { + IOUtils.closeWhileHandlingException(threadPool); + } + @Override @Before public void setUp() throws Exception { @@ -127,7 +153,7 @@ public void setUp() throws Exception { @Override protected Collection> getPlugins() { - return List.of(XPackClientPlugin.class, InferencePlugin.class, FakeMlPlugin.class); + return List.of(XPackClientPlugin.class, InferencePluginWithModelRegistry.class, FakeMlPlugin.class); } @Override @@ -394,4 +420,18 @@ public List getNamedWriteables() { return new MlInferenceNamedXContentProvider().getNamedWriteables(); } } + + private static TestThreadPool threadPool; + private static ModelRegistry modelRegistry; + + public static class InferencePluginWithModelRegistry extends InferencePlugin { + public InferencePluginWithModelRegistry(Settings settings) { + super(settings); + } + + @Override + protected Supplier getModelRegistry() { + return () -> modelRegistry; + } + } } diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index da01459b057b6..f39b3f2b01368 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -9,13 +9,20 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.client.Request; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; +import org.junit.After; import org.junit.ClassRule; +import java.io.IOException; +import java.util.List; +import java.util.Map; + public class InferenceRestIT extends ESClientYamlSuiteTestCase { @ClassRule @@ -50,4 +57,24 @@ protected String getTestRestCluster() { public static Iterable parameters() throws Exception { return ESClientYamlSuiteTestCase.createParameters(); } + + @After + public void cleanup() throws Exception { + for (var model : getAllModels()) { + var inferenceId = model.get("inference_id"); + try { + var endpoint = Strings.format("_inference/%s?force", inferenceId); + adminClient().performRequest(new Request("DELETE", endpoint)); + } catch (Exception ex) { + logger.warn(() -> "failed to delete inference endpoint " + inferenceId, ex); + } + } + } + + @SuppressWarnings("unchecked") + static List> getAllModels() throws IOException { + var request = new Request("GET", "_inference/_all"); + var response = client().performRequest(request); + return (List>) entityAsMap(response).get("endpoints"); + } } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference.yml index 68c2658c66234..3dba90c0f89ae 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference.yml @@ -1176,3 +1176,50 @@ setup: - exists: hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.1.embeddings - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.1.start_offset: 20 } - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.1.end_offset: 35 } + +--- +"inference endpoint late creation": + - do: + indices.create: + index: new-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: new_inference_endpoint + + - do: + inference.put: + task_type: sparse_embedding + inference_id: new_inference_endpoint + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + + - do: + index: + index: new-index + id: doc_1 + body: + inference_field: "inference test" + refresh: true + + - do: + search: + index: new-index + body: + query: + exists: + field: "inference_field" + + - match: { hits.total.value: 1 } + diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference_bwc.yml index 75df28148c148..2924ce8a108ef 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference_bwc.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference_bwc.yml @@ -739,3 +739,54 @@ setup: - exists: hits.hits.0._source.sparse_field.inference.chunks.0.embeddings - match: { hits.hits.0._source.sparse_field.inference.chunks.1.text: "now with chunks" } - exists: hits.hits.0._source.sparse_field.inference.chunks.1.embeddings + +--- +"inference endpoint late creation": + - do: + indices.create: + index: new-index + body: + settings: + index: + mapping: + semantic_text: + use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: new_inference_endpoint + + - do: + inference.put: + task_type: sparse_embedding + inference_id: new_inference_endpoint + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + + - do: + index: + index: new-index + id: doc_1 + body: + inference_field: "inference test" + refresh: true + + - do: + search: + index: new-index + body: + query: + exists: + field: "inference_field" + + - match: { hits.total.value: 1 }