diff --git a/docs/changelog/126876.yaml b/docs/changelog/126876.yaml new file mode 100644 index 0000000000000..895af10840d84 --- /dev/null +++ b/docs/changelog/126876.yaml @@ -0,0 +1,5 @@ +pr: 126876 +summary: Improve HNSW filtered search speed through new heuristic +area: Vector Search +type: enhancement +issues: [] diff --git a/docs/reference/elasticsearch/index-settings/index-modules.md b/docs/reference/elasticsearch/index-settings/index-modules.md index 682a6fa2a39d4..4ab35b9d80a88 100644 --- a/docs/reference/elasticsearch/index-settings/index-modules.md +++ b/docs/reference/elasticsearch/index-settings/index-modules.md @@ -249,6 +249,12 @@ $$$index-final-pipeline$$$ $$$index-hidden$$$ `index.hidden` : Indicates whether the index should be hidden by default. Hidden indices are not returned by default when using a wildcard expression. This behavior is controlled per request through the use of the `expand_wildcards` parameter. Possible values are `true` and `false` (default). +$$$index-dense-vector-hnsw-filter-heuristic$$$ `index.dense_vector.hnsw_filter_heuristic` +: The heuristic to utilize when executing a filtered search against vectors in an HNSW graph. This setting is in technical preview may be changed or removed in a future release. It can be set to: + +* `acorn` (default) - Only vectors that match the filter criteria are searched. This is the fastest option, and generally provides faster searches at similar recall to `fanout`, but `num_candidates` might need to be increased for exceptionally high recall requirements. +* `fanout` - All vectors are compared with the query vector, but only those passing the criteria are added to the search results. Can be slower than `acorn`, but may yield higher recall. + $$$index-esql-stored-fields-sequential-proportion$$$ `index.esql.stored_fields_sequential_proportion` diff --git a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java index a84bc7d00578c..a4b239c10ba6a 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java @@ -36,6 +36,7 @@ import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.index.store.FsDirectoryFactory; import org.elasticsearch.index.store.Store; @@ -157,6 +158,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings { IndexSettings.INDEX_TRANSLOG_RETENTION_AGE_SETTING, IndexSettings.INDEX_TRANSLOG_RETENTION_SIZE_SETTING, IndexSettings.INDEX_SEARCH_IDLE_AFTER, + DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC, IndexFieldDataService.INDEX_FIELDDATA_CACHE_KEY, IndexSettings.IGNORE_ABOVE_SETTING, FieldMapper.IGNORE_MALFORMED_SETTING, diff --git a/server/src/main/java/org/elasticsearch/index/IndexSettings.java b/server/src/main/java/org/elasticsearch/index/IndexSettings.java index e7ff0f6d1e137..7a3b2ad938d1b 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexSettings.java +++ b/server/src/main/java/org/elasticsearch/index/IndexSettings.java @@ -29,6 +29,7 @@ import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.SourceFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.translog.Translog; import org.elasticsearch.indices.recovery.RecoverySettings; import org.elasticsearch.ingest.IngestService; @@ -896,6 +897,7 @@ private void setRetentionLeaseMillis(final TimeValue retentionLease) { private volatile int maxTokenCount; private volatile int maxNgramDiff; private volatile int maxShingleDiff; + private volatile DenseVectorFieldMapper.FilterHeuristic hnswFilterHeuristic; private volatile TimeValue searchIdleAfter; private volatile int maxAnalyzedOffset; private volatile boolean weightMatchesEnabled; @@ -1091,6 +1093,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti logsdbAddHostNameField = scopedSettings.get(LOGSDB_ADD_HOST_NAME_FIELD); skipIgnoredSourceWrite = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING); skipIgnoredSourceRead = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING); + hnswFilterHeuristic = scopedSettings.get(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC); indexMappingSourceMode = scopedSettings.get(INDEX_MAPPER_SOURCE_MODE_SETTING); recoverySourceEnabled = RecoverySettings.INDICES_RECOVERY_SOURCE_ENABLED_SETTING.get(nodeSettings); recoverySourceSyntheticEnabled = DiscoveryNode.isStateless(nodeSettings) == false @@ -1203,6 +1206,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti this::setSkipIgnoredSourceWrite ); scopedSettings.addSettingsUpdateConsumer(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, this::setSkipIgnoredSourceRead); + scopedSettings.addSettingsUpdateConsumer(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC, this::setHnswFilterHeuristic); } private void setSearchIdleAfter(TimeValue searchIdleAfter) { @@ -1821,4 +1825,16 @@ public TimestampBounds getTimestampBounds() { public IndexRouting getIndexRouting() { return indexRouting; } + + /** + * The heuristic to utilize when executing filtered search on vectors indexed + * in HNSW format. + */ + public DenseVectorFieldMapper.FilterHeuristic getHnswFilterHeuristic() { + return this.hnswFilterHeuristic; + } + + private void setHnswFilterHeuristic(DenseVectorFieldMapper.FilterHeuristic heuristic) { + this.hnswFilterHeuristic = heuristic; + } } diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 3bc3c82288122..4191e2a5a4598 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -166,6 +166,7 @@ private static Version parseUnchecked(String version) { public static final IndexVersion UPGRADE_TO_LUCENE_10_2_1 = def(9_023_00_0, Version.LUCENE_10_2_1); public static final IndexVersion DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ = def(9_024_0_00, Version.LUCENE_10_2_1); public static final IndexVersion SEMANTIC_TEXT_DEFAULTS_TO_BBQ = def(9_025_0_00, Version.LUCENE_10_2_1); + public static final IndexVersion DEFAULT_TO_ACORN_HNSW_FILTER_HEURISTIC = def(9_026_0_00, Version.LUCENE_10_2_1); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index ea889d69b7304..95caaa6ccf316 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -33,10 +33,12 @@ import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexVersion; @@ -93,6 +95,7 @@ import java.util.function.Supplier; import java.util.stream.Stream; +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_INDEX_VERSION_CREATED; import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; @@ -108,6 +111,51 @@ public static boolean isNotUnitVector(float magnitude) { return Math.abs(magnitude - 1.0f) > EPS; } + /** + * The heuristic to utilize when executing a filtered search against vectors indexed in an HNSW graph. + */ + public enum FilterHeuristic { + /** + * This heuristic searches the entire graph, doing vector comparisons in all immediate neighbors + * but only collects vectors that match the filtering criteria. + */ + FANOUT { + static final KnnSearchStrategy FANOUT_STRATEGY = new KnnSearchStrategy.Hnsw(0); + + @Override + public KnnSearchStrategy getKnnSearchStrategy() { + return FANOUT_STRATEGY; + } + }, + /** + * This heuristic will only compare vectors that match the filtering criteria. + */ + ACORN { + static final KnnSearchStrategy ACORN_STRATEGY = new KnnSearchStrategy.Hnsw(60); + + @Override + public KnnSearchStrategy getKnnSearchStrategy() { + return ACORN_STRATEGY; + } + }; + + public abstract KnnSearchStrategy getKnnSearchStrategy(); + } + + public static final Setting HNSW_FILTER_HEURISTIC = Setting.enumSetting(FilterHeuristic.class, s -> { + IndexVersion version = SETTING_INDEX_VERSION_CREATED.get(s); + if (version.onOrAfter(IndexVersions.DEFAULT_TO_ACORN_HNSW_FILTER_HEURISTIC)) { + return FilterHeuristic.ACORN.toString(); + } + return FilterHeuristic.FANOUT.toString(); + }, + "index.dense_vector.hnsw_filter_heuristic", + fh -> {}, + Setting.Property.IndexScope, + Setting.Property.ServerlessPublic, + Setting.Property.Dynamic + ); + private static boolean hasRescoreIndexVersion(IndexVersion version) { return version.onOrAfter(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS) || version.between(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS_BACKPORT_8_X, IndexVersions.UPGRADE_TO_LUCENE_10_0_0); @@ -2210,15 +2258,25 @@ public Query createKnnQuery( Float oversample, Query filter, Float similarityThreshold, - BitSetProducer parentFilter + BitSetProducer parentFilter, + DenseVectorFieldMapper.FilterHeuristic heuristic ) { if (isIndexed() == false) { throw new IllegalArgumentException( "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" ); } + KnnSearchStrategy knnSearchStrategy = heuristic.getKnnSearchStrategy(); return switch (getElementType()) { - case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); + case BYTE -> createKnnByteQuery( + queryVector.asByteVector(), + k, + numCands, + filter, + similarityThreshold, + parentFilter, + knnSearchStrategy + ); case FLOAT -> createKnnFloatQuery( queryVector.asFloatVector(), k, @@ -2226,9 +2284,18 @@ public Query createKnnQuery( oversample, filter, similarityThreshold, - parentFilter + parentFilter, + knnSearchStrategy + ); + case BIT -> createKnnBitQuery( + queryVector.asByteVector(), + k, + numCands, + filter, + similarityThreshold, + parentFilter, + knnSearchStrategy ); - case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); }; } @@ -2246,12 +2313,13 @@ private Query createKnnBitQuery( int numCands, Query filter, Float similarityThreshold, - BitSetProducer parentFilter + BitSetProducer parentFilter, + KnnSearchStrategy searchStrategy ) { elementType.checkDimensions(dims, queryVector.length); Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) - : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter); + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) + : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2268,7 +2336,8 @@ private Query createKnnByteQuery( int numCands, Query filter, Float similarityThreshold, - BitSetProducer parentFilter + BitSetProducer parentFilter, + KnnSearchStrategy searchStrategy ) { elementType.checkDimensions(dims, queryVector.length); @@ -2277,8 +2346,8 @@ private Query createKnnByteQuery( elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) - : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter); + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) + : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2296,7 +2365,8 @@ private Query createKnnFloatQuery( Float queryOversample, Query filter, Float similarityThreshold, - BitSetProducer parentFilter + BitSetProducer parentFilter, + KnnSearchStrategy knnSearchStrategy ) { elementType.checkDimensions(dims, queryVector.length); elementType.checkVectorBounds(queryVector); @@ -2330,8 +2400,16 @@ && isNotUnitVector(squaredMagnitude)) { numCands = Math.max(adjustedK, numCands); } Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, numCands, parentFilter) - : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter); + ? new ESDiversifyingChildrenFloatKnnVectorQuery( + name(), + queryVector, + filter, + adjustedK, + numCands, + parentFilter, + knnSearchStrategy + ) + : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy); if (rescore) { knnQuery = new RescoreKnnVectorQuery( name(), diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java index b7f129f674036..46b2f0a09cf7f 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java @@ -13,6 +13,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.search.profile.query.QueryProfiler; public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider { @@ -25,9 +26,10 @@ public ESDiversifyingChildrenByteKnnVectorQuery( Query childFilter, Integer k, int numCands, - BitSetProducer parentsFilter + BitSetProducer parentsFilter, + KnnSearchStrategy strategy ) { - super(field, query, childFilter, numCands, parentsFilter); + super(field, query, childFilter, numCands, parentsFilter, strategy); this.kParam = k; } @@ -42,4 +44,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { public void profile(QueryProfiler queryProfiler) { queryProfiler.addVectorOpsCount(vectorOpsCount); } + + public KnnSearchStrategy getStrategy() { + return searchStrategy; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java index cb323bbe3932a..5635281ab0e8a 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java @@ -13,6 +13,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.search.profile.query.QueryProfiler; public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements QueryProfilerProvider { @@ -25,9 +26,10 @@ public ESDiversifyingChildrenFloatKnnVectorQuery( Query childFilter, Integer k, int numCands, - BitSetProducer parentsFilter + BitSetProducer parentsFilter, + KnnSearchStrategy strategy ) { - super(field, query, childFilter, numCands, parentsFilter); + super(field, query, childFilter, numCands, parentsFilter, strategy); this.kParam = k; } @@ -42,4 +44,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { public void profile(QueryProfiler queryProfiler) { queryProfiler.addVectorOpsCount(vectorOpsCount); } + + public KnnSearchStrategy getStrategy() { + return searchStrategy; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java index 5c199f42093b1..295efd8f9b05e 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java @@ -12,14 +12,15 @@ import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.search.profile.query.QueryProfiler; public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; - public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter) { - super(field, target, numCands, filter); + public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) { + super(field, target, numCands, filter, strategy); this.kParam = k; } @@ -39,4 +40,8 @@ public void profile(QueryProfiler queryProfiler) { public Integer kParam() { return kParam; } + + public KnnSearchStrategy getStrategy() { + return searchStrategy; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java index b7b9d092ceeac..8ef4aad147049 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java @@ -12,14 +12,15 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.search.profile.query.QueryProfiler; public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; - public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter) { - super(field, target, numCands, filter); + public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) { + super(field, target, numCands, filter, strategy); this.kParam = k; } @@ -39,4 +40,8 @@ public void profile(QueryProfiler queryProfiler) { public Integer kParam() { return kParam; } + + public KnnSearchStrategy getStrategy() { + return searchStrategy; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 565fd7325a5ac..87f9a50c64c17 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -552,8 +552,17 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet); } } - - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, oversample, filterQuery, vectorSimilarity, parentBitSet); + DenseVectorFieldMapper.FilterHeuristic heuristic = context.getIndexSettings().getHnswFilterHeuristic(); + return vectorFieldType.createKnnQuery( + queryVector, + k, + adjustedNumCands, + oversample, + filterQuery, + vectorSimilarity, + parentBitSet, + heuristic + ); } @Override diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 382f22851b45a..c1c21ccda580a 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -1881,7 +1881,16 @@ public void testByteVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 128, 0, 0 }), + 3, + 3, + null, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ) ); assertThat( e.getMessage(), @@ -1897,7 +1906,8 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - null + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ) ); assertThat( @@ -1907,7 +1917,16 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), + 3, + 3, + null, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ) ); assertThat( e.getMessage(), @@ -1916,7 +1935,16 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), + 3, + 3, + null, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ) ); assertThat( e.getMessage(), @@ -1932,7 +1960,8 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - null + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ) ); assertThat(e.getMessage(), containsString("element_type [byte] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -1946,7 +1975,8 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - null + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ) ); assertThat( @@ -1963,7 +1993,8 @@ public void testByteVectorQueryBoundaries() throws IOException { null, null, null, - null + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ) ); assertThat( @@ -1997,7 +2028,8 @@ public void testFloatVectorQueryBoundaries() throws IOException { null, null, null, - null + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ) ); assertThat(e.getMessage(), containsString("element_type [float] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -2011,7 +2043,8 @@ public void testFloatVectorQueryBoundaries() throws IOException { null, null, null, - null + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ) ); assertThat( @@ -2028,7 +2061,8 @@ public void testFloatVectorQueryBoundaries() throws IOException { null, null, null, - null + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ) ); assertThat( diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index b6df46d17b598..5877ce9003ff5 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -15,6 +15,8 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.mapper.FieldTypeTestCase; @@ -32,8 +34,10 @@ import java.util.Collections; import java.util.List; import java.util.Set; +import java.util.function.Function; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BIT; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; @@ -216,7 +220,16 @@ public void testCreateNestedKnnQuery() { for (int i = 0; i < dims; i++) { queryVector[i] = randomFloat(); } - Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer); + Query query = field.createKnnQuery( + VectorData.fromFloats(queryVector), + 10, + 10, + null, + null, + null, + producer, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ); if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { query = rescoreKnnVectorQuery.innerQuery(); } @@ -240,11 +253,29 @@ public void testCreateNestedKnnQuery() { floatQueryVector[i] = queryVector[i]; } VectorData vectorData = new VectorData(null, queryVector); - Query query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); + Query query = field.createKnnQuery( + vectorData, + 10, + 10, + null, + null, + null, + producer, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); vectorData = new VectorData(floatQueryVector, null); - query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); + query = field.createKnnQuery( + vectorData, + 10, + 10, + null, + null, + null, + producer, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); } } @@ -312,7 +343,8 @@ public void testFloatCreateKnnQuery() { null, null, null, - null + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -333,7 +365,16 @@ public void testFloatCreateKnnQuery() { } e = expectThrows( IllegalArgumentException.class, - () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null) + () -> dotProductField.createKnnQuery( + VectorData.fromFloats(queryVector), + 10, + 10, + null, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ) ); assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors.")); @@ -349,7 +390,16 @@ public void testFloatCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null, null) + () -> cosineField.createKnnQuery( + VectorData.fromFloats(new float[BBQ_MIN_DIMS]), + 10, + 10, + null, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -370,7 +420,16 @@ public void testCreateKnnQueryMaxDims() { for (int i = 0; i < 4096; i++) { queryVector[i] = randomFloat(); } - Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null); + Query query = fieldWith4096dims.createKnnQuery( + VectorData.fromFloats(queryVector), + 10, + 10, + null, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ); if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { query = rescoreKnnVectorQuery.innerQuery(); } @@ -393,7 +452,16 @@ public void testCreateKnnQueryMaxDims() { queryVector[i] = randomByte(); } VectorData vectorData = new VectorData(null, queryVector); - Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null, null); + Query query = fieldWith4096dims.createKnnQuery( + vectorData, + 10, + 10, + null, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ); assertThat(query, instanceOf(KnnByteVectorQuery.class)); } } @@ -411,7 +479,16 @@ public void testByteCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null, null) + () -> unindexedField.createKnnQuery( + VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), + 10, + 10, + null, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -427,13 +504,31 @@ public void testByteCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null, null) + () -> cosineField.createKnnQuery( + VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), + 10, + 10, + null, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null, null) + () -> cosineField.createKnnQuery( + new VectorData(null, new byte[] { 0, 0, 0 }), + 10, + 10, + null, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -458,7 +553,8 @@ public void testRescoreOversampleUsedWithoutQuantization() { randomFloatBetween(1.0F, 10.0F, false), null, null, - null + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); if (elementType == BYTE) { @@ -504,7 +600,16 @@ public void testRescoreOversampleQueryOverrides() { randomIndexOptionsHnswQuantized(new DenseVectorFieldMapper.RescoreVector(randomFloatBetween(1.1f, 9.9f, false))), Collections.emptyMap() ); - Query query = fieldType.createKnnQuery(VectorData.fromFloats(new float[] { 1, 4, 10 }), 10, 100, 0f, null, null, null); + Query query = fieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 1, 4, 10 }), + 10, + 100, + 0f, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ); assertTrue(query instanceof ESKnnFloatVectorQuery); // verify we can override a `0` to a positive number @@ -518,7 +623,16 @@ public void testRescoreOversampleQueryOverrides() { randomIndexOptionsHnswQuantized(new DenseVectorFieldMapper.RescoreVector(0)), Collections.emptyMap() ); - query = fieldType.createKnnQuery(VectorData.fromFloats(new float[] { 1, 4, 10 }), 10, 100, 2f, null, null, null); + query = fieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 1, 4, 10 }), + 10, + 100, + 2f, + null, + null, + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) + ); assertTrue(query instanceof RescoreKnnVectorQuery); assertThat(((RescoreKnnVectorQuery) query).k(), equalTo(10)); ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) ((RescoreKnnVectorQuery) query).innerQuery(); @@ -526,6 +640,55 @@ public void testRescoreOversampleQueryOverrides() { } + public void testFilterSearchThreshold() { + List>> cases = List.of( + Tuple.tuple(FLOAT, q -> ((ESKnnFloatVectorQuery) q).getStrategy()), + Tuple.tuple(BYTE, q -> ((ESKnnByteVectorQuery) q).getStrategy()), + Tuple.tuple(BIT, q -> ((ESKnnByteVectorQuery) q).getStrategy()) + ); + for (var tuple : cases) { + DenseVectorFieldType fieldType = new DenseVectorFieldType( + "f", + IndexVersion.current(), + tuple.v1(), + tuple.v1() == BIT ? 3 * 8 : 3, + true, + VectorSimilarity.COSINE, + randomIndexOptionsHnswQuantized(), + Collections.emptyMap() + ); + + // Test with a filter search threshold + Query query = fieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 1, 4, 10 }), + 10, + 100, + 0f, + null, + null, + null, + DenseVectorFieldMapper.FilterHeuristic.FANOUT + ); + KnnSearchStrategy strategy = tuple.v2().apply(query); + assertTrue(strategy instanceof KnnSearchStrategy.Hnsw); + assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(0)); + + query = fieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 1, 4, 10 }), + 10, + 100, + 0f, + null, + null, + null, + DenseVectorFieldMapper.FilterHeuristic.ACORN + ); + strategy = tuple.v2().apply(query); + assertTrue(strategy instanceof KnnSearchStrategy.Hnsw); + assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(60)); + } + } + private static void checkRescoreQueryParameters( DenseVectorFieldType fieldType, int k, @@ -542,7 +705,8 @@ private static void checkRescoreQueryParameters( oversample, null, null, - null + null, + randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery(); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 9499edc71b4a6..27549b3c4030b 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -13,6 +13,7 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.support.PlainActionFuture; @@ -21,6 +22,7 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.query.InnerHitsRewriteContext; @@ -216,9 +218,29 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que numCands = Math.max(numCands, k); } + final KnnSearchStrategy expectedStrategy = context.getIndexSettings() + .getIndexVersionCreated() + .onOrAfter(IndexVersions.DEFAULT_TO_ACORN_HNSW_FILTER_HEURISTIC) + ? DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy() + : DenseVectorFieldMapper.FilterHeuristic.FANOUT.getKnnSearchStrategy(); + Query knnVectorQueryBuilt = switch (elementType()) { - case BYTE, BIT -> new ESKnnByteVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asByteVector(), k, numCands, filterQuery); - case FLOAT -> new ESKnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, numCands, filterQuery); + case BYTE, BIT -> new ESKnnByteVectorQuery( + VECTOR_FIELD, + queryBuilder.queryVector().asByteVector(), + k, + numCands, + filterQuery, + expectedStrategy + ); + case FLOAT -> new ESKnnFloatVectorQuery( + VECTOR_FIELD, + queryBuilder.queryVector().asFloatVector(), + k, + numCands, + filterQuery, + expectedStrategy + ); }; if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) { query = vectorSimilarityQuery.getInnerKnnQuery(); diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java index 8b80935ca4df5..18e125a7ae1ce 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java @@ -38,6 +38,7 @@ import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.engine.EngineConfig; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesRequestCache; import org.elasticsearch.indices.IndicesService; @@ -529,7 +530,8 @@ static String[] extractLeaderShardHistoryUUIDs(Map ccrIndexMetad EngineConfig.INDEX_CODEC_SETTING, DataTier.TIER_PREFERENCE_SETTING, IndexSettings.BLOOM_FILTER_ID_FIELD_ENABLED_SETTING, - MetadataIndexStateService.VERIFIED_READ_ONLY_SETTING + MetadataIndexStateService.VERIFIED_READ_ONLY_SETTING, + DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC ); public static Settings filter(Settings originalSettings) {