Skip to content

Commit cafccfa

Browse files
committed
Fixed removal of arbitrary nodes from ThreadSafeHeap,
previously heap invariant could have been violated because of non-first removal.
1 parent e873c0a commit cafccfa

File tree

2 files changed

+60
-16
lines changed

2 files changed

+60
-16
lines changed

core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeap.kt

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,12 @@ public class ThreadSafeHeap<T> where T: ThreadSafeHeapNode, T: Comparable<T> {
8484
size--
8585
if (index < size) {
8686
swap(index, size)
87-
var i = index
88-
while (true) {
89-
var j = 2 * i + 1
90-
if (j >= size) break
91-
if (j + 1 < size && a[j + 1]!! < a[j]!!) j++
92-
if (a[i]!! <= a[j]!!) break
93-
swap(i, j)
94-
i = j
87+
val j = (index - 1) / 2
88+
if (index > 0 && a[index]!! < a[j]!!) {
89+
swap(index, j)
90+
siftUpFrom(j)
91+
} else {
92+
siftDownFrom(index)
9593
}
9694
}
9795
val result = a[size]!!
@@ -106,14 +104,27 @@ public class ThreadSafeHeap<T> where T: ThreadSafeHeapNode, T: Comparable<T> {
106104
var i = size++
107105
a[i] = node
108106
node.index = i
109-
while (i > 0) {
110-
val j = (i - 1) / 2
111-
if (a[j]!! <= a[i]!!) break
112-
swap(i, j)
113-
i = j
114-
}
107+
siftUpFrom(i)
108+
}
109+
110+
private tailrec fun siftUpFrom(i: Int) {
111+
if (i <= 0) return
112+
val a = a!!
113+
val j = (i - 1) / 2
114+
if (a[j]!! <= a[i]!!) return
115+
swap(i, j)
116+
siftUpFrom(j)
115117
}
116118

119+
private tailrec fun siftDownFrom(i: Int) {
120+
var j = 2 * i + 1
121+
if (j >= size) return
122+
val a = a!!
123+
if (j + 1 < size && a[j + 1]!! < a[j]!!) j++
124+
if (a[i]!! <= a[j]!!) return
125+
swap(i, j)
126+
siftDownFrom(j)
127+
}
117128

118129
@Suppress("UNCHECKED_CAST")
119130
private fun realloc(): Array<T?> {

core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeapTest.kt

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
package kotlinx.coroutines.experimental.internal
1818

19+
import kotlinx.coroutines.experimental.*
1920
import kotlin.test.*
2021
import java.util.*
2122

22-
class ThreadSafeHeapTest {
23+
class ThreadSafeHeapTest : TestBase() {
2324
class Node(val value: Int) : ThreadSafeHeapNode, Comparable<Node> {
2425
override var index = -1
2526
override fun compareTo(other: Node): Int = value.compareTo(other.value)
@@ -62,7 +63,7 @@ class ThreadSafeHeapTest {
6263

6364
@Test
6465
fun testRandomSort() {
65-
val n = 1000
66+
val n = 1000 * stressTestMultiplier
6667
val r = Random(1)
6768
val h = ThreadSafeHeap<Node>()
6869
val a = IntArray(n) { r.nextInt() }
@@ -71,4 +72,36 @@ class ThreadSafeHeapTest {
7172
repeat(n) { assertEquals(Node(a[it]), h.removeFirstOrNull()) }
7273
assertEquals(null, h.peek())
7374
}
75+
76+
@Test
77+
fun testRandomRemove() {
78+
val n = 1000 * stressTestMultiplier
79+
check(n % 2 == 0) { "Must be even" }
80+
val r = Random(1)
81+
val h = ThreadSafeHeap<Node>()
82+
val set = TreeSet<Node>()
83+
repeat(n) {
84+
val node = Node(r.nextInt())
85+
h.addLast(node)
86+
assertTrue(set.add(node))
87+
}
88+
while (!h.isEmpty) {
89+
// pick random node to remove
90+
val rndNode: Node
91+
while (true) {
92+
val tail = set.tailSet(Node(r.nextInt()))
93+
if (!tail.isEmpty()) {
94+
rndNode = tail.first()
95+
break
96+
}
97+
}
98+
assertTrue(set.remove(rndNode))
99+
assertTrue(h.remove(rndNode))
100+
// remove head and validate
101+
val headNode = h.removeFirstOrNull()!! // must not be null!!!
102+
assertTrue(headNode === set.first(), "Expected ${set.first()}, but found $headNode, remaining size ${h.size}")
103+
assertTrue(set.remove(headNode))
104+
assertEquals(set.size, h.size)
105+
}
106+
}
74107
}

0 commit comments

Comments
 (0)