Skip to content

Commit 15ae296

Browse files
authored
[DiskBBQ] Replace n_probe, related to the number of centroids with visit_percentage, related to the number of documents (elastic#132722)
This commit changes the way we budget how many vectors we are going to score during search.
1 parent 97a6dc8 commit 15ae296

File tree

14 files changed

+138
-99
lines changed

14 files changed

+138
-99
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ record CmdLineArgs(
3636
KnnIndexTester.IndexType indexType,
3737
int numCandidates,
3838
int k,
39-
int[] nProbes,
39+
double[] visitPercentages,
4040
int ivfClusterSize,
4141
int overSamplingFactor,
4242
int hnswM,
@@ -63,7 +63,8 @@ record CmdLineArgs(
6363
static final ParseField INDEX_TYPE_FIELD = new ParseField("index_type");
6464
static final ParseField NUM_CANDIDATES_FIELD = new ParseField("num_candidates");
6565
static final ParseField K_FIELD = new ParseField("k");
66-
static final ParseField N_PROBE_FIELD = new ParseField("n_probe");
66+
// static final ParseField N_PROBE_FIELD = new ParseField("n_probe");
67+
static final ParseField VISIT_PERCENTAGE_FIELD = new ParseField("visit_percentage");
6768
static final ParseField IVF_CLUSTER_SIZE_FIELD = new ParseField("ivf_cluster_size");
6869
static final ParseField OVER_SAMPLING_FACTOR_FIELD = new ParseField("over_sampling_factor");
6970
static final ParseField HNSW_M_FIELD = new ParseField("hnsw_m");
@@ -97,7 +98,8 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
9798
PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD);
9899
PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD);
99100
PARSER.declareInt(Builder::setK, K_FIELD);
100-
PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD);
101+
// PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD);
102+
PARSER.declareDoubleArray(Builder::setVisitPercentages, VISIT_PERCENTAGE_FIELD);
101103
PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD);
102104
PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD);
103105
PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD);
@@ -132,7 +134,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
132134
builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT));
133135
builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates);
134136
builder.field(K_FIELD.getPreferredName(), k);
135-
builder.field(N_PROBE_FIELD.getPreferredName(), nProbes);
137+
// builder.field(N_PROBE_FIELD.getPreferredName(), nProbes);
138+
builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentages);
136139
builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize);
137140
builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor);
138141
builder.field(HNSW_M_FIELD.getPreferredName(), hnswM);
@@ -165,7 +168,7 @@ static class Builder {
165168
private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW;
166169
private int numCandidates = 1000;
167170
private int k = 10;
168-
private int[] nProbes = new int[] { 10 };
171+
private double[] visitPercentages = new double[] { 1.0 };
169172
private int ivfClusterSize = 1000;
170173
private int overSamplingFactor = 1;
171174
private int hnswM = 16;
@@ -223,8 +226,8 @@ public Builder setK(int k) {
223226
return this;
224227
}
225228

226-
public Builder setNProbe(List<Integer> nProbes) {
227-
this.nProbes = nProbes.stream().mapToInt(Integer::intValue).toArray();
229+
public Builder setVisitPercentages(List<Double> visitPercentages) {
230+
this.visitPercentages = visitPercentages.stream().mapToDouble(Double::doubleValue).toArray();
228231
return this;
229232
}
230233

@@ -330,7 +333,7 @@ public CmdLineArgs build() {
330333
indexType,
331334
numCandidates,
332335
k,
333-
nProbes,
336+
visitPercentages,
334337
ivfClusterSize,
335338
overSamplingFactor,
336339
hnswM,

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,18 +191,18 @@ public static void main(String[] args) throws Exception {
191191
FormattedResults formattedResults = new FormattedResults();
192192

193193
for (CmdLineArgs cmdLineArgs : cmdLineArgsList) {
194-
int[] nProbes = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0
195-
? cmdLineArgs.nProbes()
196-
: new int[] { 0 };
194+
double[] visitPercentages = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0
195+
? cmdLineArgs.visitPercentages()
196+
: new double[] { 0 };
197197
String indexType = cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT);
198198
Results indexResults = new Results(
199199
cmdLineArgs.docVectors().get(0).getFileName().toString(),
200200
indexType,
201201
cmdLineArgs.numDocs(),
202202
cmdLineArgs.filterSelectivity()
203203
);
204-
Results[] results = new Results[nProbes.length];
205-
for (int i = 0; i < nProbes.length; i++) {
204+
Results[] results = new Results[visitPercentages.length];
205+
for (int i = 0; i < visitPercentages.length; i++) {
206206
results[i] = new Results(
207207
cmdLineArgs.docVectors().get(0).getFileName().toString(),
208208
indexType,
@@ -240,8 +240,7 @@ public static void main(String[] args) throws Exception {
240240
numSegments(indexPath, indexResults);
241241
if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
242242
for (int i = 0; i < results.length; i++) {
243-
int nProbe = nProbes[i];
244-
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe);
243+
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, visitPercentages[i]);
245244
knnSearcher.runSearch(results[i], cmdLineArgs.earlyTermination());
246245
}
247246
}
@@ -293,7 +292,7 @@ public String toString() {
293292
String[] searchHeaders = {
294293
"index_name",
295294
"index_type",
296-
"n_probe",
295+
"visit_percentage(%)",
297296
"latency(ms)",
298297
"net_cpu_time(ms)",
299298
"avg_cpu_count",
@@ -324,7 +323,7 @@ public String toString() {
324323
queryResultsArray[i] = new String[] {
325324
queryResult.indexName,
326325
queryResult.indexType,
327-
Integer.toString(queryResult.nProbe),
326+
String.format(Locale.ROOT, "%.2f", queryResult.visitPercentage),
328327
String.format(Locale.ROOT, "%.2f", queryResult.avgLatency),
329328
String.format(Locale.ROOT, "%.2f", queryResult.netCpuTimeMS),
330329
String.format(Locale.ROOT, "%.2f", queryResult.avgCpuCount),
@@ -400,7 +399,7 @@ static class Results {
400399
long indexTimeMS;
401400
long forceMergeTimeMS;
402401
int numSegments;
403-
int nProbe;
402+
double visitPercentage;
404403
double avgLatency;
405404
double qps;
406405
double avgRecall;

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class KnnSearcher {
107107
private final float selectivity;
108108
private final int topK;
109109
private final int efSearch;
110-
private final int nProbe;
110+
private final double visitPercentage;
111111
private final KnnIndexTester.IndexType indexType;
112112
private int dim;
113113
private final VectorSimilarityFunction similarityFunction;
@@ -116,7 +116,7 @@ class KnnSearcher {
116116
private final int searchThreads;
117117
private final int numSearchers;
118118

119-
KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) {
119+
KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, double visitPercentage) {
120120
this.docPath = cmdLineArgs.docVectors();
121121
this.indexPath = indexPath;
122122
this.queryPath = cmdLineArgs.queryVectors();
@@ -131,7 +131,7 @@ class KnnSearcher {
131131
throw new IllegalArgumentException("numQueryVectors must be > 0");
132132
}
133133
this.efSearch = cmdLineArgs.numCandidates();
134-
this.nProbe = nProbe;
134+
this.visitPercentage = visitPercentage;
135135
this.indexType = cmdLineArgs.indexType();
136136
this.searchThreads = cmdLineArgs.searchThreads();
137137
this.numSearchers = cmdLineArgs.numSearchers();
@@ -298,7 +298,7 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
298298
}
299299
logger.info("checking results");
300300
int[][] nn = getOrCalculateExactNN(offsetByteSize, filterQuery);
301-
finalResults.nProbe = indexType == KnnIndexTester.IndexType.IVF ? nProbe : 0;
301+
finalResults.visitPercentage = indexType == KnnIndexTester.IndexType.IVF ? visitPercentage : 0;
302302
finalResults.avgRecall = checkResults(resultIds, nn, topK);
303303
finalResults.qps = (1000f * numQueryVectors) / elapsed;
304304
finalResults.avgLatency = (float) elapsed / numQueryVectors;
@@ -424,7 +424,8 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery,
424424
}
425425
int efSearch = Math.max(topK, this.efSearch);
426426
if (indexType == KnnIndexTester.IndexType.IVF) {
427-
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, filterQuery, nProbe);
427+
float visitRatio = (float) (visitPercentage / 100);
428+
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, filterQuery, visitRatio);
428429
} else {
429430
knnQuery = new ESKnnFloatVectorQuery(
430431
VECTOR_FIELD,

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ public class IVFVectorsFormat extends KnnVectorsFormat {
6060
);
6161

6262
// This dynamically sets the cluster probe based on the `k` requested and the number of clusters.
63-
// useful when searching with 'efSearch' type parameters instead of requiring a specific nprobe.
64-
public static final int DYNAMIC_NPROBE = -1;
63+
// useful when searching with 'efSearch' type parameters instead of requiring a specific ratio.
64+
public static final float DYNAMIC_VISIT_RATIO = 0.0f;
6565
public static final int DEFAULT_VECTORS_PER_CLUSTER = 384;
6666
public static final int MIN_VECTORS_PER_CLUSTER = 64;
6767
public static final int MAX_VECTORS_PER_CLUSTER = 1 << 16; // 65536

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import java.io.IOException;
3636

3737
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
38-
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_NPROBE;
38+
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_VISIT_RATIO;
3939

4040
/**
4141
* Reader for IVF vectors. This reader is used to read the IVF vectors from the index.
@@ -222,34 +222,36 @@ public final void search(String field, float[] target, KnnCollector knnCollector
222222
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
223223
}
224224
int numVectors = rawVectorsReader.getFloatVectorValues(field).size();
225-
int nProbe = DYNAMIC_NPROBE;
225+
float visitRatio = DYNAMIC_VISIT_RATIO;
226226
// Search strategy may be null if this is being called from checkIndex (e.g. from a test)
227227
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
228-
nProbe = ivfSearchStrategy.getNProbe();
228+
visitRatio = ivfSearchStrategy.getVisitRatio();
229229
}
230230

231231
FieldEntry entry = fields.get(fieldInfo.number);
232-
if (nProbe == DYNAMIC_NPROBE) {
232+
if (visitRatio == DYNAMIC_VISIT_RATIO) {
233233
// empirically based, and a good dynamic to get decent recall while scaling a la "efSearch"
234-
// scaling by the number of centroids vs. the nearest neighbors requested
234+
// scaling by the number of vectors vs. the nearest neighbors requested
235235
// not perfect, but a comparative heuristic.
236-
// we might want to utilize the total vector count as well, but this is a good start
237-
nProbe = (int) Math.round(Math.log10(entry.numCentroids) * Math.sqrt(knnCollector.k()));
238-
// clip to be between 1 and the number of centroids
239-
nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1);
236+
// TODO: we might want to consider the density of the centroids as experiments shows that for fewer vectors per centroid,
237+
// the least vectors we need to score to get a good recall.
238+
float estimated = Math.round(Math.log10(numVectors) * Math.log10(numVectors) * (knnCollector.k()));
239+
// clip so we visit at least one vector
240+
visitRatio = estimated / numVectors;
240241
}
242+
// we account for soar vectors here. We can potentially visit a vector twice so we multiply by 2 here.
243+
long maxVectorVisited = (long) (2.0 * visitRatio * numVectors);
241244
CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target);
242245
PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, acceptDocs);
243-
int centroidsVisited = 0;
246+
244247
long expectedDocs = 0;
245248
long actualDocs = 0;
246249
// initially we visit only the "centroids to search"
247250
// Note, numCollected is doing the bare minimum here.
248251
// TODO do we need to handle nested doc counts similarly to how we handle
249252
// filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors?
250253
while (centroidIterator.hasNext()
251-
&& (centroidsVisited < nProbe || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) {
252-
++centroidsVisited;
254+
&& (maxVectorVisited > actualDocs || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) {
253255
// todo do we actually need to know the score???
254256
long offset = centroidIterator.nextPostingListOffset();
255257
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,18 +1693,22 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map<String, ?
16931693
if (rescoreVector == null) {
16941694
rescoreVector = new RescoreVector(DEFAULT_OVERSAMPLE);
16951695
}
1696-
Object nProbeNode = indexOptionsMap.remove("default_n_probe");
1697-
int nProbe = -1;
1698-
if (nProbeNode != null) {
1699-
nProbe = XContentMapValues.nodeIntegerValue(nProbeNode);
1700-
if (nProbe < 1 && nProbe != -1) {
1696+
Object visitPercentageNode = indexOptionsMap.remove("default_visit_percentage");
1697+
double visitPercentage = 0d;
1698+
if (visitPercentageNode != null) {
1699+
visitPercentage = (float) XContentMapValues.nodeDoubleValue(visitPercentageNode);
1700+
if (visitPercentage < 0d || visitPercentage > 100d) {
17011701
throw new IllegalArgumentException(
1702-
"default_n_probe must be at least 1 or exactly -1, got: " + nProbe + " for field [" + fieldName + "]"
1702+
"default_visit_percentage must be between 0.0 and 100.0, got: "
1703+
+ visitPercentage
1704+
+ " for field ["
1705+
+ fieldName
1706+
+ "]"
17031707
);
17041708
}
17051709
}
17061710
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
1707-
return new BBQIVFIndexOptions(clusterSize, nProbe, rescoreVector);
1711+
return new BBQIVFIndexOptions(clusterSize, visitPercentage, rescoreVector);
17081712
}
17091713

17101714
@Override
@@ -2297,12 +2301,12 @@ public boolean validateDimension(int dim, boolean throwOnError) {
22972301

22982302
static class BBQIVFIndexOptions extends QuantizedIndexOptions {
22992303
final int clusterSize;
2300-
final int defaultNProbe;
2304+
final double defaultVisitPercentage;
23012305

2302-
BBQIVFIndexOptions(int clusterSize, int defaultNProbe, RescoreVector rescoreVector) {
2306+
BBQIVFIndexOptions(int clusterSize, double defaultVisitPercentage, RescoreVector rescoreVector) {
23032307
super(VectorIndexType.BBQ_DISK, rescoreVector);
23042308
this.clusterSize = clusterSize;
2305-
this.defaultNProbe = defaultNProbe;
2309+
this.defaultVisitPercentage = defaultVisitPercentage;
23062310
}
23072311

23082312
@Override
@@ -2320,13 +2324,13 @@ public boolean updatableTo(DenseVectorIndexOptions update) {
23202324
boolean doEquals(DenseVectorIndexOptions other) {
23212325
BBQIVFIndexOptions that = (BBQIVFIndexOptions) other;
23222326
return clusterSize == that.clusterSize
2323-
&& defaultNProbe == that.defaultNProbe
2327+
&& defaultVisitPercentage == that.defaultVisitPercentage
23242328
&& Objects.equals(rescoreVector, that.rescoreVector);
23252329
}
23262330

23272331
@Override
23282332
int doHashCode() {
2329-
return Objects.hash(clusterSize, defaultNProbe, rescoreVector);
2333+
return Objects.hash(clusterSize, defaultVisitPercentage, rescoreVector);
23302334
}
23312335

23322336
@Override
@@ -2339,7 +2343,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
23392343
builder.startObject();
23402344
builder.field("type", type);
23412345
builder.field("cluster_size", clusterSize);
2342-
builder.field("default_n_probe", defaultNProbe);
2346+
builder.field("default_visit_percentage", defaultVisitPercentage);
23432347
if (rescoreVector != null) {
23442348
rescoreVector.toXContent(builder, params);
23452349
}
@@ -2736,6 +2740,7 @@ private Query createKnnFloatQuery(
27362740
.add(filter, BooleanClause.Occur.FILTER)
27372741
.build();
27382742
} else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) {
2743+
float defaultVisitRatio = (float) (bbqIndexOptions.defaultVisitPercentage / 100d);
27392744
knnQuery = parentFilter != null
27402745
? new DiversifyingChildrenIVFKnnFloatVectorQuery(
27412746
name(),
@@ -2744,9 +2749,9 @@ private Query createKnnFloatQuery(
27442749
numCands,
27452750
filter,
27462751
parentFilter,
2747-
bbqIndexOptions.defaultNProbe
2752+
defaultVisitRatio
27482753
)
2749-
: new IVFKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, bbqIndexOptions.defaultNProbe);
2754+
: new IVFKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, defaultVisitRatio);
27502755
} else {
27512756
knnQuery = parentFilter != null
27522757
? new ESDiversifyingChildrenFloatKnnVectorQuery(

0 commit comments

Comments
 (0)