diff --git a/docs/changelog/131058.yaml b/docs/changelog/131058.yaml new file mode 100644 index 0000000000000..f6a9096064d71 --- /dev/null +++ b/docs/changelog/131058.yaml @@ -0,0 +1,5 @@ +pr: 131058 +summary: Adds sparse vector index options settings to semantic_text field +area: Search +type: enhancement +issues: [] diff --git a/docs/reference/elasticsearch/mapping-reference/semantic-text.md b/docs/reference/elasticsearch/mapping-reference/semantic-text.md index 3b94d500f5e0b..5a7895b0f6d11 100644 --- a/docs/reference/elasticsearch/mapping-reference/semantic-text.md +++ b/docs/reference/elasticsearch/mapping-reference/semantic-text.md @@ -156,9 +156,11 @@ to create the endpoint. If not specified, the {{infer}} endpoint defined by `index_options` {applies_to}`stack: ga 9.1` : (Optional, object) Specifies the index options to override default values -for the field. Currently, `dense_vector` index options are supported. -For text embeddings, `index_options` may match any allowed -[dense_vector index options](/reference/elasticsearch/mapping-reference/dense-vector.md#dense-vector-index-options). +for the field. Currently, `dense_vector` and `sparse_vector` index options are supported. +For text embeddings, `index_options` may match any allowed. + +* [dense_vector index options](/reference/elasticsearch/mapping-reference/dense-vector.md#dense-vector-index-options). +* [sparse_vector index options](/reference/elasticsearch/mapping-reference/sparse-vector.md#sparse-vectors-params). {applies_to}`stack: ga 9.2` `chunking_settings` {applies_to}`stack: ga 9.1` : (Optional, object) Settings for chunking text into smaller passages. @@ -410,7 +412,7 @@ stack: ga 9.0 In case you want to customize data indexing, use the [`sparse_vector`](/reference/elasticsearch/mapping-reference/sparse-vector.md) or [`dense_vector`](/reference/elasticsearch/mapping-reference/dense-vector.md) -field types and create an ingest pipeline with an +field types and create an ingest pipeline with an [{{infer}} processor](/reference/enrich-processor/inference-processor.md) to generate the embeddings. [This tutorial](docs-content://solutions/search/semantic-search/semantic-search-inference.md) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java index 8eacc68b45c88..a91c84405b295 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java @@ -48,7 +48,6 @@ import org.elasticsearch.xcontent.DeprecationHandler; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParser.Token; @@ -98,7 +97,7 @@ public static class Builder extends FieldMapper.Builder { private final Parameter stored = Parameter.storeParam(m -> toType(m).fieldType().isStored(), false); private final Parameter> meta = Parameter.metaParam(); - private final Parameter indexOptions = new Parameter<>( + private final Parameter indexOptions = new Parameter<>( SPARSE_VECTOR_INDEX_OPTIONS, true, () -> null, @@ -128,9 +127,9 @@ protected Parameter[] getParameters() { @Override public SparseVectorFieldMapper build(MapperBuilderContext context) { - IndexOptions builderIndexOptions = indexOptions.getValue(); + SparseVectorIndexOptions builderIndexOptions = indexOptions.getValue(); if (builderIndexOptions == null) { - builderIndexOptions = getDefaultIndexOptions(indexVersionCreated); + builderIndexOptions = SparseVectorIndexOptions.getDefaultIndexOptions(indexVersionCreated); } final boolean syntheticVectorFinal = context.isSourceSynthetic() == false && isSyntheticVector; @@ -149,33 +148,34 @@ public SparseVectorFieldMapper build(MapperBuilderContext context) { ); } - private IndexOptions getDefaultIndexOptions(IndexVersion indexVersion) { - return (indexVersion.onOrAfter(SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_VERSION) - || indexVersion.between(SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_VERSION_8_X, IndexVersions.UPGRADE_TO_LUCENE_10_0_0)) - ? IndexOptions.DEFAULT_PRUNING_INDEX_OPTIONS - : null; + private boolean indexOptionsSerializerCheck(boolean includeDefaults, boolean isConfigured, SparseVectorIndexOptions value) { + return includeDefaults || (SparseVectorIndexOptions.isDefaultOptions(value, indexVersionCreated) == false); } - private boolean indexOptionsSerializerCheck(boolean includeDefaults, boolean isConfigured, IndexOptions value) { - return includeDefaults || (IndexOptions.isDefaultOptions(value, indexVersionCreated) == false); + public void setIndexOptions(SparseVectorIndexOptions sparseVectorIndexOptions) { + indexOptions.setValue(sparseVectorIndexOptions); } } - public IndexOptions getIndexOptions() { + public SparseVectorIndexOptions getIndexOptions() { return fieldType().getIndexOptions(); } - private static final ConstructingObjectParser INDEX_OPTIONS_PARSER = new ConstructingObjectParser<>( + private static final ConstructingObjectParser INDEX_OPTIONS_PARSER = new ConstructingObjectParser<>( SPARSE_VECTOR_INDEX_OPTIONS, - args -> new IndexOptions((Boolean) args[0], (TokenPruningConfig) args[1]) + args -> new SparseVectorIndexOptions((Boolean) args[0], (TokenPruningConfig) args[1]) ); static { - INDEX_OPTIONS_PARSER.declareBoolean(optionalConstructorArg(), IndexOptions.PRUNE_FIELD_NAME); - INDEX_OPTIONS_PARSER.declareObject(optionalConstructorArg(), TokenPruningConfig.PARSER, IndexOptions.PRUNING_CONFIG_FIELD_NAME); + INDEX_OPTIONS_PARSER.declareBoolean(optionalConstructorArg(), SparseVectorIndexOptions.PRUNE_FIELD_NAME); + INDEX_OPTIONS_PARSER.declareObject( + optionalConstructorArg(), + TokenPruningConfig.PARSER, + SparseVectorIndexOptions.PRUNING_CONFIG_FIELD_NAME + ); } - private static SparseVectorFieldMapper.IndexOptions parseIndexOptions(MappingParserContext context, Object propNode) { + private static SparseVectorIndexOptions parseIndexOptions(MappingParserContext context, Object propNode) { if (propNode == null) { return null; } @@ -212,7 +212,7 @@ private static SparseVectorFieldMapper.IndexOptions parseIndexOptions(MappingPar public static final class SparseVectorFieldType extends MappedFieldType { private final IndexVersion indexVersionCreated; - private final IndexOptions indexOptions; + private final SparseVectorIndexOptions indexOptions; public SparseVectorFieldType(IndexVersion indexVersionCreated, String name, boolean isStored, Map meta) { this(indexVersionCreated, name, isStored, meta, null); @@ -223,14 +223,14 @@ public SparseVectorFieldType( String name, boolean isStored, Map meta, - @Nullable SparseVectorFieldMapper.IndexOptions indexOptions + @Nullable SparseVectorIndexOptions indexOptions ) { super(name, true, isStored, false, TextSearchInfo.SIMPLE_MATCH_ONLY, meta); this.indexVersionCreated = indexVersionCreated; this.indexOptions = indexOptions; } - public IndexOptions getIndexOptions() { + public SparseVectorIndexOptions getIndexOptions() { return indexOptions; } @@ -560,15 +560,18 @@ public void reset() { } } - public static class IndexOptions implements ToXContent { + public static class SparseVectorIndexOptions implements IndexOptions { public static final ParseField PRUNE_FIELD_NAME = new ParseField("prune"); public static final ParseField PRUNING_CONFIG_FIELD_NAME = new ParseField("pruning_config"); - public static final IndexOptions DEFAULT_PRUNING_INDEX_OPTIONS = new IndexOptions(true, new TokenPruningConfig()); + public static final SparseVectorIndexOptions DEFAULT_PRUNING_INDEX_OPTIONS = new SparseVectorIndexOptions( + true, + new TokenPruningConfig() + ); final Boolean prune; final TokenPruningConfig pruningConfig; - IndexOptions(@Nullable Boolean prune, @Nullable TokenPruningConfig pruningConfig) { + public SparseVectorIndexOptions(@Nullable Boolean prune, @Nullable TokenPruningConfig pruningConfig) { if (pruningConfig != null && (prune == null || prune == false)) { throw new IllegalArgumentException( "[" @@ -585,14 +588,37 @@ public static class IndexOptions implements ToXContent { this.pruningConfig = pruningConfig; } - public static boolean isDefaultOptions(IndexOptions indexOptions, IndexVersion indexVersion) { - IndexOptions defaultIndexOptions = indexVersionSupportsDefaultPruningConfig(indexVersion) + public static boolean isDefaultOptions(SparseVectorIndexOptions indexOptions, IndexVersion indexVersion) { + SparseVectorIndexOptions defaultIndexOptions = indexVersionSupportsDefaultPruningConfig(indexVersion) ? DEFAULT_PRUNING_INDEX_OPTIONS : null; return Objects.equals(indexOptions, defaultIndexOptions); } + public static SparseVectorIndexOptions getDefaultIndexOptions(IndexVersion indexVersion) { + return indexVersionSupportsDefaultPruningConfig(indexVersion) ? DEFAULT_PRUNING_INDEX_OPTIONS : null; + } + + public static SparseVectorIndexOptions parseFromMap(Map map) { + if (map == null) { + return null; + } + + try { + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + + return INDEX_OPTIONS_PARSER.parse(parser, null); + } catch (IOException ioEx) { + throw new UncheckedIOException(ioEx); + } + } + public Boolean getPrune() { return prune; } @@ -626,7 +652,7 @@ public final boolean equals(Object other) { return false; } - IndexOptions otherAsIndexOptions = (IndexOptions) other; + SparseVectorIndexOptions otherAsIndexOptions = (SparseVectorIndexOptions) other; return Objects.equals(prune, otherAsIndexOptions.prune) && Objects.equals(pruningConfig, otherAsIndexOptions.pruningConfig); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java index dc21e5f8f57cc..b3899db76dd7c 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java @@ -906,4 +906,8 @@ private Map toFloats(Map value) { } return result; } + + public static IndexVersion getIndexOptionsCompatibleIndexVersion() { + return IndexVersionUtils.randomVersionBetween(random(), SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT, IndexVersion.current()); + } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldTypeTests.java index 0503204886abb..037a44a3c9e25 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldTypeTests.java @@ -40,4 +40,28 @@ public void testIsNotAggregatable() { MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType(indexVersion, "field", false, Collections.emptyMap()); assertFalse(fieldType.isAggregatable()); } + + public static SparseVectorFieldMapper.SparseVectorIndexOptions randomSparseVectorIndexOptions() { + return randomSparseVectorIndexOptions(true); + } + + public static SparseVectorFieldMapper.SparseVectorIndexOptions randomSparseVectorIndexOptions(boolean includeNull) { + if (includeNull && randomBoolean()) { + return null; + } + + Boolean prune = randomBoolean() ? null : randomBoolean(); + if (prune == null) { + new SparseVectorFieldMapper.SparseVectorIndexOptions(null, null); + } + + if (prune == Boolean.FALSE) { + new SparseVectorFieldMapper.SparseVectorIndexOptions(false, null); + } + + return new SparseVectorFieldMapper.SparseVectorIndexOptions( + true, + new TokenPruningConfig(randomFloatBetween(1.0f, 100.0f, true), randomFloatBetween(0.0f, 1.0f, true), randomBoolean()) + ); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index c35ac4c413773..996f4e601289a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -19,6 +19,7 @@ import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_INDEX_OPTIONS; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG; import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_FILTER_FIX; import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED; @@ -78,7 +79,8 @@ public Set getTestFeatures() { COHERE_V2_API, SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS, SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX, - SEMANTIC_TEXT_HIGHLIGHTING_FLAT + SEMANTIC_TEXT_HIGHLIGHTING_FLAT, + SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS ) ); if (RERANK_SNIPPETS.isEnabled()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 9972fa9e5ae0b..0d260a557f602 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -143,6 +143,9 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie public static final NodeFeature SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS = new NodeFeature( "semantic_text.index_options_with_defaults" ); + public static final NodeFeature SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS = new NodeFeature( + "semantic_text.sparse_vector_index_options" + ); public static final String CONTENT_TYPE = "semantic_text"; public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID; @@ -458,8 +461,20 @@ private void validateIndexOptions(SemanticTextIndexOptions indexOptions, String ); } - if (indexOptions.type() == SemanticTextIndexOptions.SupportedIndexOptions.DENSE_VECTOR) { + if (indexOptions.type() == SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR) { + if (modelSettings.taskType() != SPARSE_EMBEDDING) { + throw new IllegalArgumentException( + "Invalid task type for index options, required [" + + SPARSE_EMBEDDING + + "] but was [" + + modelSettings.taskType() + + "]" + ); + } + return; + } + if (indexOptions.type() == SemanticTextIndexOptions.SupportedIndexOptions.DENSE_VECTOR) { if (modelSettings.taskType() != TEXT_EMBEDDING) { throw new IllegalArgumentException( "Invalid task type for index options, required [" + TEXT_EMBEDDING + "] but was [" + modelSettings.taskType() + "]" @@ -471,7 +486,6 @@ private void validateIndexOptions(SemanticTextIndexOptions indexOptions, String (DenseVectorFieldMapper.DenseVectorIndexOptions) indexOptions.indexOptions(); denseVectorIndexOptions.validate(modelSettings.elementType(), dims, true); } - } /** @@ -1169,9 +1183,17 @@ private static Mapper.Builder createEmbeddingsField( boolean useLegacyFormat ) { return switch (modelSettings.taskType()) { - case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(CHUNKED_EMBEDDINGS_FIELD, indexVersionCreated, false).setStored( - useLegacyFormat == false - ); + case SPARSE_EMBEDDING -> { + SparseVectorFieldMapper.Builder sparseVectorMapperBuilder = new SparseVectorFieldMapper.Builder( + CHUNKED_EMBEDDINGS_FIELD, + indexVersionCreated, + false + ).setStored(useLegacyFormat == false); + + configureSparseVectorMapperBuilder(indexVersionCreated, sparseVectorMapperBuilder, indexOptions); + + yield sparseVectorMapperBuilder; + } case TEXT_EMBEDDING -> { DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( CHUNKED_EMBEDDINGS_FIELD, @@ -1179,45 +1201,7 @@ private static Mapper.Builder createEmbeddingsField( false ); - SimilarityMeasure similarity = modelSettings.similarity(); - if (similarity != null) { - switch (similarity) { - case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); - case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); - case L2_NORM -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.L2_NORM); - default -> throw new IllegalArgumentException( - "Unknown similarity measure in model_settings [" + similarity.name() + "]" - ); - } - } - denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); - denseVectorMapperBuilder.elementType(modelSettings.elementType()); - // Here is where we persist index_options. If they are specified by the user, we will use those index_options, - // otherwise we will determine if we can set default index options. If we can't, we won't persist any index_options - // and the field will use the defaults for the dense_vector field. - if (indexOptions != null) { - DenseVectorFieldMapper.DenseVectorIndexOptions denseVectorIndexOptions = - (DenseVectorFieldMapper.DenseVectorIndexOptions) indexOptions.indexOptions(); - denseVectorMapperBuilder.indexOptions(denseVectorIndexOptions); - denseVectorIndexOptions.validate(modelSettings.elementType(), modelSettings.dimensions(), true); - } else { - DenseVectorFieldMapper.DenseVectorIndexOptions defaultIndexOptions = defaultDenseVectorIndexOptions( - indexVersionCreated, - modelSettings - ); - if (defaultIndexOptions != null) { - denseVectorMapperBuilder.indexOptions(defaultIndexOptions); - } - } - - boolean hasUserSpecifiedIndexOptions = indexOptions != null; - DenseVectorFieldMapper.DenseVectorIndexOptions denseVectorIndexOptions = hasUserSpecifiedIndexOptions - ? (DenseVectorFieldMapper.DenseVectorIndexOptions) indexOptions.indexOptions() - : defaultDenseVectorIndexOptions(indexVersionCreated, modelSettings); - - if (denseVectorIndexOptions != null) { - denseVectorMapperBuilder.indexOptions(denseVectorIndexOptions); - } + configureDenseVectorMapperBuilder(indexVersionCreated, denseVectorMapperBuilder, modelSettings, indexOptions); yield denseVectorMapperBuilder; } @@ -1225,6 +1209,62 @@ private static Mapper.Builder createEmbeddingsField( }; } + private static void configureSparseVectorMapperBuilder( + IndexVersion indexVersionCreated, + SparseVectorFieldMapper.Builder sparseVectorMapperBuilder, + SemanticTextIndexOptions indexOptions + ) { + if (indexOptions != null) { + SparseVectorFieldMapper.SparseVectorIndexOptions sparseVectorIndexOptions = + (SparseVectorFieldMapper.SparseVectorIndexOptions) indexOptions.indexOptions(); + + sparseVectorMapperBuilder.setIndexOptions(sparseVectorIndexOptions); + } else { + SparseVectorFieldMapper.SparseVectorIndexOptions defaultIndexOptions = SparseVectorFieldMapper.SparseVectorIndexOptions + .getDefaultIndexOptions(indexVersionCreated); + if (defaultIndexOptions != null) { + sparseVectorMapperBuilder.setIndexOptions(defaultIndexOptions); + } + } + } + + private static void configureDenseVectorMapperBuilder( + IndexVersion indexVersionCreated, + DenseVectorFieldMapper.Builder denseVectorMapperBuilder, + MinimalServiceSettings modelSettings, + SemanticTextIndexOptions indexOptions + ) { + SimilarityMeasure similarity = modelSettings.similarity(); + if (similarity != null) { + switch (similarity) { + case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); + case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); + case L2_NORM -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.L2_NORM); + default -> throw new IllegalArgumentException("Unknown similarity measure in model_settings [" + similarity.name() + "]"); + } + } + + denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); + denseVectorMapperBuilder.elementType(modelSettings.elementType()); + // Here is where we persist index_options. If they are specified by the user, we will use those index_options, + // otherwise we will determine if we can set default index options. If we can't, we won't persist any index_options + // and the field will use the defaults for the dense_vector field. + if (indexOptions != null) { + DenseVectorFieldMapper.DenseVectorIndexOptions denseVectorIndexOptions = + (DenseVectorFieldMapper.DenseVectorIndexOptions) indexOptions.indexOptions(); + denseVectorMapperBuilder.indexOptions(denseVectorIndexOptions); + denseVectorIndexOptions.validate(modelSettings.elementType(), modelSettings.dimensions(), true); + } else { + DenseVectorFieldMapper.DenseVectorIndexOptions defaultIndexOptions = defaultDenseVectorIndexOptions( + indexVersionCreated, + modelSettings + ); + if (defaultIndexOptions != null) { + denseVectorMapperBuilder.indexOptions(defaultIndexOptions); + } + } + } + static DenseVectorFieldMapper.DenseVectorIndexOptions defaultDenseVectorIndexOptions( IndexVersion indexVersionCreated, MinimalServiceSettings modelSettings @@ -1259,23 +1299,30 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDense } static SemanticTextIndexOptions defaultIndexOptions(IndexVersion indexVersionCreated, MinimalServiceSettings modelSettings) { - if (modelSettings == null) { return null; } - SemanticTextIndexOptions defaultIndexOptions = null; if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) { DenseVectorFieldMapper.DenseVectorIndexOptions denseVectorIndexOptions = defaultDenseVectorIndexOptions( indexVersionCreated, modelSettings ); - defaultIndexOptions = denseVectorIndexOptions == null + return denseVectorIndexOptions == null ? null : new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.DENSE_VECTOR, denseVectorIndexOptions); } - return defaultIndexOptions; + if (modelSettings.taskType() == SPARSE_EMBEDDING) { + SparseVectorFieldMapper.SparseVectorIndexOptions sparseVectorIndexOptions = SparseVectorFieldMapper.SparseVectorIndexOptions + .getDefaultIndexOptions(indexVersionCreated); + + return sparseVectorIndexOptions == null + ? null + : new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, sparseVectorIndexOptions); + } + + return null; } private static boolean canMergeModelSettings(MinimalServiceSettings previous, MinimalServiceSettings current, Conflicts conflicts) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextIndexOptions.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextIndexOptions.java index db647499f446f..7d12995876f4b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextIndexOptions.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextIndexOptions.java @@ -13,6 +13,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.IndexOptions; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; @@ -76,6 +77,12 @@ public enum SupportedIndexOptions { public IndexOptions parseIndexOptions(String fieldName, Map map, IndexVersion indexVersion) { return parseDenseVectorIndexOptionsFromMap(fieldName, map, indexVersion); } + }, + SPARSE_VECTOR("sparse_vector") { + @Override + public IndexOptions parseIndexOptions(String fieldName, Map map, IndexVersion indexVersion) { + return parseSparseVectorIndexOptionsFromMap(map); + } }; public final String value; @@ -127,4 +134,12 @@ private static DenseVectorFieldMapper.DenseVectorIndexOptions parseDenseVectorIn throw new ElasticsearchException(exc); } } + + private static SparseVectorFieldMapper.SparseVectorIndexOptions parseSparseVectorIndexOptionsFromMap(Map map) { + try { + return SparseVectorFieldMapper.SparseVectorIndexOptions.parseFromMap(map); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index cc87edf59e9d3..7ee178cbe2af6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -56,6 +56,9 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldTypeTests; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapperTests; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldTypeTests; +import org.elasticsearch.index.mapper.vectors.TokenPruningConfig; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; import org.elasticsearch.inference.ChunkingSettings; @@ -95,6 +98,7 @@ import java.util.function.Supplier; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldTypeTests.randomIndexOptionsAll; +import static org.elasticsearch.index.mapper.vectors.SparseVectorFieldTypeTests.randomSparseVectorIndexOptions; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; @@ -113,6 +117,9 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; public class SemanticTextFieldMapperTests extends MapperTestCase { private final boolean useLegacyFormat; @@ -123,9 +130,20 @@ public SemanticTextFieldMapperTests(boolean useLegacyFormat) { this.useLegacyFormat = useLegacyFormat; } + ModelRegistry globalModelRegistry; + @Before private void startThreadPool() { threadPool = createThreadPool(); + var clusterService = ClusterServiceUtils.createClusterService(threadPool); + var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + globalModelRegistry = spy(modelRegistry); + globalModelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { + @Override + public boolean localNodeMaster() { + return false; + } + }); } @After @@ -140,18 +158,10 @@ public static Iterable parameters() throws Exception { @Override protected Collection getPlugins() { - var clusterService = ClusterServiceUtils.createClusterService(threadPool); - var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); - modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { - @Override - public boolean localNodeMaster() { - return false; - } - }); return List.of(new InferencePlugin(Settings.EMPTY) { @Override protected Supplier getModelRegistry() { - return () -> modelRegistry; + return () -> globalModelRegistry; } }, new XPackClientPlugin()); } @@ -174,6 +184,11 @@ private MapperService createMapperService( ) throws IOException { validateIndexVersion(minIndexVersion, useLegacyFormat); IndexVersion indexVersion = IndexVersionUtils.randomVersionBetween(random(), minIndexVersion, maxIndexVersion); + return createMapperServiceWithIndexVersion(mappings, useLegacyFormat, indexVersion); + } + + private MapperService createMapperServiceWithIndexVersion(XContentBuilder mappings, boolean useLegacyFormat, IndexVersion indexVersion) + throws IOException { var settings = Settings.builder() .put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), indexVersion) .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat) @@ -189,17 +204,6 @@ private static void validateIndexVersion(IndexVersion indexVersion, boolean useL } } - private MapperService createMapperService(String mappings, boolean useLegacyFormat) throws IOException { - var settings = Settings.builder() - .put( - IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), - SemanticInferenceMetadataFieldsMapperTests.getRandomCompatibleIndexVersion(useLegacyFormat) - ) - .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat) - .build(); - return createMapperService(settings, mappings); - } - @Override protected Settings getIndexSettings() { return Settings.builder() @@ -380,6 +384,14 @@ public void testInvalidInferenceEndpoints() { } } + private SemanticTextIndexOptions getDefaultSparseVectorIndexOptionsForMapper(MapperService mapperService) { + var mapperIndexVersion = mapperService.getIndexSettings().getIndexVersionCreated(); + var defaultSparseVectorIndexOptions = SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(mapperIndexVersion); + return defaultSparseVectorIndexOptions == null + ? null + : new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, defaultSparseVectorIndexOptions); + } + public void testInvalidTaskTypes() { for (var taskType : TaskType.values()) { if (taskType == TaskType.TEXT_EMBEDDING || taskType == TaskType.SPARSE_EMBEDDING) { @@ -415,7 +427,13 @@ public void testMultiFieldsSupport() throws IOException { }), useLegacyFormat)); assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); } else { - var mapperService = createMapperService(fieldMapping(b -> { + IndexVersion indexVersion = SparseVectorFieldMapperTests.getIndexOptionsCompatibleIndexVersion(); + SparseVectorFieldMapper.SparseVectorIndexOptions expectedIndexOptions = SparseVectorFieldMapper.SparseVectorIndexOptions + .getDefaultIndexOptions(indexVersion); + SemanticTextIndexOptions semanticTextIndexOptions = expectedIndexOptions == null + ? null + : new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, expectedIndexOptions); + var mapperService = createMapperServiceWithIndexVersion(fieldMapping(b -> { b.field("type", "text"); b.startObject("fields"); b.startObject("semantic"); @@ -426,10 +444,10 @@ public void testMultiFieldsSupport() throws IOException { b.endObject(); b.endObject(); b.endObject(); - }), useLegacyFormat); - assertSemanticTextField(mapperService, "field.semantic", true, null, null); + }), useLegacyFormat, indexVersion); + assertSemanticTextField(mapperService, "field.semantic", true, null, semanticTextIndexOptions); - mapperService = createMapperService(fieldMapping(b -> { + mapperService = createMapperServiceWithIndexVersion(fieldMapping(b -> { b.field("type", "semantic_text"); b.field("inference_id", "my_inference_id"); b.startObject("model_settings"); @@ -440,10 +458,10 @@ public void testMultiFieldsSupport() throws IOException { b.field("type", "text"); b.endObject(); b.endObject(); - }), useLegacyFormat); - assertSemanticTextField(mapperService, "field", true, null, null); + }), useLegacyFormat, indexVersion); + assertSemanticTextField(mapperService, "field", true, null, semanticTextIndexOptions); - mapperService = createMapperService(fieldMapping(b -> { + mapperService = createMapperServiceWithIndexVersion(fieldMapping(b -> { b.field("type", "semantic_text"); b.field("inference_id", "my_inference_id"); b.startObject("model_settings"); @@ -458,9 +476,9 @@ public void testMultiFieldsSupport() throws IOException { b.endObject(); b.endObject(); b.endObject(); - }), useLegacyFormat); - assertSemanticTextField(mapperService, "field", true, null, null); - assertSemanticTextField(mapperService, "field.semantic", true, null, null); + }), useLegacyFormat, indexVersion); + assertSemanticTextField(mapperService, "field", true, null, semanticTextIndexOptions); + assertSemanticTextField(mapperService, "field.semantic", true, null, semanticTextIndexOptions); Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { b.field("type", "semantic_text"); @@ -472,7 +490,6 @@ public void testMultiFieldsSupport() throws IOException { b.endObject(); }), useLegacyFormat)); assertThat(e.getMessage(), containsString("is already used by another field")); - } } @@ -504,7 +521,8 @@ public void testDynamicUpdate() throws IOException { inferenceId, new MinimalServiceSettings("service", TaskType.SPARSE_EMBEDDING, null, null, null) ); - assertSemanticTextField(mapperService, fieldName, true, null, null); + var expectedIndexOptions = getDefaultSparseVectorIndexOptionsForMapper(mapperService); + assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions); assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); } @@ -515,7 +533,8 @@ public void testDynamicUpdate() throws IOException { searchInferenceId, new MinimalServiceSettings("service", TaskType.SPARSE_EMBEDDING, null, null, null) ); - assertSemanticTextField(mapperService, fieldName, true, null, null); + var expectedIndexOptions = getDefaultSparseVectorIndexOptionsForMapper(mapperService); + assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions); assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId); } } @@ -559,14 +578,16 @@ public void testUpdateModelSettings() throws IOException { .endObject() ) ); - assertSemanticTextField(mapperService, fieldName, true, null, null); + var expectedIndexOptions = getDefaultSparseVectorIndexOptionsForMapper(mapperService); + assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions); } { merge( mapperService, mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) ); - assertSemanticTextField(mapperService, fieldName, true, null, null); + var expectedIndexOptions = getDefaultSparseVectorIndexOptionsForMapper(mapperService); + assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions); } { Exception exc = expectThrows( @@ -614,6 +635,87 @@ public void testDenseVectorIndexOptionValidation() throws IOException { } } + private void addSparseVectorModelSettingsToBuilder(XContentBuilder b) throws IOException { + b.startObject("model_settings"); + b.field("task_type", TaskType.SPARSE_EMBEDDING); + b.endObject(); + } + + public void testSparseVectorIndexOptionsValidationAndMapping() throws IOException { + for (int depth = 1; depth < 5; depth++) { + String inferenceId = "test_model"; + String fieldName = randomFieldName(depth); + IndexVersion indexVersion = SparseVectorFieldMapperTests.getIndexOptionsCompatibleIndexVersion(); + var sparseVectorIndexOptions = SparseVectorFieldTypeTests.randomSparseVectorIndexOptions(); + var expectedIndexOptions = sparseVectorIndexOptions == null + ? null + : new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, sparseVectorIndexOptions); + + // should not throw an exception + MapperService mapper = createMapperServiceWithIndexVersion(mapping(b -> { + b.startObject(fieldName); + { + b.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + b.field(INFERENCE_ID_FIELD, inferenceId); + addSparseVectorModelSettingsToBuilder(b); + if (sparseVectorIndexOptions != null) { + b.startObject(INDEX_OPTIONS_FIELD); + { + b.field(SparseVectorFieldMapper.CONTENT_TYPE); + sparseVectorIndexOptions.toXContent(b, null); + } + b.endObject(); + } + } + b.endObject(); + }), useLegacyFormat, indexVersion); + + assertSemanticTextField(mapper, fieldName, true, null, expectedIndexOptions); + } + } + + public void testSparseVectorMappingUpdate() throws IOException { + for (int i = 0; i < 5; i++) { + Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + when(globalModelRegistry.getMinimalServiceSettings(anyString())).thenAnswer( + invocation -> { return new MinimalServiceSettings(model); } + ); + + final ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false); + IndexVersion indexVersion = SparseVectorFieldMapperTests.getIndexOptionsCompatibleIndexVersion(); + final SemanticTextIndexOptions indexOptions = randomSemanticTextIndexOptions(TaskType.SPARSE_EMBEDDING); + String fieldName = "field"; + + MapperService mapperService = createMapperServiceWithIndexVersion( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, chunkingSettings, indexOptions)), + useLegacyFormat, + indexVersion + ); + var expectedIndexOptions = (indexOptions == null) + ? new SemanticTextIndexOptions( + SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, + SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(indexVersion) + ) + : indexOptions; + assertSemanticTextField(mapperService, fieldName, false, chunkingSettings, expectedIndexOptions); + + final SemanticTextIndexOptions newIndexOptions = randomSemanticTextIndexOptions(TaskType.SPARSE_EMBEDDING); + expectedIndexOptions = (newIndexOptions == null) + ? new SemanticTextIndexOptions( + SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, + SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(indexVersion) + ) + : newIndexOptions; + + ChunkingSettings newChunkingSettings = generateRandomChunkingSettingsOtherThan(chunkingSettings); + merge( + mapperService, + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, newChunkingSettings, newIndexOptions)) + ); + assertSemanticTextField(mapperService, fieldName, false, newChunkingSettings, expectedIndexOptions); + } + } + public void testUpdateSearchInferenceId() throws IOException { final String inferenceId = "test_inference_id"; final String searchInferenceId1 = "test_search_inference_id_1"; @@ -650,27 +752,24 @@ public void testUpdateSearchInferenceId() throws IOException { inferenceId, new MinimalServiceSettings("my-service", TaskType.SPARSE_EMBEDDING, null, null, null) ); - assertSemanticTextField(mapperService, fieldName, true, null, null); + var expectedIndexOptions = getDefaultSparseVectorIndexOptionsForMapper(mapperService); + assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions); assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1)); - assertSemanticTextField(mapperService, fieldName, true, null, null); + assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions); assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1); merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2)); - assertSemanticTextField(mapperService, fieldName, true, null, null); + assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions); assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2); merge(mapperService, buildMapping.apply(fieldName, null)); - assertSemanticTextField(mapperService, fieldName, true, null, null); + assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions); assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); } } - private static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { - assertSemanticTextField(mapperService, fieldName, expectedModelSettings, null, null); - } - private static void assertSemanticTextField( MapperService mapperService, String fieldName, @@ -720,9 +819,20 @@ private static void assertSemanticTextField( switch (semanticFieldMapper.fieldType().getModelSettings().taskType()) { case SPARSE_EMBEDDING -> { assertThat(embeddingsMapper, instanceOf(SparseVectorFieldMapper.class)); - SparseVectorFieldMapper sparseMapper = (SparseVectorFieldMapper) embeddingsMapper; - assertEquals(sparseMapper.fieldType().isStored(), semanticTextFieldType.useLegacyFormat() == false); - assertNull(expectedIndexOptions); + SparseVectorFieldMapper sparseVectorFieldMapper = (SparseVectorFieldMapper) embeddingsMapper; + assertEquals(sparseVectorFieldMapper.fieldType().isStored(), semanticTextFieldType.useLegacyFormat() == false); + + SparseVectorFieldMapper.SparseVectorIndexOptions applied = sparseVectorFieldMapper.fieldType().getIndexOptions(); + SparseVectorFieldMapper.SparseVectorIndexOptions expected = expectedIndexOptions == null + ? null + : (SparseVectorFieldMapper.SparseVectorIndexOptions) expectedIndexOptions.indexOptions(); + if (expected == null && applied != null) { + var indexVersionCreated = mapperService.getIndexSettings().getIndexVersionCreated(); + if (SparseVectorFieldMapper.SparseVectorIndexOptions.isDefaultOptions(applied, indexVersionCreated)) { + expected = SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(indexVersionCreated); + } + } + assertEquals(expected, applied); } case TEXT_EMBEDDING -> { assertThat(embeddingsMapper, instanceOf(DenseVectorFieldMapper.class)); @@ -763,6 +873,8 @@ private static void assertInferenceEndpoints( public void testSuccessfulParse() throws IOException { for (int depth = 1; depth < 4; depth++) { + final IndexVersion indexVersion = SemanticInferenceMetadataFieldsMapperTests.getRandomCompatibleIndexVersion(useLegacyFormat); + final String fieldName1 = randomFieldName(depth); final String fieldName2 = randomFieldName(depth + 1); final String searchInferenceId = randomAlphaOfLength(8); @@ -771,6 +883,18 @@ public void testSuccessfulParse() throws IOException { TaskType taskType = TaskType.SPARSE_EMBEDDING; Model model1 = TestModel.createRandomInstance(taskType); Model model2 = TestModel.createRandomInstance(taskType); + + when(globalModelRegistry.getMinimalServiceSettings(anyString())).thenAnswer(invocation -> { + var modelId = (String) invocation.getArguments()[0]; + if (modelId.equals(model1.getInferenceEntityId())) { + return new MinimalServiceSettings(model1); + } + if (modelId.equals(model2.getInferenceEntityId())) { + return new MinimalServiceSettings(model2); + } + return null; + }); + ChunkingSettings chunkingSettings = null; // Some chunking settings configs can produce different Lucene docs counts SemanticTextIndexOptions indexOptions = randomSemanticTextIndexOptions(taskType); XContentBuilder mapping = mapping(b -> { @@ -792,15 +916,22 @@ public void testSuccessfulParse() throws IOException { ); }); - MapperService mapperService = createMapperService(mapping, useLegacyFormat); - assertSemanticTextField(mapperService, fieldName1, false, null, null); + var expectedIndexOptions = (indexOptions == null) + ? new SemanticTextIndexOptions( + SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, + SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(indexVersion) + ) + : indexOptions; + + MapperService mapperService = createMapperServiceWithIndexVersion(mapping, useLegacyFormat, indexVersion); + assertSemanticTextField(mapperService, fieldName1, false, null, expectedIndexOptions); assertInferenceEndpoints( mapperService, fieldName1, model1.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : model1.getInferenceEntityId() ); - assertSemanticTextField(mapperService, fieldName2, false, null, null); + assertSemanticTextField(mapperService, fieldName2, false, null, expectedIndexOptions); assertInferenceEndpoints( mapperService, fieldName2, @@ -1015,24 +1146,19 @@ public void testDenseVectorElementType() throws IOException { public void testSettingAndUpdatingChunkingSettings() throws IOException { Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + when(globalModelRegistry.getMinimalServiceSettings(anyString())).thenAnswer( + invocation -> { return new MinimalServiceSettings(model); } + ); + final ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false); - final SemanticTextIndexOptions indexOptions = null; + final SemanticTextIndexOptions indexOptions = randomSemanticTextIndexOptions(TaskType.SPARSE_EMBEDDING); String fieldName = "field"; - SemanticTextField randomSemanticText = randomSemanticText( - useLegacyFormat, - fieldName, - model, - chunkingSettings, - List.of("a"), - XContentType.JSON - ); - MapperService mapperService = createMapperService( mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, chunkingSettings, indexOptions)), useLegacyFormat ); - assertSemanticTextField(mapperService, fieldName, false, chunkingSettings, null); + assertSemanticTextField(mapperService, fieldName, false, chunkingSettings, indexOptions); ChunkingSettings newChunkingSettings = generateRandomChunkingSettingsOtherThan(chunkingSettings); merge( @@ -1046,6 +1172,11 @@ public void testModelSettingsRequiredWithChunks() throws IOException { // Create inference results where model settings are set to null and chunks are provided TaskType taskType = TaskType.SPARSE_EMBEDDING; Model model = TestModel.createRandomInstance(taskType); + + when(globalModelRegistry.getMinimalServiceSettings(anyString())).thenAnswer( + invocation -> { return new MinimalServiceSettings(model); } + ); + ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false); SemanticTextIndexOptions indexOptions = randomSemanticTextIndexOptions(taskType); SemanticTextField randomSemanticText = randomSemanticText( @@ -1196,6 +1327,13 @@ private static SemanticTextIndexOptions defaultBbqHnswSemanticTextIndexOptions() ); } + private static SemanticTextIndexOptions defaultSparseVectorIndexOptions(IndexVersion indexVersion) { + return new SemanticTextIndexOptions( + SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, + SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(indexVersion) + ); + } + public void testDefaultIndexOptions() throws IOException { // We default to BBQ for eligible dense vectors @@ -1318,6 +1456,42 @@ public void testDefaultIndexOptions() throws IOException { IndexVersionUtils.getPreviousVersion(IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ_BACKPORT_8_X) ); assertSemanticTextField(mapperService, "field", true, null, defaultDenseVectorSemanticIndexOptions()); + + mapperService = createMapperService(fieldMapping(b -> { + b.field("type", "semantic_text"); + b.field("inference_id", "another_inference_id"); + b.startObject("model_settings"); + b.field("task_type", "sparse_embedding"); + b.endObject(); + }), + useLegacyFormat, + IndexVersionUtils.getPreviousVersion(IndexVersions.SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT), + IndexVersions.SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT + ); + + assertSemanticTextField( + mapperService, + "field", + true, + null, + defaultSparseVectorIndexOptions(mapperService.getIndexSettings().getIndexVersionCreated()) + ); + } + + public void testSparseVectorIndexOptionsDefaultsBeforeSupport() throws IOException { + var mapperService = createMapperService(fieldMapping(b -> { + b.field("type", "semantic_text"); + b.field("inference_id", "another_inference_id"); + b.startObject("model_settings"); + b.field("task_type", "sparse_embedding"); + b.endObject(); + }), + useLegacyFormat, + IndexVersions.INFERENCE_METADATA_FIELDS, + IndexVersionUtils.getPreviousVersion(IndexVersions.SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT) + ); + + assertSemanticTextField(mapperService, "field", true, null, null); } public void testSpecifiedDenseVectorIndexOptions() throws IOException { @@ -1428,7 +1602,74 @@ public void testSpecifiedDenseVectorIndexOptions() throws IOException { b.endObject(); }), useLegacyFormat, IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT)); assertThat(e.getMessage(), containsString("Unsupported index options type invalid")); + } + + public void testSpecificSparseVectorIndexOptions() throws IOException { + for (int i = 0; i < 10; i++) { + SparseVectorFieldMapper.SparseVectorIndexOptions testIndexOptions = randomSparseVectorIndexOptions(false); + var mapperService = createMapperService(fieldMapping(b -> { + b.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + b.field(INFERENCE_ID_FIELD, "test_inference_id"); + addSparseVectorModelSettingsToBuilder(b); + b.startObject(INDEX_OPTIONS_FIELD); + { + b.field(SparseVectorFieldMapper.CONTENT_TYPE); + testIndexOptions.toXContent(b, null); + } + b.endObject(); + }), useLegacyFormat, IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT); + assertSemanticTextField( + mapperService, + "field", + true, + null, + new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, testIndexOptions) + ); + } + } + + public void testSparseVectorIndexOptionsValidations() throws IOException { + Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { + b.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + b.field(INFERENCE_ID_FIELD, "test_inference_id"); + b.startObject(INDEX_OPTIONS_FIELD); + { + b.startObject(SparseVectorFieldMapper.CONTENT_TYPE); + { + b.field("prune", false); + b.startObject("pruning_config"); + { + b.field(TokenPruningConfig.TOKENS_FREQ_RATIO_THRESHOLD.getPreferredName(), 5.0f); + } + b.endObject(); + } + b.endObject(); + } + b.endObject(); + }), useLegacyFormat, IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT)); + assertThat(e.getMessage(), containsString("failed to parse field [pruning_config]")); + + e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { + b.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + b.field(INFERENCE_ID_FIELD, "test_inference_id"); + b.startObject(INDEX_OPTIONS_FIELD); + { + b.startObject(SparseVectorFieldMapper.CONTENT_TYPE); + { + b.field("prune", true); + b.startObject("pruning_config"); + { + b.field(TokenPruningConfig.TOKENS_FREQ_RATIO_THRESHOLD.getPreferredName(), 1000.0f); + } + b.endObject(); + } + b.endObject(); + } + b.endObject(); + }), useLegacyFormat, IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT)); + var innerClause = e.getCause().getCause().getCause().getCause(); + assertThat(innerClause.getMessage(), containsString("[tokens_freq_ratio_threshold] must be between [1] and [100], got 1000.0")); } public static SemanticTextIndexOptions randomSemanticTextIndexOptions() { @@ -1437,13 +1678,21 @@ public static SemanticTextIndexOptions randomSemanticTextIndexOptions() { } public static SemanticTextIndexOptions randomSemanticTextIndexOptions(TaskType taskType) { - if (taskType == TaskType.TEXT_EMBEDDING) { return randomBoolean() ? null : new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.DENSE_VECTOR, randomIndexOptionsAll()); } + if (taskType == TaskType.SPARSE_EMBEDDING) { + return randomBoolean() + ? null + : new SemanticTextIndexOptions( + SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, + randomSparseVectorIndexOptions(false) + ); + } + return null; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index c3b50cdb4a670..d5ba1859d8c5d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -52,13 +52,10 @@ public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities, int maxDimensions) { - var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null; - var dimensions = taskType == TaskType.TEXT_EMBEDDING - ? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions) - : null; - - SimilarityMeasure similarity = null; if (taskType == TaskType.TEXT_EMBEDDING) { + var elementType = randomFrom(DenseVectorFieldMapper.ElementType.values()); + var dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions); + List supportedSimilarities = new ArrayList<>( DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType) ); @@ -75,17 +72,30 @@ public static TestModel createRandomInstance(TaskType taskType, List