Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ record CmdLineArgs(
KnnIndexTester.IndexType indexType,
int numCandidates,
int k,
int[] nProbes,
double[] visitPercentages,
int ivfClusterSize,
int overSamplingFactor,
int hnswM,
Expand All @@ -63,7 +63,8 @@ record CmdLineArgs(
static final ParseField INDEX_TYPE_FIELD = new ParseField("index_type");
static final ParseField NUM_CANDIDATES_FIELD = new ParseField("num_candidates");
static final ParseField K_FIELD = new ParseField("k");
static final ParseField N_PROBE_FIELD = new ParseField("n_probe");
// static final ParseField N_PROBE_FIELD = new ParseField("n_probe");
static final ParseField VISIT_PERCENTAGE_FIELD = new ParseField("visit_percentage");
static final ParseField IVF_CLUSTER_SIZE_FIELD = new ParseField("ivf_cluster_size");
static final ParseField OVER_SAMPLING_FACTOR_FIELD = new ParseField("over_sampling_factor");
static final ParseField HNSW_M_FIELD = new ParseField("hnsw_m");
Expand Down Expand Up @@ -97,7 +98,8 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD);
PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD);
PARSER.declareInt(Builder::setK, K_FIELD);
PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD);
// PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD);
PARSER.declareDoubleArray(Builder::setVisitPercentages, VISIT_PERCENTAGE_FIELD);
PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD);
PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD);
PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD);
Expand Down Expand Up @@ -132,7 +134,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT));
builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates);
builder.field(K_FIELD.getPreferredName(), k);
builder.field(N_PROBE_FIELD.getPreferredName(), nProbes);
// builder.field(N_PROBE_FIELD.getPreferredName(), nProbes);
builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentages);
builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize);
builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor);
builder.field(HNSW_M_FIELD.getPreferredName(), hnswM);
Expand Down Expand Up @@ -165,7 +168,7 @@ static class Builder {
private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW;
private int numCandidates = 1000;
private int k = 10;
private int[] nProbes = new int[] { 10 };
private double[] visitPercentages = new double[] { 1.0 };
private int ivfClusterSize = 1000;
private int overSamplingFactor = 1;
private int hnswM = 16;
Expand Down Expand Up @@ -223,8 +226,8 @@ public Builder setK(int k) {
return this;
}

public Builder setNProbe(List<Integer> nProbes) {
this.nProbes = nProbes.stream().mapToInt(Integer::intValue).toArray();
public Builder setVisitPercentages(List<Double> visitPercentages) {
this.visitPercentages = visitPercentages.stream().mapToDouble(Double::doubleValue).toArray();
return this;
}

Expand Down Expand Up @@ -330,7 +333,7 @@ public CmdLineArgs build() {
indexType,
numCandidates,
k,
nProbes,
visitPercentages,
ivfClusterSize,
overSamplingFactor,
hnswM,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,18 +191,18 @@ public static void main(String[] args) throws Exception {
FormattedResults formattedResults = new FormattedResults();

for (CmdLineArgs cmdLineArgs : cmdLineArgsList) {
int[] nProbes = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0
? cmdLineArgs.nProbes()
: new int[] { 0 };
double[] visitPercentages = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0
? cmdLineArgs.visitPercentages()
: new double[] { 0 };
String indexType = cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT);
Results indexResults = new Results(
cmdLineArgs.docVectors().get(0).getFileName().toString(),
indexType,
cmdLineArgs.numDocs(),
cmdLineArgs.filterSelectivity()
);
Results[] results = new Results[nProbes.length];
for (int i = 0; i < nProbes.length; i++) {
Results[] results = new Results[visitPercentages.length];
for (int i = 0; i < visitPercentages.length; i++) {
results[i] = new Results(
cmdLineArgs.docVectors().get(0).getFileName().toString(),
indexType,
Expand Down Expand Up @@ -240,8 +240,7 @@ public static void main(String[] args) throws Exception {
numSegments(indexPath, indexResults);
if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
for (int i = 0; i < results.length; i++) {
int nProbe = nProbes[i];
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe);
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, visitPercentages[i]);
knnSearcher.runSearch(results[i], cmdLineArgs.earlyTermination());
}
}
Expand Down Expand Up @@ -293,7 +292,7 @@ public String toString() {
String[] searchHeaders = {
"index_name",
"index_type",
"n_probe",
"visit_percentage(%)",
"latency(ms)",
"net_cpu_time(ms)",
"avg_cpu_count",
Expand Down Expand Up @@ -324,7 +323,7 @@ public String toString() {
queryResultsArray[i] = new String[] {
queryResult.indexName,
queryResult.indexType,
Integer.toString(queryResult.nProbe),
String.format(Locale.ROOT, "%.2f", queryResult.visitPercentage),
String.format(Locale.ROOT, "%.2f", queryResult.avgLatency),
String.format(Locale.ROOT, "%.2f", queryResult.netCpuTimeMS),
String.format(Locale.ROOT, "%.2f", queryResult.avgCpuCount),
Expand Down Expand Up @@ -400,7 +399,7 @@ static class Results {
long indexTimeMS;
long forceMergeTimeMS;
int numSegments;
int nProbe;
double visitPercentage;
double avgLatency;
double qps;
double avgRecall;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class KnnSearcher {
private final float selectivity;
private final int topK;
private final int efSearch;
private final int nProbe;
private final double visitPercentage;
private final KnnIndexTester.IndexType indexType;
private int dim;
private final VectorSimilarityFunction similarityFunction;
Expand All @@ -116,7 +116,7 @@ class KnnSearcher {
private final int searchThreads;
private final int numSearchers;

KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) {
KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, double visitPercentage) {
this.docPath = cmdLineArgs.docVectors();
this.indexPath = indexPath;
this.queryPath = cmdLineArgs.queryVectors();
Expand All @@ -131,7 +131,7 @@ class KnnSearcher {
throw new IllegalArgumentException("numQueryVectors must be > 0");
}
this.efSearch = cmdLineArgs.numCandidates();
this.nProbe = nProbe;
this.visitPercentage = visitPercentage;
this.indexType = cmdLineArgs.indexType();
this.searchThreads = cmdLineArgs.searchThreads();
this.numSearchers = cmdLineArgs.numSearchers();
Expand Down Expand Up @@ -298,7 +298,7 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
}
logger.info("checking results");
int[][] nn = getOrCalculateExactNN(offsetByteSize, filterQuery);
finalResults.nProbe = indexType == KnnIndexTester.IndexType.IVF ? nProbe : 0;
finalResults.visitPercentage = indexType == KnnIndexTester.IndexType.IVF ? visitPercentage : 0;
finalResults.avgRecall = checkResults(resultIds, nn, topK);
finalResults.qps = (1000f * numQueryVectors) / elapsed;
finalResults.avgLatency = (float) elapsed / numQueryVectors;
Expand Down Expand Up @@ -424,7 +424,8 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery,
}
int efSearch = Math.max(topK, this.efSearch);
if (indexType == KnnIndexTester.IndexType.IVF) {
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, filterQuery, nProbe);
float visitRatio = (float) (visitPercentage / 100);
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, filterQuery, visitRatio);
} else {
knnQuery = new ESKnnFloatVectorQuery(
VECTOR_FIELD,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ public class IVFVectorsFormat extends KnnVectorsFormat {
);

// This dynamically sets the cluster probe based on the `k` requested and the number of clusters.
// useful when searching with 'efSearch' type parameters instead of requiring a specific nprobe.
public static final int DYNAMIC_NPROBE = -1;
// useful when searching with 'efSearch' type parameters instead of requiring a specific ratio.
public static final float DYNAMIC_VISIT_RATIO = 0.0f;
public static final int DEFAULT_VECTORS_PER_CLUSTER = 384;
public static final int MIN_VECTORS_PER_CLUSTER = 64;
public static final int MAX_VECTORS_PER_CLUSTER = 1 << 16; // 65536
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import java.io.IOException;

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_NPROBE;
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_VISIT_RATIO;

/**
* Reader for IVF vectors. This reader is used to read the IVF vectors from the index.
Expand Down Expand Up @@ -222,34 +222,38 @@ public final void search(String field, float[] target, KnnCollector knnCollector
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
}
int numVectors = rawVectorsReader.getFloatVectorValues(field).size();
int nProbe = DYNAMIC_NPROBE;
float visitRatio = DYNAMIC_VISIT_RATIO;
// Search strategy may be null if this is being called from checkIndex (e.g. from a test)
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
nProbe = ivfSearchStrategy.getNProbe();
visitRatio = ivfSearchStrategy.getVisitRatio();
}

FieldEntry entry = fields.get(fieldInfo.number);
if (nProbe == DYNAMIC_NPROBE) {
long maxVectorVisited;
if (visitRatio == DYNAMIC_VISIT_RATIO) {
// empirically based, and a good dynamic to get decent recall while scaling a la "efSearch"
// scaling by the number of centroids vs. the nearest neighbors requested
// scaling by the number of vectors vs. the nearest neighbors requested
// not perfect, but a comparative heuristic.
// we might want to utilize the total vector count as well, but this is a good start
nProbe = (int) Math.round(Math.log10(entry.numCentroids) * Math.sqrt(knnCollector.k()));
// clip to be between 1 and the number of centroids
nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1);
// TODO: we might want to consider the density of the centroids as experiments shows that for fewer vectors per centroid,
// the least vectors we need to score to get a good recall.
maxVectorVisited = Math.round(1.75f * Math.log10(knnCollector.k()) * Math.log10(knnCollector.k()) * (knnCollector.k()));
// clip so we visit at least one vector
maxVectorVisited = Math.max(maxVectorVisited, 1L);
} else {
// we account for soar vectors here. We can potentially visit a vector twice so we multiply by 2 here.
maxVectorVisited = (long) (2.0 * visitRatio * numVectors);
}
CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target);
PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, acceptDocs);
int centroidsVisited = 0;

long expectedDocs = 0;
long actualDocs = 0;
// initially we visit only the "centroids to search"
// Note, numCollected is doing the bare minimum here.
// TODO do we need to handle nested doc counts similarly to how we handle
// filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors?
while (centroidIterator.hasNext()
&& (centroidsVisited < nProbe || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) {
++centroidsVisited;
&& (maxVectorVisited > actualDocs || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) {
// todo do we actually need to know the score???
long offset = centroidIterator.nextPostingListOffset();
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1697,18 +1697,22 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map<String, ?
if (rescoreVector == null) {
rescoreVector = new RescoreVector(DEFAULT_OVERSAMPLE);
}
Object nProbeNode = indexOptionsMap.remove("default_n_probe");
int nProbe = -1;
if (nProbeNode != null) {
nProbe = XContentMapValues.nodeIntegerValue(nProbeNode);
if (nProbe < 1 && nProbe != -1) {
Object visitPercentageNode = indexOptionsMap.remove("default_visit_percentage");
double visitPercentage = 0d;
if (visitPercentageNode != null) {
visitPercentage = (float) XContentMapValues.nodeDoubleValue(visitPercentageNode);
if (visitPercentage < 0d || visitPercentage > 100d) {
throw new IllegalArgumentException(
"default_n_probe must be at least 1 or exactly -1, got: " + nProbe + " for field [" + fieldName + "]"
"default_visit_percentage must be between 0.0 and 100.0, got: "
+ visitPercentage
+ " for field ["
+ fieldName
+ "]"
);
}
}
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
return new BBQIVFIndexOptions(clusterSize, nProbe, rescoreVector);
return new BBQIVFIndexOptions(clusterSize, visitPercentage, rescoreVector);
}

@Override
Expand Down Expand Up @@ -2301,12 +2305,12 @@ public boolean validateDimension(int dim, boolean throwOnError) {

static class BBQIVFIndexOptions extends QuantizedIndexOptions {
final int clusterSize;
final int defaultNProbe;
final double defaultVisitPercentage;

BBQIVFIndexOptions(int clusterSize, int defaultNProbe, RescoreVector rescoreVector) {
BBQIVFIndexOptions(int clusterSize, double defaultVisitPercentage, RescoreVector rescoreVector) {
super(VectorIndexType.BBQ_DISK, rescoreVector);
this.clusterSize = clusterSize;
this.defaultNProbe = defaultNProbe;
this.defaultVisitPercentage = defaultVisitPercentage;
}

@Override
Expand All @@ -2324,13 +2328,13 @@ public boolean updatableTo(DenseVectorIndexOptions update) {
boolean doEquals(DenseVectorIndexOptions other) {
BBQIVFIndexOptions that = (BBQIVFIndexOptions) other;
return clusterSize == that.clusterSize
&& defaultNProbe == that.defaultNProbe
&& defaultVisitPercentage == that.defaultVisitPercentage
&& Objects.equals(rescoreVector, that.rescoreVector);
}

@Override
int doHashCode() {
return Objects.hash(clusterSize, defaultNProbe, rescoreVector);
return Objects.hash(clusterSize, defaultVisitPercentage, rescoreVector);
}

@Override
Expand All @@ -2343,7 +2347,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.field("type", type);
builder.field("cluster_size", clusterSize);
builder.field("default_n_probe", defaultNProbe);
builder.field("default_visit_percentage", defaultVisitPercentage);
if (rescoreVector != null) {
rescoreVector.toXContent(builder, params);
}
Expand Down Expand Up @@ -2740,6 +2744,7 @@ && isNotUnitVector(squaredMagnitude)) {
.add(filter, BooleanClause.Occur.FILTER)
.build();
} else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) {
float defaultVisitRatio = (float) (bbqIndexOptions.defaultVisitPercentage / 100d);
knnQuery = parentFilter != null
? new DiversifyingChildrenIVFKnnFloatVectorQuery(
name(),
Expand All @@ -2748,9 +2753,9 @@ && isNotUnitVector(squaredMagnitude)) {
numCands,
filter,
parentFilter,
bbqIndexOptions.defaultNProbe
defaultVisitRatio
)
: new IVFKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, bbqIndexOptions.defaultNProbe);
: new IVFKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, defaultVisitRatio);
} else {
knnQuery = parentFilter != null
? new ESDiversifyingChildrenFloatKnnVectorQuery(
Expand Down
Loading