Skip to content

Commit a7704f5

Browse files
authored
Convert more PriorityQueues to use Comparator (#14761)
1 parent a6c60e5 commit a7704f5

File tree

30 files changed

+273
-402
lines changed

30 files changed

+273
-402
lines changed

lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene60/bkd/BKDWriter60.java

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.IOException;
2121
import java.util.ArrayList;
2222
import java.util.Arrays;
23+
import java.util.Comparator;
2324
import java.util.List;
2425
import java.util.function.IntFunction;
2526
import org.apache.lucene.codecs.CodecUtil;
@@ -453,29 +454,14 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
453454
}
454455
}
455456

456-
private static class BKDMergeQueue extends PriorityQueue<MergeReader> {
457-
private final int bytesPerDim;
458-
459-
public BKDMergeQueue(int bytesPerDim, int maxSize) {
460-
super(maxSize);
461-
this.bytesPerDim = bytesPerDim;
462-
}
463-
464-
@Override
465-
public boolean lessThan(MergeReader a, MergeReader b) {
466-
assert a != b;
467-
468-
int cmp =
469-
Arrays.compareUnsigned(a.packedValue, 0, bytesPerDim, b.packedValue, 0, bytesPerDim);
470-
if (cmp < 0) {
471-
return true;
472-
} else if (cmp > 0) {
473-
return false;
474-
}
475-
476-
// Tie break by sorting smaller docIDs earlier:
477-
return a.docID < b.docID;
478-
}
457+
private static Comparator<MergeReader> mergeComparator(int bytesPerDim) {
458+
return ((Comparator<MergeReader>)
459+
(a, b) -> {
460+
assert a != b;
461+
return Arrays.compareUnsigned(
462+
a.packedValue, 0, bytesPerDim, b.packedValue, 0, bytesPerDim);
463+
})
464+
.thenComparingInt(mr -> mr.docID);
479465
}
480466

481467
/**
@@ -642,7 +628,8 @@ public long merge(IndexOutput out, List<MergeState.DocMap> docMaps, List<PointVa
642628
throws IOException {
643629
assert docMaps == null || readers.size() == docMaps.size();
644630

645-
BKDMergeQueue queue = new BKDMergeQueue(config.bytesPerDim(), readers.size());
631+
PriorityQueue<MergeReader> queue =
632+
PriorityQueue.usingComparator(readers.size(), mergeComparator(config.bytesPerDim()));
646633

647634
for (int i = 0; i < readers.size(); i++) {
648635
PointValues pointValues = readers.get(i);

lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import java.util.ArrayList;
21+
import java.util.Comparator;
2122
import java.util.HashMap;
2223
import java.util.HashSet;
2324
import java.util.Objects;
@@ -40,6 +41,7 @@
4041
import org.apache.lucene.search.QueryVisitor;
4142
import org.apache.lucene.search.TermQuery;
4243
import org.apache.lucene.util.BytesRef;
44+
import org.apache.lucene.util.FloatComparator;
4345
import org.apache.lucene.util.PriorityQueue;
4446
import org.apache.lucene.util.automaton.LevenshteinAutomata;
4547

@@ -125,7 +127,8 @@ public void addTerms(String queryString, String fieldName) {
125127
fieldVals.add(new FieldVals(fieldName, maxEdits, queryString));
126128
}
127129

128-
private void addTerms(IndexReader reader, FieldVals f, ScoreTermQueue q) throws IOException {
130+
private void addTerms(IndexReader reader, FieldVals f, PriorityQueue<ScoreTerm> q)
131+
throws IOException {
129132
if (f.queryString == null) return;
130133
final Terms terms = MultiTerms.getTerms(reader, f.fieldName);
131134
if (terms == null) {
@@ -141,8 +144,8 @@ private void addTerms(IndexReader reader, FieldVals f, ScoreTermQueue q) throws
141144
String term = termAtt.toString();
142145
if (!processedTerms.contains(term)) {
143146
processedTerms.add(term);
144-
ScoreTermQueue variantsQ =
145-
new ScoreTermQueue(
147+
PriorityQueue<ScoreTerm> variantsQ =
148+
createScoreTermQueue(
146149
MAX_VARIANTS_PER_TERM); // maxNum variants considered for any one term
147150
float minScore = 0;
148151
Term startTerm = new Term(f.fieldName, term);
@@ -214,7 +217,7 @@ private Query newTermQuery(IndexReader reader, Term term) throws IOException {
214217
@Override
215218
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
216219
IndexReader reader = indexSearcher.getIndexReader();
217-
ScoreTermQueue q = new ScoreTermQueue(MAX_NUM_TERMS);
220+
PriorityQueue<ScoreTerm> q = createScoreTermQueue(MAX_NUM_TERMS);
218221
// load up the list of possible terms
219222
for (FieldVals f : fieldVals) {
220223
addTerms(reader, f, q);
@@ -275,19 +278,11 @@ private static class ScoreTerm {
275278
}
276279
}
277280

278-
private static class ScoreTermQueue extends PriorityQueue<ScoreTerm> {
279-
ScoreTermQueue(int size) {
280-
super(size);
281-
}
282-
283-
/* (non-Javadoc)
284-
* @see org.apache.lucene.util.PriorityQueue#lessThan(java.lang.Object, java.lang.Object)
285-
*/
286-
@Override
287-
protected boolean lessThan(ScoreTerm termA, ScoreTerm termB) {
288-
if (termA.score == termB.score) return termA.term.compareTo(termB.term) > 0;
289-
else return termA.score < termB.score;
290-
}
281+
private static PriorityQueue<ScoreTerm> createScoreTermQueue(int size) {
282+
return PriorityQueue.usingComparator(
283+
size,
284+
FloatComparator.<ScoreTerm>comparing(st -> st.score)
285+
.thenComparing(st -> st.term, Comparator.reverseOrder()));
291286
}
292287

293288
@Override

lucene/codecs/src/java/org/apache/lucene/codecs/uniformsplit/sharedterms/STUniformSplitTermsWriter.java

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,14 @@ private Collection<FieldMetadata> writeSingleSegment(
164164
throws IOException {
165165
List<FieldMetadata> fieldMetadataList =
166166
createFieldMetadataList(new FieldsIterator(fields, fieldInfos), maxDoc);
167-
TermIteratorQueue<FieldTerms> fieldTermsQueue =
167+
PriorityQueue<TermIterator<FieldTerms>> fieldTermsQueue =
168168
createFieldTermsQueue(fields, fieldMetadataList);
169169
List<TermIterator<FieldTerms>> groupedFieldTerms = new ArrayList<>(fieldTermsQueue.size());
170170
List<FieldMetadataTermState> termStates = new ArrayList<>(fieldTermsQueue.size());
171171

172172
while (fieldTermsQueue.size() != 0) {
173-
TermIterator<FieldTerms> topFieldTerms = fieldTermsQueue.popTerms();
173+
TermIterator<FieldTerms> topFieldTerms = fieldTermsQueue.pop();
174+
assert topFieldTerms != null && topFieldTerms.term != null;
174175
BytesRef term = BytesRef.deepCopyOf(topFieldTerms.term);
175176
groupByTerm(fieldTermsQueue, topFieldTerms, groupedFieldTerms);
176177
writePostingLines(term, groupedFieldTerms, normsProducer, termStates);
@@ -190,9 +191,10 @@ private List<FieldMetadata> createFieldMetadataList(Iterator<FieldInfo> fieldInf
190191
return fieldMetadataList;
191192
}
192193

193-
private TermIteratorQueue<FieldTerms> createFieldTermsQueue(
194+
private PriorityQueue<TermIterator<FieldTerms>> createFieldTermsQueue(
194195
Fields fields, List<FieldMetadata> fieldMetadataList) throws IOException {
195-
TermIteratorQueue<FieldTerms> fieldQueue = new TermIteratorQueue<>(fieldMetadataList.size());
196+
PriorityQueue<TermIterator<FieldTerms>> fieldQueue =
197+
PriorityQueue.usingComparator(fieldMetadataList.size(), Comparator.naturalOrder());
196198
for (FieldMetadata fieldMetadata : fieldMetadataList) {
197199
Terms terms = fields.terms(fieldMetadata.getFieldInfo().name);
198200
if (terms != null) {
@@ -207,7 +209,7 @@ private TermIteratorQueue<FieldTerms> createFieldTermsQueue(
207209
}
208210

209211
private <T> void groupByTerm(
210-
TermIteratorQueue<T> termIteratorQueue,
212+
PriorityQueue<TermIterator<T>> termIteratorQueue,
211213
TermIterator<T> topTermIterator,
212214
List<TermIterator<T>> groupedTermIterators) {
213215
groupedTermIterators.clear();
@@ -243,7 +245,8 @@ private void writePostingLines(
243245
}
244246

245247
private <T> void nextTermForIterators(
246-
List<? extends TermIterator<T>> termIterators, TermIteratorQueue<T> termIteratorQueue)
248+
List<? extends TermIterator<T>> termIterators,
249+
PriorityQueue<TermIterator<T>> termIteratorQueue)
247250
throws IOException {
248251
for (TermIterator<T> termIterator : termIterators) {
249252
if (termIterator.nextTerm()) {
@@ -330,15 +333,17 @@ private Collection<FieldMetadata> mergeSegments(
330333
mergeState.mergeFieldInfos.iterator(), mergeState.segmentInfo.maxDoc());
331334
Map<String, MergingFieldTerms> fieldTermsMap =
332335
createMergingFieldTermsMap(fieldMetadataList, mergeState.fieldsProducers.length);
333-
TermIteratorQueue<SegmentTerms> segmentTermsQueue = createSegmentTermsQueue(segmentTermsList);
336+
PriorityQueue<TermIterator<SegmentTerms>> segmentTermsQueue =
337+
createSegmentTermsQueue(segmentTermsList);
334338
List<TermIterator<SegmentTerms>> groupedSegmentTerms = new ArrayList<>(segmentTermsList.size());
335339
Map<String, List<SegmentPostings>> fieldPostingsMap =
336340
CollectionUtil.newHashMap(mergeState.fieldInfos.length);
337341
List<MergingFieldTerms> groupedFieldTerms = new ArrayList<>(mergeState.fieldInfos.length);
338342
List<FieldMetadataTermState> termStates = new ArrayList<>(mergeState.fieldInfos.length);
339343

340344
while (segmentTermsQueue.size() != 0) {
341-
TermIterator<SegmentTerms> topSegmentTerms = segmentTermsQueue.popTerms();
345+
TermIterator<SegmentTerms> topSegmentTerms = segmentTermsQueue.pop();
346+
assert topSegmentTerms != null && topSegmentTerms.term != null;
342347
BytesRef term = BytesRef.deepCopyOf(topSegmentTerms.term);
343348
groupByTerm(segmentTermsQueue, topSegmentTerms, groupedSegmentTerms);
344349
combineSegmentsFields(groupedSegmentTerms, fieldPostingsMap);
@@ -364,9 +369,10 @@ private Map<String, MergingFieldTerms> createMergingFieldTermsMap(
364369
return fieldTermsMap;
365370
}
366371

367-
private TermIteratorQueue<SegmentTerms> createSegmentTermsQueue(
372+
private PriorityQueue<TermIterator<SegmentTerms>> createSegmentTermsQueue(
368373
List<TermIterator<SegmentTerms>> segmentTermsList) throws IOException {
369-
TermIteratorQueue<SegmentTerms> segmentQueue = new TermIteratorQueue<>(segmentTermsList.size());
374+
PriorityQueue<TermIterator<SegmentTerms>> segmentQueue =
375+
PriorityQueue.usingComparator(segmentTermsList.size(), Comparator.naturalOrder());
370376
for (TermIterator<SegmentTerms> segmentTerms : segmentTermsList) {
371377
if (segmentTerms.nextTerm()) {
372378
// There is at least one term in the segment
@@ -447,26 +453,7 @@ PostingsEnum getPostings(String fieldName, PostingsEnum reuse, int flags) throws
447453
}
448454
}
449455

450-
private class TermIteratorQueue<T> extends PriorityQueue<TermIterator<T>> {
451-
452-
TermIteratorQueue(int numFields) {
453-
super(numFields);
454-
}
455-
456-
@Override
457-
protected boolean lessThan(TermIterator<T> a, TermIterator<T> b) {
458-
return a.compareTo(b) < 0;
459-
}
460-
461-
TermIterator<T> popTerms() {
462-
TermIterator<T> topTerms = pop();
463-
assert topTerms != null;
464-
assert topTerms.term != null;
465-
return topTerms;
466-
}
467-
}
468-
469-
private abstract class TermIterator<T> implements Comparable<TermIterator<T>> {
456+
private abstract static class TermIterator<T> implements Comparable<TermIterator<T>> {
470457

471458
BytesRef term;
472459

@@ -485,7 +472,7 @@ public int compareTo(TermIterator<T> other) {
485472
abstract int compareSecondary(TermIterator<T> other);
486473
}
487474

488-
private class FieldTerms extends TermIterator<FieldTerms> {
475+
private static class FieldTerms extends TermIterator<FieldTerms> {
489476

490477
final FieldMetadata fieldMetadata;
491478
final TermsEnum termsEnum;
@@ -520,7 +507,7 @@ void resetIterator(BytesRef term, List<SegmentPostings> segmentPostingsList) {
520507
}
521508
}
522509

523-
private class SegmentTerms extends TermIterator<SegmentTerms> {
510+
private static class SegmentTerms extends TermIterator<SegmentTerms> {
524511

525512
private final Integer segmentIndex;
526513
private final STMergingBlockReader mergingBlockReader;

lucene/core/src/java/org/apache/lucene/index/DocIDMerger.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,12 @@ private SortedDocIDMerger(List<T> subs, int maxCount) throws IOException {
145145
}
146146
this.subs = subs;
147147
queue =
148-
new PriorityQueue<T>(maxCount - 1) {
149-
@Override
150-
protected boolean lessThan(Sub a, Sub b) {
151-
assert a.mappedDocID != b.mappedDocID;
152-
return a.mappedDocID < b.mappedDocID;
153-
}
154-
};
148+
PriorityQueue.usingComparator(
149+
maxCount - 1,
150+
(a, b) -> {
151+
assert a.mappedDocID != b.mappedDocID;
152+
return Integer.compare(a.mappedDocID, b.mappedDocID);
153+
});
155154
reset();
156155
}
157156

lucene/core/src/java/org/apache/lucene/index/DocValuesFieldUpdates.java

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
2020

21+
import java.util.Comparator;
2122
import org.apache.lucene.search.DocIdSetIterator;
2223
import org.apache.lucene.util.Accountable;
2324
import org.apache.lucene.util.BytesRef;
@@ -159,24 +160,12 @@ public static Iterator mergedIterator(Iterator[] subs) {
159160
return subs[0];
160161
}
161162

163+
// sort by smaller docID, then larger delGen
162164
PriorityQueue<Iterator> queue =
163-
new PriorityQueue<Iterator>(subs.length) {
164-
@Override
165-
protected boolean lessThan(Iterator a, Iterator b) {
166-
// sort by smaller docID
167-
int cmp = Integer.compare(a.docID(), b.docID());
168-
if (cmp == 0) {
169-
// then by larger delGen
170-
cmp = Long.compare(b.delGen(), a.delGen());
171-
172-
// delGens are unique across our subs:
173-
assert cmp != 0;
174-
}
175-
176-
return cmp < 0;
177-
}
178-
};
179-
165+
PriorityQueue.usingComparator(
166+
subs.length,
167+
Comparator.comparingInt(Iterator::docID)
168+
.thenComparing(Comparator.comparingLong(Iterator::delGen).reversed()));
180169
for (Iterator sub : subs) {
181170
if (sub.nextDoc() != NO_MORE_DOCS) {
182171
queue.add(sub);

lucene/core/src/java/org/apache/lucene/index/MultiSorter.java

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.lucene.index;
1919

2020
import java.io.IOException;
21+
import java.util.Comparator;
2122
import java.util.List;
2223
import org.apache.lucene.index.MergeState.DocMap;
2324
import org.apache.lucene.search.Sort;
@@ -82,24 +83,22 @@ static MergeState.DocMap[] sort(Sort sort, List<CodecReader> readers) throws IOE
8283
int leafCount = readers.size();
8384

8485
PriorityQueue<LeafAndDocID> queue =
85-
new PriorityQueue<LeafAndDocID>(leafCount) {
86-
@Override
87-
public boolean lessThan(LeafAndDocID a, LeafAndDocID b) {
88-
for (int i = 0; i < comparables.length; i++) {
89-
int cmp = Long.compare(a.valuesAsComparableLongs[i], b.valuesAsComparableLongs[i]);
90-
if (cmp != 0) {
91-
return reverseMuls[i] * cmp < 0;
92-
}
93-
}
94-
95-
// tie-break by docID natural order:
96-
if (a.readerIndex != b.readerIndex) {
97-
return a.readerIndex < b.readerIndex;
98-
} else {
99-
return a.docID < b.docID;
100-
}
101-
}
102-
};
86+
PriorityQueue.usingComparator(
87+
leafCount,
88+
((Comparator<LeafAndDocID>)
89+
(a, b) -> {
90+
for (int i = 0; i < comparables.length; i++) {
91+
int cmp =
92+
Long.compare(
93+
a.valuesAsComparableLongs[i], b.valuesAsComparableLongs[i]);
94+
if (cmp != 0) {
95+
return reverseMuls[i] * cmp;
96+
}
97+
}
98+
return 0;
99+
})
100+
.thenComparingInt(ld -> ld.readerIndex)
101+
.thenComparingInt(ld -> ld.docID));
103102

104103
PackedLongValues.Builder[] builders = new PackedLongValues.Builder[leafCount];
105104

lucene/core/src/java/org/apache/lucene/search/DisjunctionScorer.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.ArrayList;
2121
import java.util.Collection;
2222
import java.util.List;
23+
import org.apache.lucene.util.FloatComparator;
2324
import org.apache.lucene.util.PriorityQueue;
2425

2526
/** Base class for Scorers that score disjunctions. */
@@ -88,12 +89,7 @@ private TwoPhase(DocIdSetIterator approximation, float matchCost) {
8889
super(approximation);
8990
this.matchCost = matchCost;
9091
unverifiedMatches =
91-
new PriorityQueue<DisiWrapper>(numClauses) {
92-
@Override
93-
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
94-
return a.matchCost < b.matchCost;
95-
}
96-
};
92+
PriorityQueue.usingComparator(numClauses, FloatComparator.comparing(d -> d.matchCost));
9793
}
9894

9995
DisiWrapper getSubMatches() throws IOException {

0 commit comments

Comments
 (0)