Skip to content

Commit 0a354cb

Browse files
committed
Enabling histogram collection for PointRangeQuery (#14560)
1 parent a5e5e9a commit 0a354cb

File tree

4 files changed

+169
-14
lines changed

4 files changed

+169
-14
lines changed

lucene/CHANGES.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Optimizations
3333
---------------------
3434
* GITHUB#14418: Quick exit on filter query matching no docs when rewriting knn query. (Pan Guixin)
3535

36-
* GITHUB#14439: Efficient Histogram Collection using multi range traversal over PointTrees (Ankit Jain)
36+
* GITHUB#14439, GITHUB#14560: Efficient Histogram Collection using multi range traversal over PointTrees (Ankit Jain)
3737

3838
* GITHUB#14268: PointInSetQuery early exit on non-matching segments. (hanbj)
3939

lucene/sandbox/src/java/org/apache/lucene/sandbox/facet/plain/histograms/HistogramCollector.java

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
package org.apache.lucene.sandbox.facet.plain.histograms;
1818

1919
import java.io.IOException;
20+
import java.util.HashMap;
21+
import java.util.Map;
2022
import java.util.concurrent.ConcurrentMap;
23+
import java.util.function.Function;
2124
import org.apache.lucene.index.DocValues;
2225
import org.apache.lucene.index.DocValuesSkipper;
2326
import org.apache.lucene.index.DocValuesType;
@@ -27,10 +30,16 @@
2730
import org.apache.lucene.index.PointValues;
2831
import org.apache.lucene.index.SortedNumericDocValues;
2932
import org.apache.lucene.internal.hppc.LongIntHashMap;
33+
import org.apache.lucene.queries.function.FunctionScoreQuery;
34+
import org.apache.lucene.search.BoostQuery;
3035
import org.apache.lucene.search.CollectionTerminatedException;
3136
import org.apache.lucene.search.Collector;
37+
import org.apache.lucene.search.ConstantScoreQuery;
3238
import org.apache.lucene.search.DocIdStream;
39+
import org.apache.lucene.search.IndexOrDocValuesQuery;
3340
import org.apache.lucene.search.LeafCollector;
41+
import org.apache.lucene.search.PointRangeQuery;
42+
import org.apache.lucene.search.Query;
3443
import org.apache.lucene.search.Scorable;
3544
import org.apache.lucene.search.ScoreMode;
3645
import org.apache.lucene.search.Weight;
@@ -67,13 +76,15 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept
6776
// We can use multi range traversal logic to collect the histogram on numeric
6877
// field indexed as point for MATCH_ALL cases. In future, this can be extended
6978
// for Point Range Query cases as well
70-
if (weight != null && weight.count(context) == context.reader().maxDoc()) {
79+
final PointRangeQuery pointRangeQuery = getPointRangeQuery(field);
80+
if (isMatchAll(context) || pointRangeQuery != null) {
7181
final PointValues pointValues = context.reader().getPointValues(field);
7282
if (PointTreeBulkCollector.canCollectEfficiently(pointValues, bucketWidth)) {
7383
// In case of intra segment concurrency, only one collector should collect
7484
// documents for all the partitions to avoid duplications across collectors
7585
if (leafBulkCollected.putIfAbsent(context, true) == null) {
76-
PointTreeBulkCollector.collect(pointValues, bucketWidth, counts, maxBuckets);
86+
PointTreeBulkCollector.collect(
87+
pointValues, pointRangeQuery, bucketWidth, counts, maxBuckets);
7788
}
7889
// Either the collection is finished on this collector, or some other collector
7990
// already started that collection, so this collector can finish early!
@@ -330,4 +341,45 @@ static void checkMaxBuckets(int size, int maxBuckets) {
330341
public void setWeight(Weight weight) {
331342
this.weight = weight;
332343
}
344+
345+
private boolean isMatchAll(LeafReaderContext context) throws IOException {
346+
return weight != null && weight.count(context) == context.reader().maxDoc();
347+
}
348+
349+
private static final Map<Class<?>, Function<Query, Query>> queryWrappers;
350+
351+
// Initialize the wrapper map for unwrapping the query
352+
static {
353+
queryWrappers = new HashMap<>();
354+
queryWrappers.put(BoostQuery.class, q -> ((BoostQuery) q).getQuery());
355+
queryWrappers.put(ConstantScoreQuery.class, q -> ((ConstantScoreQuery) q).getQuery());
356+
queryWrappers.put(FunctionScoreQuery.class, q -> ((FunctionScoreQuery) q).getWrappedQuery());
357+
queryWrappers.put(
358+
IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery) q).getIndexQuery());
359+
}
360+
361+
/** Recursively unwraps query into the concrete form for applying the optimization */
362+
private static Query unwrapIntoConcreteQuery(Query query) {
363+
while (queryWrappers.containsKey(query.getClass())) {
364+
query = queryWrappers.get(query.getClass()).apply(query);
365+
}
366+
367+
return query;
368+
}
369+
370+
private PointRangeQuery getPointRangeQuery(final String field) {
371+
if (weight == null || weight.getQuery() == null) {
372+
return null;
373+
}
374+
375+
final Query concreteQuery = unwrapIntoConcreteQuery(weight.getQuery());
376+
377+
if (concreteQuery instanceof PointRangeQuery prq) {
378+
if (prq.getField().equals(field)) {
379+
return prq;
380+
}
381+
}
382+
383+
return null;
384+
}
333385
}

lucene/sandbox/src/java/org/apache/lucene/sandbox/facet/plain/histograms/PointTreeBulkCollector.java

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import java.util.function.Function;
2323
import org.apache.lucene.index.PointValues;
2424
import org.apache.lucene.internal.hppc.LongIntHashMap;
25+
import org.apache.lucene.search.CollectionTerminatedException;
2526
import org.apache.lucene.search.DocIdSetIterator;
27+
import org.apache.lucene.search.PointRangeQuery;
2628
import org.apache.lucene.util.NumericUtils;
2729

2830
/**
@@ -76,15 +78,23 @@ static boolean canCollectEfficiently(final PointValues pointValues, final long b
7678

7779
static void collect(
7880
final PointValues pointValues,
81+
final PointRangeQuery prq,
7982
final long bucketWidth,
8083
final LongIntHashMap collectorCounts,
8184
final int maxBuckets)
8285
throws IOException {
8386
final Function<byte[], Long> byteToLong = bytesToLong(pointValues.getBytesPerDimension());
87+
long leafMin = byteToLong.apply(pointValues.getMinPackedValue());
88+
long leafMax = byteToLong.apply(pointValues.getMaxPackedValue());
89+
if (prq != null) {
90+
leafMin = Math.max(leafMin, byteToLong.apply(prq.getLowerPoint()));
91+
leafMax = Math.min(leafMax, byteToLong.apply(prq.getUpperPoint()));
92+
}
8493
BucketManager collector =
8594
new BucketManager(
8695
collectorCounts,
87-
byteToLong.apply(pointValues.getMinPackedValue()),
96+
leafMin,
97+
leafMax + 1, // the max value is exclusive for collector
8898
bucketWidth,
8999
byteToLong,
90100
maxBuckets);
@@ -135,6 +145,11 @@ public void visit(int docID) {
135145
public void visit(int docID, byte[] packedValue) throws IOException {
136146
if (!collector.withinUpperBound(packedValue)) {
137147
collector.finalizePreviousBucket(packedValue);
148+
// If the packedValue is not within upper bound even after updating upper bound,
149+
// we have exhausted the max value and should throw early termination error
150+
if (!collector.withinUpperBound(packedValue)) {
151+
throw new CollectionTerminatedException();
152+
}
138153
}
139154

140155
if (collector.withinRange(packedValue)) {
@@ -146,6 +161,11 @@ public void visit(int docID, byte[] packedValue) throws IOException {
146161
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
147162
if (!collector.withinUpperBound(packedValue)) {
148163
collector.finalizePreviousBucket(packedValue);
164+
// If the packedValue is not within upper bound even after updating upper bound,
165+
// we have exhausted the max value and should throw early termination error
166+
if (!collector.withinUpperBound(packedValue)) {
167+
throw new CollectionTerminatedException();
168+
}
149169
}
150170

151171
if (collector.withinRange(packedValue)) {
@@ -157,9 +177,14 @@ public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOExcept
157177

158178
@Override
159179
public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
160-
// try to find the first range that may collect values from this cell
180+
// Try to find the first range that may collect values from this cell
161181
if (!collector.withinUpperBound(minPackedValue)) {
162182
collector.finalizePreviousBucket(minPackedValue);
183+
// If the minPackedValue is not within upper bound even after updating upper bound,
184+
// we have exhausted the max value and should throw early termination error
185+
if (!collector.withinUpperBound(minPackedValue)) {
186+
throw new CollectionTerminatedException();
187+
}
163188
}
164189

165190
// Not possible to have the CELL_OUTSIDE_QUERY, as bucket lower bound is updated
@@ -176,6 +201,7 @@ private static class BucketManager {
176201
private final LongIntHashMap collectorCounts;
177202
private int counter = 0;
178203
private long startValue;
204+
private long maxValue;
179205
private long endValue;
180206
private int nonZeroBuckets = 0;
181207
private int maxBuckets;
@@ -185,13 +211,16 @@ private static class BucketManager {
185211
public BucketManager(
186212
LongIntHashMap collectorCounts,
187213
long minValue,
214+
long maxValue,
188215
long bucketWidth,
189216
Function<byte[], Long> byteToLong,
190217
int maxBuckets) {
191218
this.collectorCounts = collectorCounts;
192219
this.bucketWidth = bucketWidth;
193-
this.startValue = Math.floorDiv(minValue, bucketWidth) * bucketWidth;
194-
this.endValue = startValue + bucketWidth;
220+
this.startValue = minValue;
221+
this.endValue =
222+
Math.min((Math.floorDiv(startValue, bucketWidth) + 1) * bucketWidth, maxValue);
223+
this.maxValue = maxValue;
195224
this.byteToLong = byteToLong;
196225
this.maxBuckets = maxBuckets;
197226
}
@@ -205,19 +234,19 @@ private void countNode(int count) {
205234
}
206235

207236
private void finalizePreviousBucket(byte[] packedValue) {
208-
// TODO: Can counter ever be 0?
237+
// counter can be 0 for first bucket in case
238+
// of Point Range Query
209239
if (counter > 0) {
210240
collectorCounts.addTo(Math.floorDiv(startValue, bucketWidth), counter);
211-
if (packedValue != null) {
212-
startValue = byteToLong.apply(packedValue);
213-
// Align the start value with bucket width
214-
startValue = Math.floorDiv(startValue, bucketWidth) * bucketWidth;
215-
endValue = startValue + bucketWidth;
216-
}
217241
nonZeroBuckets++;
218242
counter = 0;
219243
HistogramCollector.checkMaxBuckets(nonZeroBuckets, maxBuckets);
220244
}
245+
246+
if (packedValue != null) {
247+
startValue = byteToLong.apply(packedValue);
248+
endValue = Math.min((Math.floorDiv(startValue, bucketWidth) + 1) * bucketWidth, maxValue);
249+
}
221250
}
222251

223252
private boolean withinLowerBound(byte[] value) {

lucene/sandbox/src/test/org/apache/lucene/sandbox/facet/plain/histograms/TestHistogramCollectorManager.java

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,23 @@
2626
import org.apache.lucene.index.IndexWriter;
2727
import org.apache.lucene.index.IndexWriterConfig;
2828
import org.apache.lucene.internal.hppc.LongIntHashMap;
29+
import org.apache.lucene.queries.function.FunctionScoreQuery;
2930
import org.apache.lucene.search.BooleanClause.Occur;
3031
import org.apache.lucene.search.BooleanQuery;
32+
import org.apache.lucene.search.BoostQuery;
33+
import org.apache.lucene.search.ConstantScoreQuery;
34+
import org.apache.lucene.search.DoubleValuesSource;
35+
import org.apache.lucene.search.IndexOrDocValuesQuery;
3136
import org.apache.lucene.search.IndexSearcher;
3237
import org.apache.lucene.search.MatchAllDocsQuery;
38+
import org.apache.lucene.search.PointRangeQuery;
3339
import org.apache.lucene.search.Query;
3440
import org.apache.lucene.search.Sort;
3541
import org.apache.lucene.search.SortField;
3642
import org.apache.lucene.store.Directory;
3743
import org.apache.lucene.tests.util.LuceneTestCase;
3844
import org.apache.lucene.tests.util.TestUtil;
45+
import org.apache.lucene.util.NumericUtils;
3946

4047
public class TestHistogramCollectorManager extends LuceneTestCase {
4148

@@ -136,6 +143,8 @@ public void testMultiRangePointTreeCollector() throws IOException {
136143
DirectoryReader reader = DirectoryReader.open(w);
137144
w.close();
138145
IndexSearcher searcher = newSearcher(reader);
146+
147+
// Validate the MATCH_ALL case
139148
LongIntHashMap actualCounts =
140149
searcher.search(new MatchAllDocsQuery(), new HistogramCollectorManager("f", 1000));
141150
LongIntHashMap expectedCounts = new LongIntHashMap();
@@ -144,6 +153,71 @@ public void testMultiRangePointTreeCollector() throws IOException {
144153
}
145154
assertEquals(expectedCounts, actualCounts);
146155

156+
// Validate the Point Range Query case
157+
int lowerBound = random().nextInt(0, 1500);
158+
int upperBound = random().nextInt(3500, 5000);
159+
160+
byte[] lowerPoint = new byte[Long.BYTES];
161+
byte[] upperPoint = new byte[Long.BYTES];
162+
NumericUtils.longToSortableBytes(lowerBound, lowerPoint, 0);
163+
NumericUtils.longToSortableBytes(upperBound, upperPoint, 0);
164+
final PointRangeQuery prq =
165+
new PointRangeQuery("f", lowerPoint, upperPoint, 1) {
166+
@Override
167+
protected String toString(int dimension, byte[] value) {
168+
return Long.toString(NumericUtils.sortableBytesToLong(value, 0));
169+
}
170+
};
171+
172+
actualCounts = searcher.search(prq, new HistogramCollectorManager("f", 1000));
173+
expectedCounts = new LongIntHashMap();
174+
for (long value : values) {
175+
if (value >= lowerBound && value <= upperBound) {
176+
expectedCounts.addTo(Math.floorDiv(value, 1000), 1);
177+
}
178+
}
179+
assertEquals(expectedCounts, actualCounts);
180+
181+
// Validate the BoostQuery case
182+
actualCounts =
183+
searcher.search(new BoostQuery(prq, 1.5f), new HistogramCollectorManager("f", 1000));
184+
// Don't need to compute expectedCounts again as underlying point range
185+
// query is not changing
186+
assertEquals(expectedCounts, actualCounts);
187+
188+
// Validate the ConstantScoreQuery case
189+
actualCounts =
190+
searcher.search(new ConstantScoreQuery(prq), new HistogramCollectorManager("f", 1000));
191+
// Don't need to compute expectedCounts again as underlying point range query is not changing
192+
assertEquals(expectedCounts, actualCounts);
193+
194+
// Validate the FunctionScoreQuery case
195+
actualCounts =
196+
searcher.search(
197+
new FunctionScoreQuery(prq, DoubleValuesSource.SCORES),
198+
new HistogramCollectorManager("f", 1000));
199+
// Don't need to compute expectedCounts again as underlying point range query is not changing
200+
assertEquals(expectedCounts, actualCounts);
201+
202+
// Validate the IndexOrDocValuesQuery case
203+
actualCounts =
204+
searcher.search(
205+
new IndexOrDocValuesQuery(prq, prq), new HistogramCollectorManager("f", 1000));
206+
// Don't need to compute expectedCounts again as underlying point range query is not changing
207+
assertEquals(expectedCounts, actualCounts);
208+
209+
// Validate the recursive wrapping case
210+
actualCounts =
211+
searcher.search(
212+
new ConstantScoreQuery(
213+
new BoostQuery(
214+
new FunctionScoreQuery(
215+
new IndexOrDocValuesQuery(prq, prq), DoubleValuesSource.SCORES),
216+
1.5f)),
217+
new HistogramCollectorManager("f", 1000));
218+
// Don't need to compute expectedCounts again as underlying point range query is not changing
219+
assertEquals(expectedCounts, actualCounts);
220+
147221
reader.close();
148222
dir.close();
149223
}

0 commit comments

Comments
 (0)