diff --git a/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java b/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java index de4f8e81265a..b472573260f6 100644 --- a/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java +++ b/lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java @@ -19,8 +19,12 @@ import java.util.Arrays; import java.util.Collection; import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; +import java.util.Map; import java.util.NoSuchElementException; +import java.util.Set; import java.util.function.IntFunction; import java.util.function.Supplier; @@ -72,6 +76,7 @@ public static PriorityQueue usingComparator( private final int maxSize; private final T[] heap; private final LessThan lessThan; + private final Map> indexMap = new HashMap<>(); /** Create an empty priority queue of the configured size using the specified {@link LessThan}. */ public PriorityQueue(int maxSize, LessThan lessThan) { @@ -182,9 +187,9 @@ public void addAll(Collection elements) { * @return the new 'top' element in the queue. */ public final T add(T element) { - // don't modify size until we know heap access didn't throw AIOOB. int index = size + 1; heap[index] = element; + addIndex(element, index); size = index; upHeap(index); return heap[1]; @@ -272,6 +277,7 @@ public final int size() { public final void clear() { Arrays.fill(heap, 0, size + 1, null); size = 0; + indexMap.clear(); } /** @@ -280,20 +286,25 @@ public final void clear() { * constant remove time but the trade-off would be extra cost to all additions/insertions) */ public final boolean remove(T element) { - for (int i = 1; i <= size; i++) { - if (heap[i] == element) { - heap[i] = heap[size]; - heap[size] = null; // permit GC of objects - size--; - if (i <= size) { - if (!upHeap(i)) { - downHeap(i); - } - } - return true; - } + Set indices = indexMap.get(element); + if (indices == null || indices.isEmpty()) return false; + Integer idx = indices.iterator().next(); + removeIndex(element, idx); + T last = heap[size]; + if (idx == size) { + heap[size] = null; + size--; + return true; } - return false; + removeIndex(last, size); + heap[idx] = last; + addIndex(last, idx); + heap[size] = null; + size--; + if (!upHeap(idx)) { + downHeap(idx); + } + return true; } /** @@ -320,36 +331,54 @@ public T[] drainToArrayHighestFirst(IntFunction newArray) { return array; } - private boolean upHeap(int origPos) { - int i = origPos; - T node = heap[i]; // save bottom node - int j = i >>> 1; - while (j > 0 && lessThan.lessThan(node, heap[j])) { - heap[i] = heap[j]; // shift parents down - i = j; - j = j >>> 1; + private void addIndex(T element, int idx) { + indexMap.computeIfAbsent(element, k -> new HashSet<>()).add(idx); + } + + private void removeIndex(T element, int idx) { + Set indices = indexMap.get(element); + if (indices != null) { + indices.remove(idx); + if (indices.isEmpty()) indexMap.remove(element); } - heap[i] = node; // install saved node - return i != origPos; } - private void downHeap(int i) { - T node = heap[i]; // save top node - int j = i << 1; // find smaller child - int k = j + 1; - if (k <= size && lessThan.lessThan(heap[k], heap[j])) { - j = k; + protected boolean upHeap(int i) { + T node = heap[i]; + int j = i; + while (j > 1 && lessThan.lessThan(node, heap[j >> 1])) { + heap[j] = heap[j >> 1]; + removeIndex(heap[j], j >> 1); + addIndex(heap[j], j); + j >>= 1; } - while (j <= size && lessThan.lessThan(heap[j], node)) { - heap[i] = heap[j]; // shift up child - i = j; - j = i << 1; - k = j + 1; - if (k <= size && lessThan.lessThan(heap[k], heap[j])) { + heap[j] = node; + removeIndex(node, i); + addIndex(node, j); + return j < i; + } + + protected boolean downHeap(int i) { + T node = heap[i]; + int j = i; + int k; + while ((k = j << 1) <= size) { + if (k < size && lessThan.lessThan(heap[k + 1], heap[k])) { + k++; + } + if (lessThan.lessThan(heap[k], node)) { + heap[j] = heap[k]; + removeIndex(heap[j], k); + addIndex(heap[j], j); j = k; + } else { + break; } } - heap[i] = node; // install saved node + heap[j] = node; + removeIndex(node, i); + addIndex(node, j); + return j > i; } /** diff --git a/lucene/core/src/test/org/apache/lucene/util/TestPriorityQueue.java b/lucene/core/src/test/org/apache/lucene/util/TestPriorityQueue.java index 091834b0e380..9aa996b26294 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestPriorityQueue.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestPriorityQueue.java @@ -208,47 +208,80 @@ public void testAddAllDoesNotFitIntoQueue() { () -> pq.addAll(list)); } - /** Randomly add and remove elements, comparing against the reference java.util.PriorityQueue. */ - public void testRemovalsAndInsertions() { + /** Randomly remove elements, comparing against the reference java.util.PriorityQueue by value. */ + public void testRemovals() { int maxElement = RandomNumbers.randomIntBetween(random(), 1, 10_000); int size = maxElement / 2 + 1; - var reference = new java.util.PriorityQueue(); var pq = new IntegerQueue(size); - Random localRandom = nonAssertingRandom(random()); - - // Lucene's PriorityQueue.remove uses reference equality, not .equals to determine which - // elements - // to remove (!). HashMap ints = new HashMap<>(); + // Fill both queues with up to maxSize elements + for (int i = 0; i < size; i++) { + Integer element = ints.computeIfAbsent(localRandom.nextInt(maxElement), k -> k); + pq.add(element); + reference.add(element); + } + // Perform random removals and compare by value + for (int i = 0; i < size; i++) { + Integer element = ints.computeIfAbsent(localRandom.nextInt(maxElement), k -> k); + int pqCount = 0, refCount = 0; + for (Integer val : pq) if (val.equals(element)) pqCount++; + for (Integer val : reference) if (val.equals(element)) refCount++; + boolean pqRemoved = pq.remove(element); + boolean refRemoved = reference.remove(element); + assertEquals("remove() should return true if value was present", refCount > 0, pqRemoved); + assertEquals("remove() should return true if value was present", refCount > 0, refRemoved); + int pqCountAfter = 0, refCountAfter = 0; + for (Integer val : pq) if (val.equals(element)) pqCountAfter++; + for (Integer val : reference) if (val.equals(element)) refCountAfter++; + assertEquals("Should remove only one instance (value)", Math.max(0, refCount - 1), refCountAfter); + assertEquals("Should remove only one instance (value)", Math.max(0, pqCount - 1), pqCountAfter); + assertEquals("pq and reference should match counts after removal", refCountAfter, pqCountAfter); + assertEquals("size after removal should match", reference.size(), pq.size()); + Integer pqTop = pq.top(); + Integer refTop = reference.peek(); + if (pqTop != null && refTop != null) { + assertEquals("top() value difference after removal?", refTop.intValue(), pqTop.intValue()); + } else { + assertEquals("top() value difference after removal?", refTop, pqTop); + } + } + pq.checkValidity(); + } + /** Randomly add elements, comparing against the reference java.util.PriorityQueue by value. */ + public void testInsertions() { + int maxElement = RandomNumbers.randomIntBetween(random(), 1, 10_000); + int size = maxElement / 2 + 1; + var reference = new java.util.PriorityQueue(); + var pq = new IntegerQueue(size); + Random localRandom = nonAssertingRandom(random()); + HashMap ints = new HashMap<>(); for (int i = 0, iters = size * 2; i < iters; i++) { Integer element = ints.computeIfAbsent(localRandom.nextInt(maxElement), k -> k); - - var action = localRandom.nextInt(100); - if (action < 25) { - // removals, possibly misses. - assertEquals("remove() difference: " + i, reference.remove(element), pq.remove(element)); + var dropped = pq.insertWithOverflow(element); + reference.add(element); + Integer droppedReference; + if (reference.size() > size) { + droppedReference = reference.remove(); } else { - // additions. - var dropped = pq.insertWithOverflow(element); - - reference.add(element); - Integer droppedReference; - if (reference.size() > size) { - droppedReference = reference.remove(); - } else { - droppedReference = null; - } - - assertEquals("insertWithOverflow() difference.", dropped, droppedReference); + droppedReference = null; + } + if (dropped != null && droppedReference != null) { + assertEquals("insertWithOverflow() dropped value difference.", dropped.intValue(), droppedReference.intValue()); + } else { + assertEquals("insertWithOverflow() dropped value difference.", droppedReference, dropped); } - assertEquals("insertWithOverflow() size difference?", reference.size(), pq.size()); - assertEquals("top() difference?", reference.peek(), pq.top()); + Integer pqTop = pq.top(); + Integer refTop = reference.peek(); + if (pqTop != null && refTop != null) { + assertEquals("top() value difference?", refTop.intValue(), pqTop.intValue()); + } else { + assertEquals("top() value difference?", refTop, pqTop); + } } - pq.checkValidity(); }