Skip to content

Commit e15eae8

Browse files
author
Sergey Mashkov
committed
IO: fix continuous writing session interfered with joining
1 parent e06b6ca commit e15eae8

File tree

2 files changed

+127
-28
lines changed

2 files changed

+127
-28
lines changed

core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteBufferChannel.kt

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,47 +1757,74 @@ internal class ByteBufferChannel(
17571757
}
17581758

17591759
override suspend fun writeSuspendSession(visitor: suspend WriterSuspendSession.() -> Unit) {
1760-
writing { byteBuffer, ringBufferCapacity ->
1761-
var locked = 0
1762-
1763-
val session = object : WriterSuspendSession {
1764-
override fun request(min: Int): ByteBuffer? {
1765-
locked += ringBufferCapacity.tryWriteAtLeast(0)
1766-
if (locked < min) return null
1767-
byteBuffer.prepareBuffer(writeByteOrder, writePosition, locked)
1768-
if (byteBuffer.remaining() < min) return null
1769-
if (joining != null) return null
1760+
var locked = 0
17701761

1771-
return byteBuffer
1772-
}
1762+
var current = joining?.let { resolveDelegation(this, it) } ?: this
1763+
var byteBuffer = current.setupStateForWrite() ?: return writeSuspendSession(visitor)
1764+
var ringBufferCapacity = current.state.capacity
17731765

1774-
override fun written(n: Int) {
1775-
require(n >= 0)
1776-
if (n > locked) throw IllegalStateException()
1777-
locked -= n
1778-
byteBuffer.bytesWritten(ringBufferCapacity, n)
1779-
}
1766+
val session = object : WriterSuspendSession {
1767+
override fun request(min: Int): ByteBuffer? {
1768+
locked += ringBufferCapacity.tryWriteAtLeast(0)
1769+
if (locked < min) return null
1770+
byteBuffer.prepareBuffer(writeByteOrder, writePosition, locked)
1771+
if (byteBuffer.remaining() < min) return null
1772+
if (current.joining != null) return null
17801773

1781-
override suspend fun tryAwait(n: Int) {
1782-
if (locked >= n) return
1783-
if (locked > 0) {
1784-
ringBufferCapacity.completeRead(locked)
1785-
locked = 0
1786-
}
1774+
return byteBuffer
1775+
}
1776+
1777+
override fun written(n: Int) {
1778+
require(n >= 0)
1779+
if (n > locked) throw IllegalStateException()
1780+
locked -= n
1781+
byteBuffer.bytesWritten(ringBufferCapacity, n)
1782+
}
17871783

1788-
return tryWriteSuspend(n)
1784+
override suspend fun tryAwait(n: Int) {
1785+
val joining = current.joining
1786+
if (joining != null) {
1787+
return tryAwaitJoinSwitch(n, joining)
17891788
}
1789+
1790+
if (locked >= n) return
1791+
if (locked > 0) {
1792+
ringBufferCapacity.completeRead(locked)
1793+
locked = 0
1794+
}
1795+
1796+
return tryWriteSuspend(n)
17901797
}
17911798

1792-
try {
1793-
visitor(session)
1794-
} finally {
1799+
private suspend fun tryAwaitJoinSwitch(n: Int, joining: JoiningState) {
17951800
if (locked > 0) {
17961801
ringBufferCapacity.completeRead(locked)
17971802
locked = 0
17981803
}
1804+
flush()
1805+
restoreStateAfterWrite()
1806+
tryTerminate()
1807+
1808+
do {
1809+
current.tryWriteSuspend(n)
1810+
current = resolveDelegation(current, joining) ?: continue
1811+
byteBuffer = current.setupStateForWrite() ?: continue
1812+
ringBufferCapacity = current.state.capacity
1813+
} while (false)
17991814
}
18001815
}
1816+
1817+
try {
1818+
visitor(session)
1819+
} finally {
1820+
if (locked > 0) {
1821+
ringBufferCapacity.completeRead(locked)
1822+
locked = 0
1823+
}
1824+
1825+
current.restoreStateAfterWrite()
1826+
current.tryTerminate()
1827+
}
18011828
}
18021829

18031830
override fun consumed(n: Int) {

core/kotlinx-coroutines-io/src/test/kotlin/kotlinx/coroutines/experimental/io/ByteBufferChannelTest.kt

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,6 +1546,78 @@ class ByteBufferChannelTest : TestBase() {
15461546
sub.join()
15471547
}
15481548

1549+
@Test
1550+
fun testWriteSuspendSessionSmokeTest() = runTest {
1551+
ch.writeSuspendSession {
1552+
val buffer = request(1)
1553+
assertNotNull(buffer)
1554+
}
1555+
1556+
ch.writeSuspendSession {
1557+
val buffer = request(1)!!
1558+
buffer.put(0x11)
1559+
written(1)
1560+
}
1561+
1562+
assertEquals(0, ch.availableForRead)
1563+
ch.flush()
1564+
assertEquals(1, ch.availableForRead)
1565+
assertEquals(0x11, ch.readByte())
1566+
}
1567+
1568+
@Test
1569+
fun testWriteSuspendSessionJoined() = runTest {
1570+
val next = ByteChannel()
1571+
launch(Unconfined) {
1572+
ch.joinTo(next, true)
1573+
}
1574+
1575+
yield()
1576+
1577+
ch.writeSuspendSession {
1578+
val buffer = request(1)
1579+
assertNotNull(buffer)
1580+
buffer!!.put(0x11)
1581+
written(1)
1582+
}
1583+
1584+
assertEquals(0, next.availableForRead)
1585+
ch.flush()
1586+
assertEquals(1, next.availableForRead)
1587+
assertEquals(0x11, next.readByte())
1588+
}
1589+
1590+
@Test
1591+
fun testWriteSuspendSessionJoinDuringWrite() = runTest {
1592+
val next = ByteChannel()
1593+
1594+
ch.writeSuspendSession {
1595+
var buffer = request(1)
1596+
assertNotNull(buffer)
1597+
buffer!!.put(0x11)
1598+
written(1)
1599+
1600+
launch(Unconfined) {
1601+
ch.joinTo(next, true)
1602+
}
1603+
1604+
yield()
1605+
1606+
assertNull(request(1))
1607+
tryAwait(1)
1608+
buffer = request(1)
1609+
assertNotNull(buffer)
1610+
buffer!!.put(0x22)
1611+
written(1)
1612+
}
1613+
1614+
ch.flush()
1615+
1616+
assertEquals(2, next.availableForRead)
1617+
assertEquals(0x11, next.readByte())
1618+
assertEquals(0x22, next.readByte())
1619+
}
1620+
15491621
private inline fun buildPacket(block: ByteWritePacket.() -> Unit): ByteReadPacket {
15501622
val builder = BytePacketBuilder(0, pktPool)
15511623
try {

0 commit comments

Comments
 (0)