diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index ba3218b0b8d5..26fdeffa75cb 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -122,7 +122,7 @@ Improvements Optimizations --------------------- -(No changes) +* GITHUB#15140: Optimize TopScoreDocCollector with TernaryLongHeap for improved performance over Binary-LongHeap. (Ramakrishna Chilaka) Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/TopScoreDocCollector.java b/lucene/core/src/java/org/apache/lucene/search/TopScoreDocCollector.java index 2ab46cb38362..e878f6f880b8 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopScoreDocCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopScoreDocCollector.java @@ -18,7 +18,7 @@ import java.io.IOException; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.util.LongHeap; +import org.apache.lucene.util.TernaryLongHeap; /** * A {@link Collector} implementation that collects the top-scoring hits, returning them as a {@link @@ -33,7 +33,7 @@ public class TopScoreDocCollector extends TopDocsCollector { private final ScoreDoc after; - private final LongHeap heap; + private final TernaryLongHeap heap; final int totalHitsThreshold; final MaxScoreAccumulator minScoreAcc; @@ -41,7 +41,7 @@ public class TopScoreDocCollector extends TopDocsCollector { TopScoreDocCollector( int numHits, ScoreDoc after, int totalHitsThreshold, MaxScoreAccumulator minScoreAcc) { super(null); - this.heap = new LongHeap(numHits, DocScoreEncoder.LEAST_COMPETITIVE_CODE); + this.heap = new TernaryLongHeap(numHits, DocScoreEncoder.LEAST_COMPETITIVE_CODE); this.after = after; this.totalHitsThreshold = totalHitsThreshold; this.minScoreAcc = minScoreAcc; diff --git a/lucene/core/src/java/org/apache/lucene/util/LongHeap.java b/lucene/core/src/java/org/apache/lucene/util/LongHeap.java index 2cca6a3c524b..3f1115e09076 100644 --- a/lucene/core/src/java/org/apache/lucene/util/LongHeap.java +++ b/lucene/core/src/java/org/apache/lucene/util/LongHeap.java @@ -22,15 +22,15 @@ * A min heap that stores longs; a primitive priority queue that like all priority queues maintains * a partial ordering of its elements such that the least element can always be found in constant * time. Put()'s and pop()'s require log(size). This heap provides unbounded growth via {@link - * #push(long)}, and bounded-size insertion based on its nominal maxSize via {@link + * #push(long)}, and bounded-size insertion based on its initial capacity via {@link * #insertWithOverflow(long)}. The heap is a min heap, meaning that the top element is the lowest - * value of the heap. + * value of the heap. LongHeap implements 2-ary heap. * * @lucene.internal */ public final class LongHeap { - private final int maxSize; + private final int initialCapacity; private long[] heap; private int size = 0; @@ -50,19 +50,21 @@ public LongHeap(int size, long initialValue) { /** * Create an empty priority queue of the configured initial size. * - * @param maxSize the maximum size of the heap, or if negative, the initial size of an unbounded - * heap + * @param initialCapacity the initial capacity of the heap */ - public LongHeap(int maxSize) { + public LongHeap(int initialCapacity) { final int heapSize; - if (maxSize < 1 || maxSize >= ArrayUtil.MAX_ARRAY_LENGTH) { + if (initialCapacity < 1 || initialCapacity >= ArrayUtil.MAX_ARRAY_LENGTH) { // Throw exception to prevent confusing OOME: throw new IllegalArgumentException( - "maxSize must be > 0 and < " + (ArrayUtil.MAX_ARRAY_LENGTH - 1) + "; got: " + maxSize); + "initialCapacity must be > 0 and < " + + (ArrayUtil.MAX_ARRAY_LENGTH - 1) + + "; got: " + + initialCapacity); } // NOTE: we add +1 because all access to heap is 1-based not 0-based. heap[0] is unused. - heapSize = maxSize + 1; - this.maxSize = maxSize; + heapSize = initialCapacity + 1; + this.initialCapacity = initialCapacity; this.heap = new long[heapSize]; } @@ -83,13 +85,13 @@ public long push(long element) { /** * Adds a value to an LongHeap in log(size) time. If the number of values would exceed the heap's - * maxSize, the least value is discarded. + * initialCapacity, the least value is discarded. * * @return whether the value was added (unless the heap is full, or the new value is less than the * top value) */ public boolean insertWithOverflow(long value) { - if (size >= maxSize) { + if (size >= initialCapacity) { if (value < heap[1]) { return false; } diff --git a/lucene/core/src/java/org/apache/lucene/util/TernaryLongHeap.java b/lucene/core/src/java/org/apache/lucene/util/TernaryLongHeap.java new file mode 100644 index 000000000000..e10bfaecb8ae --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/TernaryLongHeap.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +import java.util.Arrays; + +/** + * A ternary min heap that stores longs; a primitive priority queue that like all priority queues + * maintains a partial ordering of its elements such that the least element can always be found in + * constant time. Put()'s and pop()'s require log_3(size). This heap provides unbounded growth via + * {@link #push(long)}, and bounded-size insertion based on its nominal initial capacity via {@link + * #insertWithOverflow(long)}. The heap is a min heap, meaning that the top element is the lowest + * value of the heap. TernaryLongHeap implements 3-ary heap. + * + * @lucene.internal + */ +public final class TernaryLongHeap { + + private final int initialCapacity; + + private long[] heap; + private int size = 0; + private static final int ARITY = 3; + + /** + * Constructs a heap with specified size and initializes all elements with the given value. + * + * @param size the number of elements to initialize in the heap. + * @param initialValue the value to fill the heap with. + */ + public TernaryLongHeap(int size, long initialValue) { + this(size <= 0 ? 1 : size); + Arrays.fill(heap, 1, size + 1, initialValue); + this.size = size; + } + + /** + * Create an empty priority queue of the configured initial size. + * + * @param initialCapacity the initial capacity of the heap + */ + public TernaryLongHeap(int initialCapacity) { + if (initialCapacity < 1 || initialCapacity >= ArrayUtil.MAX_ARRAY_LENGTH) { + // Throw exception to prevent confusing OOME: + throw new IllegalArgumentException( + "initialCapacity must be > 0 and < " + + (ArrayUtil.MAX_ARRAY_LENGTH - 1) + + "; got: " + + initialCapacity); + } + // NOTE: we add +1 because all access to heap is 1-based not 0-based. heap[0] is unused. + final int heapSize = initialCapacity + 1; + this.initialCapacity = initialCapacity; + this.heap = new long[heapSize]; + } + + /** + * Adds a value in log(size) time. Grows unbounded as needed to accommodate new values. + * + * @return the new 'top' element in the queue. + */ + public long push(long element) { + size++; + if (size == heap.length) { + heap = ArrayUtil.grow(heap, (size * 3 + 1) / 2); + } + heap[size] = element; + TernaryLongHeap.upHeap(heap, size, ARITY); + return heap[1]; + } + + /** + * Adds a value to an TernaryLongHeap in log(size) time. If the number of values would exceed the + * heap's initialCapacity, the least value is discarded. + * + * @return whether the value was added (unless the heap is full, or the new value is less than the + * top value) + */ + public boolean insertWithOverflow(long value) { + if (size >= initialCapacity) { + if (value < heap[1]) { + return false; + } + updateTop(value); + return true; + } + push(value); + return true; + } + + /** + * Returns the least element of the TernaryLongHeap in constant time. It is up to the caller to + * verify that the heap is not empty; no checking is done, and if no elements have been added, 0 + * is returned. + */ + public long top() { + return heap[1]; + } + + /** + * Removes and returns the least element of the PriorityQueue in log(size) time. + * + * @throws IllegalStateException if the TernaryLongHeap is empty. + */ + public long pop() { + if (size > 0) { + long result = heap[1]; // save first value + heap[1] = heap[size]; // move last to first + size--; + TernaryLongHeap.downHeap(heap, 1, size, ARITY); // adjust heap + return result; + } else { + throw new IllegalStateException("The heap is empty"); + } + } + + /** + * Replace the top of the pq with {@code newTop}. Should be called when the top value changes. + * Still log(n) worst case, but it's at least twice as fast to + * + *
+   * pq.updateTop(value);
+   * 
+ * + *

