Skip to content

Commit aeee542

Browse files
committed
minor tweaks, add knn tester param
1 parent 4f9982e commit aeee542

File tree

9 files changed

+45
-19
lines changed

9 files changed

+45
-19
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ record CmdLineArgs(
5353
VectorEncoding vectorEncoding,
5454
int dimensions,
5555
boolean earlyTermination,
56-
KnnIndexTester.MergePolicyType mergePolicy
56+
KnnIndexTester.MergePolicyType mergePolicy,
57+
float vectorsRatio
5758
) implements ToXContentObject {
5859

5960
static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors");
@@ -81,6 +82,7 @@ record CmdLineArgs(
8182
static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity");
8283
static final ParseField SEED_FIELD = new ParseField("seed");
8384
static final ParseField MERGE_POLICY_FIELD = new ParseField("merge_policy");
85+
static final ParseField VECTORS_RATIO = new ParseField("vectors_ratio");
8486

8587
static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
8688
Builder builder = PARSER.apply(parser, null);
@@ -115,6 +117,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
115117
PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD);
116118
PARSER.declareLong(Builder::setSeed, SEED_FIELD);
117119
PARSER.declareString(Builder::setMergePolicy, MERGE_POLICY_FIELD);
120+
PARSER.declareFloat(Builder::setVectorsRatio, VECTORS_RATIO);
118121
}
119122

120123
@Override
@@ -149,6 +152,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
149152
builder.field(EARLY_TERMINATION_FIELD.getPreferredName(), earlyTermination);
150153
builder.field(FILTER_SELECTIVITY_FIELD.getPreferredName(), filterSelectivity);
151154
builder.field(SEED_FIELD.getPreferredName(), seed);
155+
builder.field(VECTORS_RATIO.getPreferredName(), vectorsRatio);
152156
return builder.endObject();
153157
}
154158

