Skip to content

Commit 52597fa

Browse files
committed
Make it possible to extend Patience/Seeded knn queries (apache#14838)
(cherry picked from commit 2b47cd3)
1 parent f93a0ed commit 52597fa

File tree

5 files changed

+96
-11
lines changed

5 files changed

+96
-11
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ Bug Fixes
104104

105105
* GITHUB#14755: Fix too many documents collected when only bool-filter condition is present. (Ke Wei)
106106

107+
* GITHUB#14838: Make it possible to extend Patience/Seeded knn queries (Tommaso Teofili)
108+
107109
Build
108110
---------------------
109111
* Upgrade forbiddenapis to version 3.9. (Uwe Schindler)

lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ public class HnswQueueSaturationCollector extends KnnCollector.Decorator {
3737
private int previousQueueSize;
3838
private int currentQueueSize;
3939

40-
HnswQueueSaturationCollector(KnnCollector delegate, double saturationThreshold, int patience) {
40+
public HnswQueueSaturationCollector(
41+
KnnCollector delegate, double saturationThreshold, int patience) {
4142
super(delegate);
4243
this.delegate = delegate;
4344
this.previousQueueSize = 0;

lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public class PatienceKnnVectorQuery extends AbstractKnnVectorQuery {
4848
/**
4949
* Construct a new PatienceKnnVectorQuery instance for a float vector field
5050
*
51-
* @param knnQuery the knn query to be seeded
51+
* @param knnQuery the knn query to be wrapped
5252
* @param saturationThreshold the early exit saturation threshold
5353
* @param patience the patience parameter
5454
* @return a new PatienceKnnVectorQuery instance
@@ -62,7 +62,7 @@ public static PatienceKnnVectorQuery fromFloatQuery(
6262
/**
6363
* Construct a new PatienceKnnVectorQuery instance for a float vector field
6464
*
65-
* @param knnQuery the knn query to be seeded
65+
* @param knnQuery the knn query to be wrapped
6666
* @return a new PatienceKnnVectorQuery instance
6767
* @lucene.experimental
6868
*/
@@ -74,7 +74,7 @@ public static PatienceKnnVectorQuery fromFloatQuery(KnnFloatVectorQuery knnQuery
7474
/**
7575
* Construct a new PatienceKnnVectorQuery instance for a byte vector field
7676
*
77-
* @param knnQuery the knn query to be seeded
77+
* @param knnQuery the knn query to be wrapped
7878
* @param saturationThreshold the early exit saturation threshold
7979
* @param patience the patience parameter
8080
* @return a new PatienceKnnVectorQuery instance
@@ -124,13 +124,55 @@ public static PatienceKnnVectorQuery fromSeededQuery(SeededKnnVectorQuery knnQue
124124
}
125125

126126
PatienceKnnVectorQuery(
127-
AbstractKnnVectorQuery knnQuery, double saturationThreshold, int patience) {
128-
super(knnQuery.field, knnQuery.k, knnQuery.filter, knnQuery.searchStrategy);
127+
AbstractKnnVectorQuery knnQuery,
128+
String field,
129+
int k,
130+
Query filter,
131+
KnnSearchStrategy searchStrategy,
132+
double saturationThreshold,
133+
int patience) {
134+
super(field, k, filter, searchStrategy);
129135
this.delegate = knnQuery;
130136
this.saturationThreshold = saturationThreshold;
131137
this.patience = patience;
132138
}
133139

140+
public PatienceKnnVectorQuery(
141+
SeededKnnVectorQuery knnQuery, double saturationThreshold, int patience) {
142+
this(
143+
knnQuery,
144+
knnQuery.field,
145+
knnQuery.k,
146+
knnQuery.filter,
147+
knnQuery.searchStrategy,
148+
saturationThreshold,
149+
patience);
150+
}
151+
152+
public PatienceKnnVectorQuery(
153+
KnnFloatVectorQuery knnQuery, double saturationThreshold, int patience) {
154+
this(
155+
knnQuery,
156+
knnQuery.field,
157+
knnQuery.k,
158+
knnQuery.filter,
159+
knnQuery.searchStrategy,
160+
saturationThreshold,
161+
patience);
162+
}
163+
164+
public PatienceKnnVectorQuery(
165+
KnnByteVectorQuery knnQuery, double saturationThreshold, int patience) {
166+
this(
167+
knnQuery,
168+
knnQuery.field,
169+
knnQuery.k,
170+
knnQuery.filter,
171+
knnQuery.searchStrategy,
172+
saturationThreshold,
173+
patience);
174+
}
175+
134176
private static int defaultPatience(AbstractKnnVectorQuery delegate) {
135177
return Math.max(7, (int) (delegate.k * 0.3));
136178
}
@@ -243,7 +285,11 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
243285
new SeededKnnVectorQuery(
244286
seededKnnVectorQuery.delegate,
245287
seededKnnVectorQuery.seed,
246-
seededKnnVectorQuery.createSeedWeight(indexSearcher));
288+
seededKnnVectorQuery.createSeedWeight(indexSearcher),
289+
delegate.field,
290+
delegate.k,
291+
delegate.filter,
292+
delegate.searchStrategy);
247293
}
248294
return super.rewrite(indexSearcher);
249295
}

lucene/core/src/java/org/apache/lucene/search/SeededKnnVectorQuery.java

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,42 @@ public static SeededKnnVectorQuery fromByteQuery(KnnByteVectorQuery knnQuery, Qu
6969
return new SeededKnnVectorQuery(knnQuery, seed, null);
7070
}
7171

72-
SeededKnnVectorQuery(AbstractKnnVectorQuery knnQuery, Query seed, Weight seedWeight) {
73-
super(knnQuery.field, knnQuery.k, knnQuery.filter, knnQuery.searchStrategy);
72+
SeededKnnVectorQuery(
73+
AbstractKnnVectorQuery knnQuery,
74+
Query seed,
75+
Weight seedWeight,
76+
String field,
77+
int k,
78+
Query filter,
79+
KnnSearchStrategy searchStrategy) {
80+
super(field, k, filter, searchStrategy);
7481
this.delegate = knnQuery;
7582
this.seed = Objects.requireNonNull(seed);
7683
this.seedWeight = seedWeight;
7784
}
7885

86+
public SeededKnnVectorQuery(KnnFloatVectorQuery knnQuery, Query seed, Weight seedWeight) {
87+
this(
88+
knnQuery,
89+
seed,
90+
seedWeight,
91+
knnQuery.field,
92+
knnQuery.k,
93+
knnQuery.filter,
94+
knnQuery.searchStrategy);
95+
}
96+
97+
public SeededKnnVectorQuery(KnnByteVectorQuery knnQuery, Query seed, Weight seedWeight) {
98+
this(
99+
knnQuery,
100+
seed,
101+
seedWeight,
102+
knnQuery.field,
103+
knnQuery.k,
104+
knnQuery.filter,
105+
knnQuery.searchStrategy);
106+
}
107+
79108
@Override
80109
public String toString(String field) {
81110
return "SeededKnnVectorQuery{"
@@ -94,7 +123,14 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
94123
return super.rewrite(indexSearcher);
95124
}
96125
SeededKnnVectorQuery rewritten =
97-
new SeededKnnVectorQuery(delegate, seed, createSeedWeight(indexSearcher));
126+
new SeededKnnVectorQuery(
127+
delegate,
128+
seed,
129+
createSeedWeight(indexSearcher),
130+
delegate.field,
131+
delegate.k,
132+
delegate.filter,
133+
delegate.searchStrategy);
98134
return rewritten.rewrite(indexSearcher);
99135
}
100136

lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ static class AssertingSeededKnnVectorQuery extends SeededKnnVectorQuery {
238238

239239
public AssertingSeededKnnVectorQuery(
240240
AbstractKnnVectorQuery query, Query seed, Weight seedWeight, AtomicInteger seedCalls) {
241-
super(query, seed, seedWeight);
241+
super(query, seed, seedWeight, query.field, query.k, query.filter, query.searchStrategy);
242242
this.seedCalls = seedCalls;
243243
}
244244

0 commit comments

Comments
 (0)