instead of + * + *

+   * pq.pop();
+   * pq.push(value);
+   * 
+ * + *

Calling this method on an empty TernaryLongHeap has no visible effect. + * + * @param value the new element that is less than the current top. + * @return the new 'top' element after shuffling the heap. + */ + public long updateTop(long value) { + heap[1] = value; + TernaryLongHeap.downHeap(heap, 1, size, ARITY); + return heap[1]; + } + + /** Returns the number of elements currently stored in the PriorityQueue. */ + public int size() { + return size; + } + + /** Removes all entries from the PriorityQueue. */ + public void clear() { + size = 0; + } + + public void pushAll(TernaryLongHeap other) { + for (int i = 1; i <= other.size; i++) { + push(other.heap[i]); + } + } + + /** + * Return the element at the ith location in the heap array. Use for iterating over elements when + * the order doesn't matter. Note that the valid arguments range from [1, size]. + */ + public long get(int i) { + return heap[i]; + } + + /** + * This method returns the internal heap array. + * + * @lucene.internal + */ + // pkg-private for testing + long[] getHeapArray() { + return heap; + } + + /** + * Restores heap order by moving an element up the heap until it finds its proper position. Works + * with heaps of any arity (number of children per node). + * + * @param heap the heap array (1-based indexing) + * @param i the index of the element to move up + * @param arity the number of children each node can have + */ + static void upHeap(long[] heap, int i, int arity) { + final long value = heap[i]; // save bottom value + while (i > 1) { + // parent formula for 1-based indexing + final int parent = ((i - 2) / arity) + 1; + final long parentVal = heap[parent]; + if (value >= parentVal) break; + heap[i] = parentVal; // shift parent down + i = parent; + } + heap[i] = value; // install saved value + } + + /** + * Restores heap order by moving an element down the heap until it finds its proper position. + * Works with heaps of any arity (number of children per node). + * + * @param heap the heap array (1-based indexing) + * @param i the index of the element to move down + * @param size the current size of the heap + * @param arity the number of children each node can have + */ + static void downHeap(long[] heap, int i, int size, int arity) { + long value = heap[i]; // save top value + for (; ; ) { + // first child formula for 1-based indexing + int firstChild = arity * (i - 1) + 2; + if (firstChild > size) break; // i is a leaf + + int lastChild = Math.min(firstChild + arity - 1, size); + + // find the smallest child in [firstChild, lastChild] + int best = firstChild; + long bestVal = heap[firstChild]; + + for (int c = firstChild + 1; c <= lastChild; c++) { + final long v = heap[c]; + if (v < bestVal) { + bestVal = v; + best = c; + } + } + + if (bestVal >= value) break; + + heap[i] = bestVal; + i = best; + } + heap[i] = value; // install saved value + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/TestTernaryLongHeap.java b/lucene/core/src/test/org/apache/lucene/util/TestTernaryLongHeap.java new file mode 100644 index 000000000000..335116e7324b --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/TestTernaryLongHeap.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +import java.util.ArrayList; +import java.util.Random; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; + +public class TestTernaryLongHeap extends LuceneTestCase { + + private static void checkValidity(TernaryLongHeap heap) { + long[] heapArray = heap.getHeapArray(); + int d = 3; + int size = heap.size(); + for (int parent = 1; parent <= size; parent++) { + int firstChild = d * (parent - 1) + 2; + int lastChild = Math.min(firstChild + d - 1, size); + for (int c = firstChild; c <= lastChild; c++) { + assert heapArray[parent] <= heapArray[c]; + } + } + } + + public void testPQ() { + testPQ(atLeast(10000), random()); + } + + public static void testPQ(int count, Random gen) { + TernaryLongHeap pq = new TernaryLongHeap(count); + long sum = 0, sum2 = 0; + + for (int i = 0; i < count; i++) { + long next = gen.nextLong(); + sum += next; + pq.push(next); + } + + long last = Long.MIN_VALUE; + for (long i = 0; i < count; i++) { + long next = pq.pop(); + assertTrue(next >= last); + last = next; + sum2 += last; + } + + assertEquals(sum, sum2); + } + + public void testClear() { + TernaryLongHeap pq = new TernaryLongHeap(3); + pq.push(2); + pq.push(3); + pq.push(1); + assertEquals(3, pq.size()); + pq.clear(); + assertEquals(0, pq.size()); + } + + public void testExceedBounds() { + TernaryLongHeap pq = new TernaryLongHeap(1); + pq.push(2); + pq.push(0); + assertEquals(2, pq.size()); // the heap has been extended to a new max size + assertEquals(0, pq.top()); + } + + public void testFixedSize() { + TernaryLongHeap pq = new TernaryLongHeap(3); + pq.insertWithOverflow(2); + pq.insertWithOverflow(3); + pq.insertWithOverflow(1); + pq.insertWithOverflow(5); + pq.insertWithOverflow(7); + pq.insertWithOverflow(1); + assertEquals(3, pq.size()); + assertEquals(3, pq.top()); + } + + public void testDuplicateValues() { + TernaryLongHeap pq = new TernaryLongHeap(3); + pq.push(2); + pq.push(3); + pq.push(1); + assertEquals(1, pq.top()); + pq.updateTop(3); + assertEquals(3, pq.size()); + assertArrayEquals(new long[] {0, 2, 3, 3}, pq.getHeapArray()); + } + + public void testInsertions() { + Random random = random(); + int numDocsInPQ = TestUtil.nextInt(random, 1, 100); + TernaryLongHeap pq = new TernaryLongHeap(numDocsInPQ); + Long lastLeast = null; + + // Basic insertion of new content + ArrayList sds = new ArrayList<>(numDocsInPQ); + for (int i = 0; i < numDocsInPQ * 10; i++) { + long newEntry = Math.abs(random.nextLong()); + sds.add(newEntry); + pq.insertWithOverflow(newEntry); + checkValidity(pq); + long newLeast = pq.top(); + if ((lastLeast != null) && (newLeast != newEntry) && (newLeast != lastLeast)) { + // If there has been a change of least entry and it wasn't our new + // addition we expect the scores to increase + assertTrue(newLeast <= newEntry); + assertTrue(newLeast >= lastLeast); + } + lastLeast = newLeast; + } + } + + public void testInvalid() { + expectThrows(IllegalArgumentException.class, () -> new TernaryLongHeap(-1)); + expectThrows(IllegalArgumentException.class, () -> new TernaryLongHeap(0)); + expectThrows( + IllegalArgumentException.class, () -> new TernaryLongHeap(ArrayUtil.MAX_ARRAY_LENGTH)); + } + + public void testUnbounded() { + int initialSize = random().nextInt(10) + 1; + TernaryLongHeap pq = new TernaryLongHeap(initialSize); + int num = random().nextInt(100) + 1; + long maxValue = Long.MIN_VALUE; + int count = 0; + for (int i = 0; i < num; i++) { + long value = random().nextLong(); + if (random().nextBoolean()) { + pq.push(value); + count++; + } else { + boolean full = pq.size() >= initialSize; + if (pq.insertWithOverflow(value)) { + if (full == false) { + count++; + } + } + } + maxValue = Math.max(maxValue, value); + } + assertEquals(count, pq.size()); + long last = Long.MIN_VALUE; + while (pq.size() > 0) { + long top = pq.top(); + long next = pq.pop(); + assertEquals(top, next); + --count; + assertTrue(next >= last); + last = next; + } + assertEquals(0, count); + assertEquals(maxValue, last); + } +}