@@ -183,6 +187,7 @@ static class Builder {
183187
private float filterSelectivity = 1f;
184188
private long seed = 1751900822751L;
185189
private KnnIndexTester.MergePolicyType mergePolicy = null;
190+
private float vectorsRatio = 1f;
186191

187192
public Builder setDocVectors(List<String> docVectors) {
188193
if (docVectors == null || docVectors.isEmpty()) {
@@ -313,6 +318,11 @@ public Builder setMergePolicy(String mergePolicy) {
313318
return this;
314319
}
315320

321+
public Builder setVectorsRatio(float vectorsRatio) {
322+
this.vectorsRatio = vectorsRatio;
323+
return this;
324+
}
325+
316326
public CmdLineArgs build() {
317327
if (docVectors == null) {
318328
throw new IllegalArgumentException("Document vectors path must be provided");
@@ -347,7 +357,8 @@ public CmdLineArgs build() {
347357
vectorEncoding,
348358
dimensions,
349359
earlyTermination,
350-
mergePolicy
360+
mergePolicy,
361+
vectorsRatio
351362
);
352363
}
353364
}

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class KnnSearcher {
115115
private final float overSamplingFactor;
116116
private final int searchThreads;
117117
private final int numSearchers;
118+
private final float vectorsRatio;
118119

119120
KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) {
120121
this.docPath = cmdLineArgs.docVectors();
@@ -137,6 +138,7 @@ class KnnSearcher {
137138
this.numSearchers = cmdLineArgs.numSearchers();
138139
this.randomSeed = cmdLineArgs.seed();
139140
this.selectivity = cmdLineArgs.filterSelectivity();
141+
this.vectorsRatio = cmdLineArgs.vectorsRatio();
140142
}
141143

142144
void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException {
@@ -424,7 +426,7 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery,
424426
}
425427
int efSearch = Math.max(topK, this.efSearch);
426428
if (indexType == KnnIndexTester.IndexType.IVF) {
427-
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, filterQuery, nProbe);
429+
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, filterQuery, nProbe, vectorsRatio);
428430
} else {
429431
knnQuery = new ESKnnFloatVectorQuery(
430432
VECTOR_FIELD,

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
257257
// TODO do we need to handle nested doc counts similarly to how we handle
258258
// filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors?
259259
while (centroidIterator.hasNext() && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) {
260+
// TODO : check previous centroid max score, and exit?
260261
++centroidsVisited;
261262
// todo do we actually need to know the score???
262263
long offset = centroidIterator.nextPostingListOffset();

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2748,9 +2748,10 @@ && isNotUnitVector(squaredMagnitude)) {
27482748
numCands,
27492749
filter,
27502750
parentFilter,
2751-
bbqIndexOptions.defaultNProbe
2751+
bbqIndexOptions.defaultNProbe,
2752+
1f
27522753
)
2753-
: new IVFKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, bbqIndexOptions.defaultNProbe);
2754+
: new IVFKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, bbqIndexOptions.defaultNProbe, 1);
27542755
} else {
27552756
knnQuery = parentFilter != null
27562757
? new ESDiversifyingChildrenFloatKnnVectorQuery(

server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,15 @@ abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerP
6060

6161
static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
6262

63-
private static final double VECTOR_VISITED_PERCENTAGE_BUDGET = 0.05;
64-
6563
protected final String field;
6664
protected final int nProbe;
6765
protected final int k;
6866
protected final int numCands;
6967
protected final Query filter;
68+
private final float visitedRatio;
7069
protected int vectorOpsCount;
7170

72-
protected AbstractIVFKnnVectorQuery(String field, int nProbe, int k, int numCands, Query filter) {
71+
protected AbstractIVFKnnVectorQuery(String field, int nProbe, int k, int numCands, float visitedRatio, Query filter) {
7372
if (k < 1) {
7473
throw new IllegalArgumentException("k must be at least 1, got: " + k);
7574
}
@@ -84,6 +83,7 @@ protected AbstractIVFKnnVectorQuery(String field, int nProbe, int k, int numCand
8483
this.k = k;
8584
this.filter = filter;
8685
this.numCands = numCands;
86+
this.visitedRatio = visitedRatio;
8787
}
8888

8989
@Override
@@ -132,7 +132,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
132132
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
133133
List<LeafReaderContext> leafReaderContexts = reader.leaves();
134134

135-
int totalBudget = (int) (reader.numDocs() * VECTOR_VISITED_PERCENTAGE_BUDGET);
135+
int totalBudget = (int) (reader.numDocs() * visitedRatio);
136136

137137
List<Callable<TopDocs>> tasks;
138138
if (leafReaderContexts.isEmpty() == false) {
@@ -146,12 +146,12 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
146146
.mapToDouble(Double::doubleValue)
147147
.toArray();
148148

149-
// max affinity for decreasing nProbe
150149
double averageAffinity = Arrays.stream(affinityScores).average().orElse(Double.NaN);
150+
// max affinity for decreasing nProbe
151151
double maxAffinity = Arrays.stream(affinityScores).max().orElse(Double.NaN);
152152
double lowerAffinity = (maxAffinity + averageAffinity) * 0.5;
153153
double cutoffAffinity = lowerAffinity * 0.5; // minimum affinity score for a segment to be considered
154-
double affinityTreshold = (maxAffinity + lowerAffinity) * 0.66; // min affinity for increasing nProbe
154+
double affinityThreshold = (maxAffinity + lowerAffinity) * 0.66; // min affinity for increasing nProbe
155155
int maxAdjustments = (int) (nProbe * 1.5);
156156

157157
if (Double.isNaN(maxAffinity) || Double.isNaN(averageAffinity)) {
@@ -171,10 +171,10 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
171171
if (score < cutoffAffinity) {
172172
continue;
173173
}
174-
int adjustedNProbe = adjustNProbeForSegment(score, affinityTreshold, maxAdjustments);
174+
int adjustedNProbe = adjustNProbeForSegment(score, affinityThreshold, maxAdjustments);
175175
LeafReaderContext context = segmentAffinity.context();
176176

177-
// budgetᵢ = total_budget × (affinityᵢ × |vectors|ᵢ) / ∑ (affinityⱼ × |vectors|ⱼ)
177+
// distribute the budget according to : budgetᵢ = total_budget × (affinityᵢ × |vectors|ᵢ) / ∑ (affinityⱼ × |vectors|ⱼ)
178178
int segmentBudget = (int) (totalBudget * (score * context.reader().numDocs()) / scoreVectorsSum);
179179
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, adjustedNProbe, Math.max(1, segmentBudget)));
180180
}
@@ -257,6 +257,7 @@ private List<SegmentAffinity> calculateSegmentAffinities(List<LeafReaderContext>
257257

258258
// with larger clusters, global centroid might not be a good representative,
259259
// so we want to include "some" centroids' scores for higher quality estimate
260+
// TODO: tweak the threshold numCentroids here
260261
if (numCentroids > 64) {
261262
float[] centroidScores = reader.getCentroidsScores(
262263
fieldInfo,

server/src/main/java/org/elasticsearch/search/vectors/DiversifyingChildrenIVFKnnFloatVectorQuery.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ public DiversifyingChildrenIVFKnnFloatVectorQuery(
3838
int numCands,
3939
Query childFilter,
4040
BitSetProducer parentsFilter,
41-
int nProbe
41+
int nProbe,
42+
float visitedRatio
4243
) {
43-
super(field, query, k, numCands, childFilter, nProbe);
44+
super(field, query, k, numCands, childFilter, nProbe, visitedRatio);
4445
this.parentsFilter = parentsFilter;
4546
}
4647

server/src/main/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQuery.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ public class IVFKnnFloatVectorQuery extends AbstractIVFKnnVectorQuery {
3535
* @param filter the filter to apply to the results
3636
* @param nProbe the number of probes to use for the IVF search strategy
3737
*/
38-
public IVFKnnFloatVectorQuery(String field, float[] query, int k, int numCands, Query filter, int nProbe) {
39-
super(field, nProbe, k, numCands, filter);
38+
public IVFKnnFloatVectorQuery(String field, float[] query, int k, int numCands, Query filter, int nProbe, float visitedRatio) {
39+
super(field, nProbe, k, numCands, visitedRatio, filter);
4040
this.query = query;
4141
}
4242

server/src/test/java/org/elasticsearch/search/vectors/DiversifyingChildrenIVFKnnFloatVectorQueryTests.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,16 @@ public class DiversifyingChildrenIVFKnnFloatVectorQueryTests extends AbstractDiv
1818

1919
@Override
2020
Query getDiversifyingChildrenKnnQuery(String fieldName, float[] queryVector, Query childFilter, int k, BitSetProducer parentBitSet) {
21-
return new DiversifyingChildrenIVFKnnFloatVectorQuery(fieldName, queryVector, k, k, childFilter, parentBitSet, -1);
21+
return new DiversifyingChildrenIVFKnnFloatVectorQuery(
22+
fieldName,
23+
queryVector,
24+
k,
25+
k,
26+
childFilter,
27+
parentBitSet,
28+
-1,
29+
1f
30+
);
2231
}
2332

2433
@Override

server/src/test/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQueryTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public class IVFKnnFloatVectorQueryTests extends AbstractIVFKnnVectorQueryTestCa
2727

2828
@Override
2929
IVFKnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter, int nProbe) {
30-
return new IVFKnnFloatVectorQuery(field, query, k, k, queryFilter, nProbe);
30+
return new IVFKnnFloatVectorQuery(field, query, k, k, queryFilter, nProbe, 1);
3131
}
3232

3333
@Override

0 commit comments

Comments
 (0)