diff --git a/docs/changelog/132396.yaml b/docs/changelog/132396.yaml new file mode 100644 index 0000000000000..8645c199bb8a9 --- /dev/null +++ b/docs/changelog/132396.yaml @@ -0,0 +1,6 @@ +pr: 132396 +summary: DiskBBQ - Adapt `visitRatio` based on query - segment affinity in multi segment + scenario +area: Vector Search +type: enhancement +issues: [] diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 50c8fd107c444..ab2f91c3b538c 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -63,7 +63,6 @@ record CmdLineArgs( static final ParseField INDEX_TYPE_FIELD = new ParseField("index_type"); static final ParseField NUM_CANDIDATES_FIELD = new ParseField("num_candidates"); static final ParseField K_FIELD = new ParseField("k"); - // static final ParseField N_PROBE_FIELD = new ParseField("n_probe"); static final ParseField VISIT_PERCENTAGE_FIELD = new ParseField("visit_percentage"); static final ParseField IVF_CLUSTER_SIZE_FIELD = new ParseField("ivf_cluster_size"); static final ParseField OVER_SAMPLING_FACTOR_FIELD = new ParseField("over_sampling_factor"); @@ -98,7 +97,6 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD); PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD); PARSER.declareInt(Builder::setK, K_FIELD); - // PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD); PARSER.declareDoubleArray(Builder::setVisitPercentages, VISIT_PERCENTAGE_FIELD); PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD); PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD); @@ -134,7 +132,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT)); builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates); builder.field(K_FIELD.getPreferredName(), k); - // builder.field(N_PROBE_FIELD.getPreferredName(), nProbes); builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentages); builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize); builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index a2914682ac93f..1d90cf43a6c21 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -348,4 +348,12 @@ interface PostingVisitor { /** returns the number of scored documents */ int visit(KnnCollector collector) throws IOException; } + + public IndexInput getIvfCentroids(FieldInfo fieldInfo) throws IOException { + return fields.get(fieldInfo.number).centroidSlice(ivfCentroids); + } + + public float[] getGlobalCentroid(FieldInfo fieldInfo) { + return fields.get(fieldInfo.number).globalCentroid; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java index 00e083e0a6781..4ac04544dbf1f 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java @@ -11,10 +11,14 @@ import com.carrotsearch.hppc.IntHashSet; -import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.DocIdSetIterator; @@ -30,26 +34,37 @@ import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.Weight; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.codec.vectors.DefaultIVFVectorsReader; +import org.elasticsearch.index.codec.vectors.IVFVectorsReader; import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.concurrent.Callable; import java.util.concurrent.atomic.LongAccumulator; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.elasticsearch.search.vectors.AbstractMaxScoreKnnCollector.LEAST_COMPETITIVE; abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerProvider { static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; + private static final float MIN_VISIT_RATIO_FOR_AFFINITY_ADJUSTMENT = 0.004f; + private static final float MAX_AFFINITY_MULTIPLIER_ADJUSTMENT = 1.1f; + private static final float MIN_AFFINITY_MULTIPLIER_ADJUSTMENT = 0.75f; + private static final float MIN_AFFINITY = 0.001f; + private static final float MAX_AFFINITY = 1f; protected final String field; protected final float providedVisitRatio; @@ -125,30 +140,71 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { TaskExecutor taskExecutor = indexSearcher.getTaskExecutor(); List leafReaderContexts = reader.leaves(); + int totalDocsWVectors = 0; assert this instanceof IVFKnnFloatVectorQuery; - int totalVectors = 0; + int[] costs = new int[leafReaderContexts.size()]; + int i = 0; for (LeafReaderContext leafReaderContext : leafReaderContexts) { LeafReader leafReader = leafReaderContext.reader(); - FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(field); - if (floatVectorValues != null) { - totalVectors += floatVectorValues.size(); + FieldInfo fieldInfo = leafReader.getFieldInfos().fieldInfo(field); + VectorScorer scorer = createVectorScorer(leafReaderContext, fieldInfo); + int cost; + if (scorer != null) { + cost = (int) scorer.iterator().cost(); + totalDocsWVectors += cost; + } else { + cost = 0; } + costs[i] = cost; + i++; } final float visitRatio; if (providedVisitRatio == 0.0f) { // dynamically set the percentage float expected = (float) Math.round( - Math.log10(totalVectors) * Math.log10(totalVectors) * (Math.min(10_000, Math.max(numCands, 5 * k))) + Math.log10(totalDocsWVectors) * Math.log10(totalDocsWVectors) * (Math.min(10_000, Math.max(numCands, 5 * k))) ); - visitRatio = expected / totalVectors; + visitRatio = expected / totalDocsWVectors; } else { visitRatio = providedVisitRatio; } - List> tasks = new ArrayList<>(leafReaderContexts.size()); - for (LeafReaderContext context : leafReaderContexts) { - tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, visitRatio)); + List> tasks; + if (leafReaderContexts.isEmpty() == false) { + if (visitRatio > MIN_VISIT_RATIO_FOR_AFFINITY_ADJUSTMENT) { + // calculate the affinity of each segment to the query vector + List segmentAffinities = calculateSegmentAffinities(leafReaderContexts, getQueryVector(), costs); + segmentAffinities.sort((a, b) -> Double.compare(b.affinityScore(), a.affinityScore())); + + if (filterWeight != null // TODO : enable affinity optimization for filtered case + || leafReaderContexts.size() == 1) { + tasks = new ArrayList<>(leafReaderContexts.size()); + for (LeafReaderContext context : leafReaderContexts) { + tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, visitRatio)); + } + } else { + tasks = new ArrayList<>(segmentAffinities.size()); + for (SegmentAffinity segmentAffinity : segmentAffinities) { + double affinityScore = segmentAffinity.affinityScore; + + float adjustedVisitRatio = adjustVisitRatioForSegment( + affinityScore, + segmentAffinities.get(segmentAffinities.size() / 10).affinityScore, + visitRatio + ); + + tasks.add(() -> searchLeaf(segmentAffinity.context(), filterWeight, knnCollectorManager, adjustedVisitRatio)); + } + } + } else { + tasks = new ArrayList<>(leafReaderContexts.size()); + for (LeafReaderContext context : leafReaderContexts) { + tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, visitRatio)); + } + } + } else { + tasks = Collections.emptyList(); } TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new); @@ -161,6 +217,89 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return new KnnScoreDocQuery(topK.scoreDocs, reader); } + private float adjustVisitRatioForSegment(double affinityScore, double affinityThreshold, float visitRatio) { + // for high affinity scores, increase visited ratio + if (affinityScore > affinityThreshold) { + double adjustment = Math.min(1 + (affinityScore - affinityThreshold), MAX_AFFINITY_MULTIPLIER_ADJUSTMENT); + return Math.min((float) (visitRatio * adjustment), MAX_AFFINITY); + } + + // for low affinity scores, decrease visited ratio + if (affinityScore < affinityThreshold) { + double adjustment = Math.max(1 - (affinityThreshold - affinityScore), MIN_AFFINITY_MULTIPLIER_ADJUSTMENT); + return (float) Math.max(visitRatio * adjustment, MIN_AFFINITY); + } + + return visitRatio; + } + + abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException; + + abstract float[] getQueryVector() throws IOException; + + private IVFVectorsReader unwrapReader(KnnVectorsReader knnVectorsReader) { + IVFVectorsReader result = null; + if (knnVectorsReader instanceof DefaultIVFVectorsReader IVFVectorsReader) { + result = IVFVectorsReader; + } else if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader r) { + KnnVectorsReader fieldReader = r.getFieldReader(field); + if (fieldReader != null) { + result = unwrapReader(fieldReader); + } + } + return result; + } + + private List calculateSegmentAffinities(List leafReaderContexts, float[] queryVector, int[] costs) { + List segmentAffinities = new ArrayList<>(leafReaderContexts.size()); + + int i = 0; + for (LeafReaderContext context : leafReaderContexts) { + LeafReader leafReader = context.reader(); + FieldInfo fieldInfo = leafReader.getFieldInfos().fieldInfo(field); + if (fieldInfo == null) { + continue; + } + VectorSimilarityFunction similarityFunction = fieldInfo.getVectorSimilarityFunction(); + if (leafReader instanceof SegmentReader segmentReader) { + KnnVectorsReader vectorReader = segmentReader.getVectorReader(); + IVFVectorsReader reader = unwrapReader(vectorReader); + if (reader != null) { + float[] globalCentroid = reader.getGlobalCentroid(fieldInfo); + + if (similarityFunction == COSINE) { + VectorUtil.l2normalize(queryVector); + } + + if (queryVector.length != fieldInfo.getVectorDimension()) { + throw new IllegalArgumentException( + "vector query dimension: " + + queryVector.length + + " differs from field dimension: " + + fieldInfo.getVectorDimension() + ); + } + + float centroidsScore = similarityFunction.compare(queryVector, globalCentroid); + + int numVectors = costs[i]; + + // TODO : we may want to include some actual centroids' scores for higher quality estimate + double affinityScore = centroidsScore * (Math.log10(numVectors)); + + segmentAffinities.add(new SegmentAffinity(context, affinityScore)); + } else { + segmentAffinities.add(new SegmentAffinity(context, Float.NaN)); + } + } + i++; + } + + return segmentAffinities; + } + + private record SegmentAffinity(LeafReaderContext context, double affinityScore) {} + private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, IVFCollectorManager knnCollectorManager, float visitRatio) throws IOException { TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager, visitRatio); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQuery.java index da452ecc992db..aaf1019a8873b 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQuery.java @@ -8,11 +8,13 @@ */ package org.elasticsearch.search.vectors; +import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.Bits; import java.io.IOException; @@ -70,6 +72,17 @@ public int hashCode() { return result; } + @Override + VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { + LeafReader reader = context.reader(); + FloatVectorValues vectorValues = reader.getFloatVectorValues(field); + if (vectorValues == null) { + FloatVectorValues.checkField(reader, field); + return null; + } + return vectorValues.scorer(query); + } + @Override protected TopDocs approximateSearch( LeafReaderContext context, @@ -97,4 +110,9 @@ protected TopDocs approximateSearch( TopDocs results = knnCollector.topDocs(); return results != null ? results : NO_RESULTS; } + + @Override + float[] getQueryVector() { + return query; + } }