Skip to content

Commit d8dd30f

Browse files
committed
Convert more complex implementations to Comparators
1 parent a31367e commit d8dd30f

File tree

14 files changed

+149
-171
lines changed

14 files changed

+149
-171
lines changed

lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import java.util.ArrayList;
21+
import java.util.Comparator;
2122
import java.util.HashMap;
2223
import java.util.HashSet;
2324
import java.util.Objects;
@@ -40,6 +41,7 @@
4041
import org.apache.lucene.search.QueryVisitor;
4142
import org.apache.lucene.search.TermQuery;
4243
import org.apache.lucene.util.BytesRef;
44+
import org.apache.lucene.util.FloatComparator;
4345
import org.apache.lucene.util.PriorityQueue;
4446
import org.apache.lucene.util.automaton.LevenshteinAutomata;
4547

@@ -125,7 +127,8 @@ public void addTerms(String queryString, String fieldName) {
125127
fieldVals.add(new FieldVals(fieldName, maxEdits, queryString));
126128
}
127129

128-
private void addTerms(IndexReader reader, FieldVals f, ScoreTermQueue q) throws IOException {
130+
private void addTerms(IndexReader reader, FieldVals f, PriorityQueue<ScoreTerm> q)
131+
throws IOException {
129132
if (f.queryString == null) return;
130133
final Terms terms = MultiTerms.getTerms(reader, f.fieldName);
131134
if (terms == null) {
@@ -141,8 +144,8 @@ private void addTerms(IndexReader reader, FieldVals f, ScoreTermQueue q) throws
141144
String term = termAtt.toString();
142145
if (!processedTerms.contains(term)) {
143146
processedTerms.add(term);
144-
ScoreTermQueue variantsQ =
145-
new ScoreTermQueue(
147+
PriorityQueue<ScoreTerm> variantsQ =
148+
createScoreTermQueue(
146149
MAX_VARIANTS_PER_TERM); // maxNum variants considered for any one term
147150
float minScore = 0;
148151
Term startTerm = new Term(f.fieldName, term);
@@ -214,7 +217,7 @@ private Query newTermQuery(IndexReader reader, Term term) throws IOException {
214217
@Override
215218
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
216219
IndexReader reader = indexSearcher.getIndexReader();
217-
ScoreTermQueue q = new ScoreTermQueue(MAX_NUM_TERMS);
220+
PriorityQueue<ScoreTerm> q = createScoreTermQueue(MAX_NUM_TERMS);
218221
// load up the list of possible terms
219222
for (FieldVals f : fieldVals) {
220223
addTerms(reader, f, q);
@@ -275,19 +278,11 @@ private static class ScoreTerm {
275278
}
276279
}
277280

278-
private static class ScoreTermQueue extends PriorityQueue<ScoreTerm> {
279-
ScoreTermQueue(int size) {
280-
super(size);
281-
}
282-
283-
/* (non-Javadoc)
284-
* @see org.apache.lucene.util.PriorityQueue#lessThan(java.lang.Object, java.lang.Object)
285-
*/
286-
@Override
287-
protected boolean lessThan(ScoreTerm termA, ScoreTerm termB) {
288-
if (termA.score == termB.score) return termA.term.compareTo(termB.term) > 0;
289-
else return termA.score < termB.score;
290-
}
281+
private static PriorityQueue<ScoreTerm> createScoreTermQueue(int size) {
282+
return PriorityQueue.usingComparator(
283+
size,
284+
FloatComparator.<ScoreTerm>comparing(st -> st.score)
285+
.thenComparing(st -> st.term, Comparator.reverseOrder()));
291286
}
292287

293288
@Override

lucene/codecs/src/java/org/apache/lucene/codecs/uniformsplit/sharedterms/STUniformSplitTermsWriter.java

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,14 @@ private Collection<FieldMetadata> writeSingleSegment(
164164
throws IOException {
165165
List<FieldMetadata> fieldMetadataList =
166166
createFieldMetadataList(new FieldsIterator(fields, fieldInfos), maxDoc);
167-
TermIteratorQueue<FieldTerms> fieldTermsQueue =
167+
PriorityQueue<TermIterator<FieldTerms>> fieldTermsQueue =
168168
createFieldTermsQueue(fields, fieldMetadataList);
169169
List<TermIterator<FieldTerms>> groupedFieldTerms = new ArrayList<>(fieldTermsQueue.size());
170170
List<FieldMetadataTermState> termStates = new ArrayList<>(fieldTermsQueue.size());
171171

172172
while (fieldTermsQueue.size() != 0) {
173-
TermIterator<FieldTerms> topFieldTerms = fieldTermsQueue.popTerms();
173+
TermIterator<FieldTerms> topFieldTerms = fieldTermsQueue.pop();
174+
assert topFieldTerms != null && topFieldTerms.term != null;
174175
BytesRef term = BytesRef.deepCopyOf(topFieldTerms.term);
175176
groupByTerm(fieldTermsQueue, topFieldTerms, groupedFieldTerms);
176177
writePostingLines(term, groupedFieldTerms, normsProducer, termStates);
@@ -190,9 +191,10 @@ private List<FieldMetadata> createFieldMetadataList(Iterator<FieldInfo> fieldInf
190191
return fieldMetadataList;
191192
}
192193

193-
private TermIteratorQueue<FieldTerms> createFieldTermsQueue(
194+
private PriorityQueue<TermIterator<FieldTerms>> createFieldTermsQueue(
194195
Fields fields, List<FieldMetadata> fieldMetadataList) throws IOException {
195-
TermIteratorQueue<FieldTerms> fieldQueue = new TermIteratorQueue<>(fieldMetadataList.size());
196+
PriorityQueue<TermIterator<FieldTerms>> fieldQueue =
197+
PriorityQueue.usingComparator(fieldMetadataList.size(), Comparator.naturalOrder());
196198
for (FieldMetadata fieldMetadata : fieldMetadataList) {
197199
Terms terms = fields.terms(fieldMetadata.getFieldInfo().name);
198200
if (terms != null) {
@@ -207,7 +209,7 @@ private TermIteratorQueue<FieldTerms> createFieldTermsQueue(
207209
}
208210

209211
private <T> void groupByTerm(
210-
TermIteratorQueue<T> termIteratorQueue,
212+
PriorityQueue<TermIterator<T>> termIteratorQueue,
211213
TermIterator<T> topTermIterator,
212214
List<TermIterator<T>> groupedTermIterators) {
213215
groupedTermIterators.clear();
@@ -243,7 +245,8 @@ private void writePostingLines(
243245
}
244246

245247
private <T> void nextTermForIterators(
246-
List<? extends TermIterator<T>> termIterators, TermIteratorQueue<T> termIteratorQueue)
248+
List<? extends TermIterator<T>> termIterators,
249+
PriorityQueue<TermIterator<T>> termIteratorQueue)
247250
throws IOException {
248251
for (TermIterator<T> termIterator : termIterators) {
249252
if (termIterator.nextTerm()) {
@@ -330,15 +333,17 @@ private Collection<FieldMetadata> mergeSegments(
330333
mergeState.mergeFieldInfos.iterator(), mergeState.segmentInfo.maxDoc());
331334
Map<String, MergingFieldTerms> fieldTermsMap =
332335
createMergingFieldTermsMap(fieldMetadataList, mergeState.fieldsProducers.length);
333-
TermIteratorQueue<SegmentTerms> segmentTermsQueue = createSegmentTermsQueue(segmentTermsList);
336+
PriorityQueue<TermIterator<SegmentTerms>> segmentTermsQueue =
337+
createSegmentTermsQueue(segmentTermsList);
334338
List<TermIterator<SegmentTerms>> groupedSegmentTerms = new ArrayList<>(segmentTermsList.size());
335339
Map<String, List<SegmentPostings>> fieldPostingsMap =
336340
CollectionUtil.newHashMap(mergeState.fieldInfos.length);
337341
List<MergingFieldTerms> groupedFieldTerms = new ArrayList<>(mergeState.fieldInfos.length);
338342
List<FieldMetadataTermState> termStates = new ArrayList<>(mergeState.fieldInfos.length);
339343

340344
while (segmentTermsQueue.size() != 0) {
341-
TermIterator<SegmentTerms> topSegmentTerms = segmentTermsQueue.popTerms();
345+
TermIterator<SegmentTerms> topSegmentTerms = segmentTermsQueue.pop();
346+
assert topSegmentTerms != null && topSegmentTerms.term != null;
342347
BytesRef term = BytesRef.deepCopyOf(topSegmentTerms.term);
343348
groupByTerm(segmentTermsQueue, topSegmentTerms, groupedSegmentTerms);
344349
combineSegmentsFields(groupedSegmentTerms, fieldPostingsMap);
@@ -364,9 +369,10 @@ private Map<String, MergingFieldTerms> createMergingFieldTermsMap(
364369
return fieldTermsMap;
365370
}
366371

367-
private TermIteratorQueue<SegmentTerms> createSegmentTermsQueue(
372+
private PriorityQueue<TermIterator<SegmentTerms>> createSegmentTermsQueue(
368373
List<TermIterator<SegmentTerms>> segmentTermsList) throws IOException {
369-
TermIteratorQueue<SegmentTerms> segmentQueue = new TermIteratorQueue<>(segmentTermsList.size());
374+
PriorityQueue<TermIterator<SegmentTerms>> segmentQueue =
375+
PriorityQueue.usingComparator(segmentTermsList.size(), Comparator.naturalOrder());
370376
for (TermIterator<SegmentTerms> segmentTerms : segmentTermsList) {
371377
if (segmentTerms.nextTerm()) {
372378
// There is at least one term in the segment
@@ -447,26 +453,7 @@ PostingsEnum getPostings(String fieldName, PostingsEnum reuse, int flags) throws
447453
}
448454
}
449455

450-
private class TermIteratorQueue<T> extends PriorityQueue<TermIterator<T>> {
451-
452-
TermIteratorQueue(int numFields) {
453-
super(numFields);
454-
}
455-
456-
@Override
457-
protected boolean lessThan(TermIterator<T> a, TermIterator<T> b) {
458-
return a.compareTo(b) < 0;
459-
}
460-
461-
TermIterator<T> popTerms() {
462-
TermIterator<T> topTerms = pop();
463-
assert topTerms != null;
464-
assert topTerms.term != null;
465-
return topTerms;
466-
}
467-
}
468-
469-
private abstract class TermIterator<T> implements Comparable<TermIterator<T>> {
456+
private abstract static class TermIterator<T> implements Comparable<TermIterator<T>> {
470457

471458
BytesRef term;
472459

@@ -485,7 +472,7 @@ public int compareTo(TermIterator<T> other) {
485472
abstract int compareSecondary(TermIterator<T> other);
486473
}
487474

488-
private class FieldTerms extends TermIterator<FieldTerms> {
475+
private static class FieldTerms extends TermIterator<FieldTerms> {
489476

490477
final FieldMetadata fieldMetadata;
491478
final TermsEnum termsEnum;
@@ -520,7 +507,7 @@ void resetIterator(BytesRef term, List<SegmentPostings> segmentPostingsList) {
520507
}
521508
}
522509

523-
private class SegmentTerms extends TermIterator<SegmentTerms> {
510+
private static class SegmentTerms extends TermIterator<SegmentTerms> {
524511

525512
private final Integer segmentIndex;
526513
private final STMergingBlockReader mergingBlockReader;

lucene/core/src/java/org/apache/lucene/index/DocIDMerger.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,12 @@ private SortedDocIDMerger(List<T> subs, int maxCount) throws IOException {
145145
}
146146
this.subs = subs;
147147
queue =
148-
new PriorityQueue<T>(maxCount - 1) {
149-
@Override
150-
protected boolean lessThan(Sub a, Sub b) {
151-
assert a.mappedDocID != b.mappedDocID;
152-
return a.mappedDocID < b.mappedDocID;
153-
}
154-
};
148+
PriorityQueue.usingComparator(
149+
maxCount - 1,
150+
(a, b) -> {
151+
assert a.mappedDocID != b.mappedDocID;
152+
return Integer.compare(a.mappedDocID, b.mappedDocID);
153+
});
155154
reset();
156155
}
157156

lucene/core/src/java/org/apache/lucene/index/MultiSorter.java

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.lucene.index;
1919

2020
import java.io.IOException;
21+
import java.util.Comparator;
2122
import java.util.List;
2223
import org.apache.lucene.index.MergeState.DocMap;
2324
import org.apache.lucene.search.Sort;
@@ -82,24 +83,22 @@ static MergeState.DocMap[] sort(Sort sort, List<CodecReader> readers) throws IOE
8283
int leafCount = readers.size();
8384

8485
PriorityQueue<LeafAndDocID> queue =
85-
new PriorityQueue<LeafAndDocID>(leafCount) {
86-
@Override
87-
public boolean lessThan(LeafAndDocID a, LeafAndDocID b) {
88-
for (int i = 0; i < comparables.length; i++) {
89-
int cmp = Long.compare(a.valuesAsComparableLongs[i], b.valuesAsComparableLongs[i]);
90-
if (cmp != 0) {
91-
return reverseMuls[i] * cmp < 0;
92-
}
93-
}
94-
95-
// tie-break by docID natural order:
96-
if (a.readerIndex != b.readerIndex) {
97-
return a.readerIndex < b.readerIndex;
98-
} else {
99-
return a.docID < b.docID;
100-
}
101-
}
102-
};
86+
PriorityQueue.usingComparator(
87+
leafCount,
88+
((Comparator<LeafAndDocID>)
89+
(a, b) -> {
90+
for (int i = 0; i < comparables.length; i++) {
91+
int cmp =
92+
Long.compare(
93+
a.valuesAsComparableLongs[i], b.valuesAsComparableLongs[i]);
94+
if (cmp != 0) {
95+
return reverseMuls[i] * cmp;
96+
}
97+
}
98+
return 0;
99+
})
100+
.thenComparingInt(ld -> ld.readerIndex)
101+
.thenComparingInt(ld -> ld.docID));
103102

104103
PackedLongValues.Builder[] builders = new PackedLongValues.Builder[leafCount];
105104

lucene/core/src/java/org/apache/lucene/search/DisjunctionScorer.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.ArrayList;
2121
import java.util.Collection;
2222
import java.util.List;
23+
import org.apache.lucene.util.FloatComparator;
2324
import org.apache.lucene.util.PriorityQueue;
2425

2526
/** Base class for Scorers that score disjunctions. */
@@ -88,12 +89,7 @@ private TwoPhase(DocIdSetIterator approximation, float matchCost) {
8889
super(approximation);
8990
this.matchCost = matchCost;
9091
unverifiedMatches =
91-
new PriorityQueue<DisiWrapper>(numClauses) {
92-
@Override
93-
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
94-
return a.matchCost < b.matchCost;
95-
}
96-
};
92+
PriorityQueue.usingComparator(numClauses, FloatComparator.comparing(d -> d.matchCost));
9793
}
9894

9995
DisiWrapper getSubMatches() throws IOException {

lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -511,18 +511,11 @@ public List<Impact> getImpacts(int level) {
511511
}
512512

513513
PriorityQueue<SubIterator> pq =
514-
new PriorityQueue<>(impacts.length) {
515-
@Override
516-
protected boolean lessThan(SubIterator a, SubIterator b) {
517-
if (a.current == null) { // means iteration is finished
518-
return false;
519-
}
520-
if (b.current == null) {
521-
return true;
522-
}
523-
return Long.compareUnsigned(a.current.norm, b.current.norm) < 0;
524-
}
525-
};
514+
PriorityQueue.usingComparator(
515+
impacts.length,
516+
Comparator.comparing(
517+
it -> it.current,
518+
Comparator.nullsLast((a, b) -> Long.compareUnsigned(a.norm, b.norm))));
526519
for (List<Impact> impacts : toMerge) {
527520
pq.add(new SubIterator(impacts.iterator()));
528521
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.util;
18+
19+
import java.util.Comparator;
20+
21+
public interface FloatComparator {
22+
interface ToFloatFunction<T> {
23+
float applyAsFloat(T obj);
24+
}
25+
26+
static <T> Comparator<T> comparing(ToFloatFunction<T> function) {
27+
return (a, b) -> Float.compare(function.applyAsFloat(a), function.applyAsFloat(b));
28+
}
29+
30+
int compare(float f1, float f2);
31+
}

lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
*/
3838
public abstract class PriorityQueue<T> implements Iterable<T> {
3939

40+
/** Create a {@code PriorityQueue} that orders elements using the specified {@code comparator} */
4041
public static <T> PriorityQueue<T> usingComparator(
4142
int maxSize, Comparator<? super T> comparator) {
4243
return new PriorityQueue<>(maxSize) {
@@ -47,6 +48,17 @@ protected boolean lessThan(T a, T b) {
4748
};
4849
}
4950

51+
/** Create a {@code PriorityQueue} that orders elements using the specified {@code comparator} */
52+
public static <T> PriorityQueue<T> usingComparator(
53+
int maxSize, Supplier<T> sentinelObjectSupplier, Comparator<? super T> comparator) {
54+
return new PriorityQueue<>(maxSize, sentinelObjectSupplier) {
55+
@Override
56+
protected boolean lessThan(T a, T b) {
57+
return comparator.compare(a, b) < 0;
58+
}
59+
};
60+
}
61+
5062
private int size = 0;
5163
private final int maxSize;
5264
private final T[] heap;

lucene/facet/src/java/org/apache/lucene/facet/facetset/MatchingFacetSetsCounts.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,8 @@ public FacetResult getTopChildren(int topN, String dim, String... path) throws I
161161
topN = Math.min(topN, counts.length);
162162

163163
PriorityQueue<Entry> pq =
164-
new PriorityQueue<>(topN, () -> new Entry("", 0)) {
165-
@Override
166-
protected boolean lessThan(Entry a, Entry b) {
167-
return compare(a.count, b.count, a.label, b.label) < 0;
168-
}
169-
};
164+
PriorityQueue.usingComparator(
165+
topN, () -> new Entry("", 0), (a, b) -> compare(a.count, b.count, a.label, b.label));
170166

171167
int childCount = 0;
172168
Entry reuse = pq.top();

0 commit comments

Comments
 (0)