Skip to content

Commit e6e8ce8

Browse files
committed
Fixed lock-freedom of send/offer on closed channels and related bugs.
Conflated[Broadcast]Channel hang of concurrent offer/close is fixed.
1 parent 584ae3d commit e6e8ce8

File tree

6 files changed

+324
-27
lines changed

6 files changed

+324
-27
lines changed

kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/channels/AbstractChannel.kt

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,40 @@ public abstract class AbstractSendChannel<E> : SendChannel<E> {
9898
queue.removeFirstIfIsInstanceOfOrPeekIf<Send> { it is Closed<*> }
9999

100100
/**
101+
* Queues buffered element, returns null on success or
102+
* returns node reference if it was already closed or is waiting for receive.
101103
* @suppress **This is unstable API and it is subject to change.**
102104
*/
103-
protected fun sendBuffered(element: E): Boolean =
104-
queue.addLastIfPrev(SendBuffered(element), { it !is ReceiveOrClosed<*> })
105+
protected fun sendBuffered(element: E): ReceiveOrClosed<*>? {
106+
queue.addLastIfPrev(SendBuffered(element), { prev ->
107+
if (prev is ReceiveOrClosed<*>) return@sendBuffered prev
108+
true
109+
})
110+
return null
111+
}
105112

106113
/**
114+
* Queues conflated element, returns null on success or
115+
* returns node reference if it was already closed or is waiting for receive.
107116
* @suppress **This is unstable API and it is subject to change.**
108117
*/
109-
protected fun sendConflated(element: E): Boolean {
118+
protected fun sendConflated(element: E): ReceiveOrClosed<*>? {
110119
val node = SendBuffered(element)
111-
if (!queue.addLastIfPrev(node, { it !is ReceiveOrClosed<*> })) return false
112-
// remove previous SendBuffered
120+
queue.addLastIfPrev(node, { prev ->
121+
if (prev is ReceiveOrClosed<*>) return@sendConflated prev
122+
true
123+
})
124+
conflatePreviousSendBuffered(node)
125+
return null
126+
}
127+
128+
/**
129+
* @suppress **This is unstable API and it is subject to change.**
130+
*/
131+
protected fun conflatePreviousSendBuffered(node: LockFreeLinkedListNode) {
113132
val prev = node.prev
114133
if (prev is SendBuffered<*>)
115134
prev.remove()
116-
return true
117135
}
118136

119137
/**
@@ -173,32 +191,56 @@ public abstract class AbstractSendChannel<E> : SendChannel<E> {
173191
private suspend fun sendSuspend(element: E): Unit = suspendAtomicCancellableCoroutine(holdCancellability = true) sc@ { cont ->
174192
val send = SendElement(element, cont)
175193
loop@ while (true) {
176-
if (enqueueSend(send)) {
177-
cont.initCancellability() // make it properly cancellable
178-
cont.removeOnCancel(send)
179-
return@sc
194+
val enqueueResult = enqueueSend(send)
195+
when (enqueueResult) {
196+
null -> { // enqueued successfully
197+
cont.initCancellability() // make it properly cancellable
198+
cont.removeOnCancel(send)
199+
return@sc
200+
}
201+
is Closed<*> -> {
202+
cont.resumeWithException(enqueueResult.sendException)
203+
return@sc
204+
}
180205
}
181-
// hm... something is not right. try to offer
182-
val result = offerInternal(element)
206+
// hm... receiver is waiting or buffer is not full. try to offer
207+
val offerResult = offerInternal(element)
183208
when {
184-
result === OFFER_SUCCESS -> {
209+
offerResult === OFFER_SUCCESS -> {
185210
cont.resume(Unit)
186211
return@sc
187212
}
188-
result === OFFER_FAILED -> continue@loop
189-
result is Closed<*> -> {
190-
cont.resumeWithException(result.sendException)
213+
offerResult === OFFER_FAILED -> continue@loop
214+
offerResult is Closed<*> -> {
215+
cont.resumeWithException(offerResult.sendException)
191216
return@sc
192217
}
193-
else -> error("offerInternal returned $result")
218+
else -> error("offerInternal returned $offerResult")
194219
}
195220
}
196221
}
197222

198-
private fun enqueueSend(send: SendElement) =
199-
if (isBufferAlwaysFull)
200-
queue.addLastIfPrev(send, { it !is ReceiveOrClosed<*> }) else
201-
queue.addLastIfPrevAndIf(send, { it !is ReceiveOrClosed<*> }, { isBufferFull })
223+
/**
224+
* Result is:
225+
* * null -- successfully enqueued
226+
* * ENQUEUE_FAILED -- buffer is not full (should not enqueue)
227+
* * ReceiveOrClosed<*> -- receiver is waiting or it is closed (should not enqueue)
228+
*/
229+
private fun enqueueSend(send: SendElement): Any? {
230+
if (isBufferAlwaysFull) {
231+
queue.addLastIfPrev(send, { prev ->
232+
if (prev is ReceiveOrClosed<*>) return@enqueueSend prev
233+
true
234+
})
235+
} else {
236+
if (!queue.addLastIfPrevAndIf(send, { prev ->
237+
if (prev is ReceiveOrClosed<*>) return@enqueueSend prev
238+
true
239+
}, { isBufferFull }))
240+
return ENQUEUE_FAILED
241+
}
242+
return null
243+
}
202244

203245
public override fun close(cause: Throwable?): Boolean {
204246
val closed = Closed<E>(cause)
@@ -207,6 +249,7 @@ public abstract class AbstractSendChannel<E> : SendChannel<E> {
207249
if (receive == null) {
208250
// queue empty or has only senders -- try add last "Closed" item to the queue
209251
if (queue.addLastIfPrev(closed, { it !is ReceiveOrClosed<*> })) {
252+
onClosed(closed)
210253
afterClose(cause)
211254
return true
212255
}
@@ -218,6 +261,12 @@ public abstract class AbstractSendChannel<E> : SendChannel<E> {
218261
}
219262
}
220263

264+
/**
265+
* Invoked when [Closed] element was just added.
266+
* @suppress **This is unstable API and it is subject to change.**
267+
*/
268+
protected open fun onClosed(closed: Closed<E>) {}
269+
221270
/**
222271
* Invoked after successful [close].
223272
*/
@@ -870,8 +919,8 @@ public class Closed<in E>(
870919
override val pollResult get() = this
871920
override fun tryResumeSend(idempotent: Any?): Any? = CLOSE_RESUMED
872921
override fun completeResumeSend(token: Any) { check(token === CLOSE_RESUMED) }
873-
override fun tryResumeReceive(value: E, idempotent: Any?): Any? = throw sendException
874-
override fun completeResumeReceive(token: Any) = throw sendException
922+
override fun tryResumeReceive(value: E, idempotent: Any?): Any? = CLOSE_RESUMED
923+
override fun completeResumeReceive(token: Any) { check(token === CLOSE_RESUMED) }
875924
override fun toString(): String = "Closed[$closeCause]"
876925
}
877926

kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/channels/ConflatedChannel.kt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,27 @@ public open class ConflatedChannel<E> : AbstractChannel<E>() {
3434
protected final override val isBufferAlwaysFull: Boolean get() = false
3535
protected final override val isBufferFull: Boolean get() = false
3636

37+
/**
38+
* This implementation conflates last sent item when channel is closed.
39+
* @suppress **This is unstable API and it is subject to change.**
40+
*/
41+
override fun onClosed(closed: Closed<E>) {
42+
conflatePreviousSendBuffered(closed)
43+
}
44+
3745
// result is always `OFFER_SUCCESS | Closed`
3846
protected override fun offerInternal(element: E): Any {
3947
while (true) {
4048
val result = super.offerInternal(element)
4149
when {
4250
result === OFFER_SUCCESS -> return OFFER_SUCCESS
4351
result === OFFER_FAILED -> { // try to buffer
44-
if (sendConflated(element))
45-
return OFFER_SUCCESS
52+
val sendResult = sendConflated(element)
53+
when (sendResult) {
54+
null -> return OFFER_SUCCESS
55+
is Closed<*> -> return sendResult
56+
}
57+
// otherwise there was receiver in queue, retry super.offerInternal
4658
}
4759
result is Closed<*> -> return result
4860
else -> error("Invalid offerInternal result $result")

kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/channels/LinkedListChannel.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@ public open class LinkedListChannel<E> : AbstractChannel<E>() {
3838
when {
3939
result === OFFER_SUCCESS -> return OFFER_SUCCESS
4040
result === OFFER_FAILED -> { // try to buffer
41-
if (sendBuffered(element))
42-
return OFFER_SUCCESS
41+
val sendResult = sendBuffered(element)
42+
when (sendResult) {
43+
null -> return OFFER_SUCCESS
44+
is Closed<*> -> return sendResult
45+
}
46+
// otherwise there was receiver in queue, retry super.offerInternal
4347
}
4448
result is Closed<*> -> return result
4549
else -> error("Invalid offerInternal result $result")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Copyright 2016-2017 JetBrains s.r.o.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package kotlinx.coroutines.experimental.channels
18+
19+
import kotlinx.coroutines.experimental.*
20+
import org.hamcrest.MatcherAssert.assertThat
21+
import org.hamcrest.core.IsEqual
22+
import org.junit.Test
23+
import java.util.concurrent.atomic.AtomicInteger
24+
25+
class ConflatedBroadcastChannelNotifyStressTest : TestBase() {
26+
val nSenders = 2
27+
val nReceivers = 3
28+
val nEvents = 1_000_000 * stressTestMultiplier
29+
val timeLimit = 30_000L * stressTestMultiplier // 30 sec
30+
31+
val broadcast = ConflatedBroadcastChannel<Int>()
32+
33+
val sendersCompleted = AtomicInteger()
34+
val receiversCompleted = AtomicInteger()
35+
val sentTotal = AtomicInteger()
36+
val receivedTotal = AtomicInteger()
37+
38+
@Test
39+
fun testStressNotify()= runBlocking<Unit> {
40+
val senders = List(nSenders) { senderId ->
41+
launch(CommonPool + CoroutineName("Sender$senderId")) {
42+
repeat(nEvents) { i ->
43+
if (i % nSenders == senderId) {
44+
broadcast.offer(i)
45+
sentTotal.incrementAndGet()
46+
yield()
47+
}
48+
}
49+
sendersCompleted.incrementAndGet()
50+
}
51+
}
52+
val receivers = List(nReceivers) { receiverId ->
53+
launch(CommonPool + CoroutineName("Receiver$receiverId")) {
54+
var last = -1
55+
while (isActive) {
56+
val i = waitForEvent()
57+
if (i > last) {
58+
receivedTotal.incrementAndGet()
59+
last = i
60+
}
61+
if (i >= nEvents) break
62+
yield()
63+
}
64+
receiversCompleted.incrementAndGet()
65+
}
66+
}
67+
// print progress
68+
val progressJob = launch(context) {
69+
var seconds = 0
70+
while (true) {
71+
delay(1000)
72+
println("${++seconds}: Sent ${sentTotal.get()}, received ${receivedTotal.get()}")
73+
}
74+
}
75+
try {
76+
withTimeout(timeLimit) {
77+
senders.forEach { it.join() }
78+
broadcast.offer(nEvents) // last event to signal receivers termination
79+
receivers.forEach { it.join() }
80+
}
81+
} catch (e: CancellationException) {
82+
println("!!! Test timed out $e")
83+
}
84+
progressJob.cancel()
85+
println("Tested with nSenders=$nSenders, nReceivers=$nReceivers")
86+
println("Completed successfully ${sendersCompleted.get()} sender coroutines")
87+
println("Completed successfully ${receiversCompleted.get()} receiver coroutines")
88+
println(" Sent ${sentTotal.get()} events")
89+
println(" Received ${receivedTotal.get()} events")
90+
assertThat(sendersCompleted.get(), IsEqual(nSenders))
91+
assertThat(receiversCompleted.get(), IsEqual(nReceivers))
92+
assertThat(sentTotal.get(), IsEqual(nEvents))
93+
}
94+
95+
suspend fun waitForEvent(): Int =
96+
broadcast.open().use {
97+
it.receive()
98+
}
99+
}

0 commit comments

Comments
 (0)