@@ -13,6 +13,8 @@ import kotlinx.io.pool.*
13
13
import java.io.EOFException
14
14
import java.nio.*
15
15
import java.util.concurrent.atomic.*
16
+ import kotlin.coroutines.experimental.*
17
+ import kotlin.coroutines.experimental.intrinsics.*
16
18
17
19
internal const val DEFAULT_CLOSE_MESSAGE = " Byte channel was closed"
18
20
@@ -42,10 +44,10 @@ internal class ByteBufferChannel(
42
44
private var joining: JoiningState ? = null
43
45
44
46
@Volatile
45
- private var readOp: CancellableContinuation <Boolean >? = null
47
+ private var readOp: Continuation <Boolean >? = null
46
48
47
49
@Volatile
48
- private var writeOp: CancellableContinuation <Unit >? = null
50
+ private var writeOp: Continuation <Unit >? = null
49
51
50
52
private var readPosition = 0
51
53
private var writePosition = 0
@@ -56,8 +58,9 @@ internal class ByteBufferChannel(
56
58
internal fun attachJob (job : Job ) {
57
59
attachedJob?.cancel()
58
60
attachedJob = job
59
- job.invokeOnCompletion {
61
+ job.invokeOnCompletion(onCancelling = true ) { cause ->
60
62
attachedJob = null
63
+ if (cause != null ) cancel(cause)
61
64
}
62
65
}
63
66
@@ -104,6 +107,8 @@ internal class ByteBufferChannel(
104
107
}
105
108
106
109
if (cause != null ) attachedJob?.cancel(cause)
110
+ readSuspendContinuationCache.close()
111
+ writeSuspendContinuationCache.close()
107
112
108
113
return true
109
114
}
@@ -150,7 +155,9 @@ internal class ByteBufferChannel(
150
155
}
151
156
152
157
private fun setupStateForWrite (): ByteBuffer ? {
153
- if (writeOp != null ) throw IllegalStateException (" Write operation is already in progress" )
158
+ if (writeOp != null ) {
159
+ throw IllegalStateException (" Write operation is already in progress" )
160
+ }
154
161
155
162
var _allocated : ReadWriteBufferState .Initial ? = null
156
163
val (old, newState) = updateState { state ->
@@ -1116,7 +1123,7 @@ internal class ByteBufferChannel(
1116
1123
1117
1124
private suspend fun writeFullySuspend (src : ByteBuffer ) {
1118
1125
while (src.hasRemaining()) {
1119
- writeSuspend (1 )
1126
+ tryWriteSuspend (1 )
1120
1127
1121
1128
joining?.let { resolveDelegation(this , it)?.let { return it.writeFully(src) } }
1122
1129
@@ -1126,7 +1133,7 @@ internal class ByteBufferChannel(
1126
1133
1127
1134
private suspend fun writeFullySuspend (src : BufferView ) {
1128
1135
while (src.canRead()) {
1129
- writeSuspend (1 )
1136
+ tryWriteSuspend (1 )
1130
1137
1131
1138
joining?.let { resolveDelegation(this , it)?.let { return it.writeFully(src) } }
1132
1139
@@ -1197,7 +1204,7 @@ internal class ByteBufferChannel(
1197
1204
while (copied < limit) {
1198
1205
var avWBefore = state.availableForWrite
1199
1206
if (avWBefore == 0 ) {
1200
- writeSuspend (1 )
1207
+ tryWriteSuspend (1 )
1201
1208
if (joining != null ) break
1202
1209
avWBefore = state.availableForWrite
1203
1210
}
@@ -1256,7 +1263,7 @@ internal class ByteBufferChannel(
1256
1263
// println("readSuspend?")
1257
1264
flush()
1258
1265
1259
- if (src.availableForRead == 0 && ! src.readSuspend (1 )) {
1266
+ if (src.availableForRead == 0 && ! src.readSuspendImpl (1 )) {
1260
1267
// println("readSuspend failed")
1261
1268
if (joined == null || src.tryCompleteJoining(joined)) break
1262
1269
}
@@ -1406,7 +1413,7 @@ internal class ByteBufferChannel(
1406
1413
1407
1414
private suspend fun writeSuspend (src : ByteArray , offset : Int , length : Int ): Int {
1408
1415
while (true ) {
1409
- writeSuspend (1 )
1416
+ tryWriteSuspend (1 )
1410
1417
1411
1418
joining?.let { resolveDelegation(this , it)?.let { return it.writeSuspend(src, offset, length) } }
1412
1419
@@ -1704,6 +1711,49 @@ internal class ByteBufferChannel(
1704
1711
return result!!
1705
1712
}
1706
1713
1714
+ override suspend fun writeSuspendSession (visitor : suspend WriterSuspendSession .() -> Unit ) {
1715
+ writing { byteBuffer, ringBufferCapacity ->
1716
+ var locked = 0
1717
+
1718
+ val session = object : WriterSuspendSession {
1719
+ override fun request (min : Int ): ByteBuffer ? {
1720
+ locked + = ringBufferCapacity.tryWriteAtLeast(0 )
1721
+ if (locked < min) return null
1722
+ byteBuffer.prepareBuffer(writeByteOrder, writePosition, locked)
1723
+ if (byteBuffer.remaining() < min) return null
1724
+ if (joining != null ) return null
1725
+
1726
+ return byteBuffer
1727
+ }
1728
+
1729
+ override fun written (n : Int ) {
1730
+ require(n >= 0 )
1731
+ if (n > locked) throw IllegalStateException ()
1732
+ locked - = n
1733
+ byteBuffer.bytesWritten(ringBufferCapacity, n)
1734
+ }
1735
+
1736
+ override suspend fun tryAwait (n : Int ) {
1737
+ if (locked >= n) return
1738
+ if (locked > 0 ) {
1739
+ ringBufferCapacity.completeRead(locked)
1740
+ locked = 0
1741
+ }
1742
+
1743
+ return tryWriteSuspend(n)
1744
+ }
1745
+ }
1746
+
1747
+ try {
1748
+ visitor(session)
1749
+ } finally {
1750
+ if (locked > 0 ) {
1751
+ ringBufferCapacity.completeRead(locked)
1752
+ }
1753
+ }
1754
+ }
1755
+ }
1756
+
1707
1757
override fun consumed (n : Int ) {
1708
1758
require(n >= 0 )
1709
1759
@@ -1713,19 +1763,23 @@ internal class ByteBufferChannel(
1713
1763
}
1714
1764
}
1715
1765
1716
- final suspend override fun awaitAtLeast (n : Int ) {
1766
+ final override suspend fun awaitAtLeast (n : Int ): Boolean {
1717
1767
if (state.capacity.availableForRead >= n) {
1718
- if (state.idle) setupStateForRead()
1719
- return
1768
+ if (state.idle || state is ReadWriteBufferState . Writing ) setupStateForRead()
1769
+ return true
1720
1770
}
1721
1771
1722
- return awaitAtLeastSuspend(n)
1772
+ if (state.idle || state is ReadWriteBufferState .Writing ) return awaitAtLeastSuspend(n)
1773
+ else if (n == 1 ) return readSuspendImpl(1 )
1774
+ else return readSuspend(n)
1723
1775
}
1724
1776
1725
- private suspend fun awaitAtLeastSuspend (n : Int ) {
1726
- if (readSuspend(n) && state.idle) {
1777
+ private suspend fun awaitAtLeastSuspend (n : Int ): Boolean {
1778
+ val rc = readSuspend(n)
1779
+ if (rc && state.idle) {
1727
1780
setupStateForRead()
1728
1781
}
1782
+ return rc
1729
1783
}
1730
1784
1731
1785
override fun request (skip : Int , atLeast : Int ): ByteBuffer ? {
@@ -1746,33 +1800,38 @@ internal class ByteBufferChannel(
1746
1800
}
1747
1801
1748
1802
private inline fun consumeEachBufferRangeFast (last : Boolean , visitor : (buffer: ByteBuffer , last: Boolean ) -> Boolean ): Boolean {
1749
- if (state == = ReadWriteBufferState .Terminated && ! last) return false
1750
-
1751
1803
val rc = reading {
1752
1804
do {
1753
- val available = state.capacity.availableForRead
1754
-
1755
- val rem = if (available > 0 || last) {
1756
- if (! visitor(this , last)) {
1757
- afterBufferVisited(this , it)
1758
- return true
1759
- }
1760
-
1761
- val consumed = afterBufferVisited(this , it)
1762
- available - consumed
1763
- } else 0
1764
- } while (rem > 0 && ! last)
1805
+ if (hasRemaining() || last) {
1806
+ val rc = visitor(this , last)
1807
+ afterBufferVisited(this , it)
1808
+ if (! rc || (last && ! hasRemaining())) return true
1809
+ } else break
1810
+ } while (true )
1765
1811
1766
1812
last
1767
1813
}
1768
1814
1769
1815
if (! rc && closed != null ) {
1770
1816
visitor(EmptyByteBuffer , true )
1817
+ return true
1771
1818
}
1772
1819
1773
1820
return rc
1774
1821
}
1775
1822
1823
+ // private suspend fun consumeEachBufferRangeSuspendLoop(visitor: RendezvousChannel<ConsumeEachBufferVisitor>) {
1824
+ // var last = false
1825
+ //
1826
+ // do {
1827
+ // if (consumeEachBufferRangeFast(last, visitor)) return
1828
+ // if (last) return
1829
+ // if (!readSuspend(1)) {
1830
+ // last = true
1831
+ // }
1832
+ // } while (true)
1833
+ // }
1834
+
1776
1835
private suspend fun consumeEachBufferRangeSuspend (visitor : (buffer: ByteBuffer , last: Boolean ) -> Boolean ) {
1777
1836
var last = false
1778
1837
@@ -1998,7 +2057,23 @@ internal class ByteBufferChannel(
1998
2057
ClosedWriteChannelException (DEFAULT_CLOSE_MESSAGE ))
1999
2058
}
2000
2059
2001
- private tailrec suspend fun readSuspend (size : Int ): Boolean {
2060
+ private suspend fun readSuspend (size : Int ): Boolean {
2061
+ val capacity = state.capacity
2062
+ if (capacity.availableForRead >= size) return true
2063
+
2064
+ closed?.let { c ->
2065
+ if (c.cause != null ) throw c.cause
2066
+ val afterCapacity = state.capacity
2067
+ val result = afterCapacity.flush() && afterCapacity.availableForRead >= size
2068
+ if (readOp != null ) throw IllegalStateException (" Read operation is already in progress" )
2069
+ return result
2070
+ }
2071
+
2072
+ if (size == 1 ) return readSuspendImpl(1 )
2073
+ return readSuspendLoop(size)
2074
+ }
2075
+
2076
+ private tailrec suspend fun readSuspendLoop (size : Int ): Boolean {
2002
2077
val capacity = state.capacity
2003
2078
if (capacity.availableForRead >= size) return true
2004
2079
@@ -2012,28 +2087,38 @@ internal class ByteBufferChannel(
2012
2087
2013
2088
if (! readSuspendImpl(size)) return false
2014
2089
2015
- return readSuspend (size)
2090
+ return readSuspendLoop (size)
2016
2091
}
2017
2092
2018
- private suspend fun readSuspendImpl (size : Int ): Boolean {
2019
- if (state.capacity.availableForRead >= size) return true
2093
+ private val readSuspendContinuationCache = MutableDelegateContinuation <Boolean >()
2020
2094
2021
- return suspendCancellableCoroutine(holdCancellability = true ) { c ->
2022
- do {
2023
- if (state.capacity.availableForRead >= size) {
2024
- c.resume(true )
2025
- break
2026
- }
2095
+ private fun suspensionForSize ( size : Int , c : Continuation < Boolean >): Any {
2096
+ do {
2097
+ if (this . state.capacity.availableForRead >= size) {
2098
+ c.resume(true )
2099
+ break
2100
+ }
2027
2101
2028
- closed?.let {
2029
- if (it.cause != null ) {
2030
- c.resumeWithException(it.cause)
2031
- } else {
2032
- c.resume(state.capacity.flush() && state.capacity.availableForRead >= size)
2033
- }
2034
- return @suspendCancellableCoroutine
2102
+ closed?.let {
2103
+ if (it.cause != null ) {
2104
+ c.resumeWithException(it.cause)
2105
+ } else {
2106
+ c.resume(state.capacity.flush() && state.capacity.availableForRead >= size)
2035
2107
}
2036
- } while (! setContinuation({ readOp }, ReadOp , c, { closed == null && state.capacity.availableForRead < size }))
2108
+ return COROUTINE_SUSPENDED
2109
+ }
2110
+ } while (! setContinuation({ readOp }, ReadOp , c, { closed == null && state.capacity.availableForRead < size }))
2111
+
2112
+ return COROUTINE_SUSPENDED
2113
+ }
2114
+
2115
+ private suspend fun readSuspendImpl (size : Int ): Boolean {
2116
+ if (state.capacity.availableForRead >= size) return true
2117
+
2118
+ return suspendCoroutineOrReturn { raw ->
2119
+ val c = readSuspendContinuationCache
2120
+ suspensionForSize(size, c)
2121
+ c.swap(raw)
2037
2122
}
2038
2123
}
2039
2124
@@ -2049,8 +2134,44 @@ internal class ByteBufferChannel(
2049
2134
}
2050
2135
}
2051
2136
2137
+ private val writeSuspendContinuationCache = MutableDelegateContinuation <Unit >()
2138
+ @Volatile
2139
+ private var writeSuspensionSize: Int = 0
2140
+ private val writeSuspension = { c: Continuation <Unit > ->
2141
+ val size = writeSuspensionSize
2142
+
2143
+ do {
2144
+ closed?.sendException?.let { throw it }
2145
+ if (! writeSuspendPredicate(size)) {
2146
+ c.resume(Unit )
2147
+ break
2148
+ }
2149
+ } while (! setContinuation({ writeOp }, WriteOp , c, { writeSuspendPredicate(size) }))
2150
+
2151
+ flushImpl(1 , minWriteSize = size)
2152
+
2153
+ COROUTINE_SUSPENDED
2154
+ }
2155
+
2156
+ private suspend fun tryWriteSuspend (size : Int ) {
2157
+ if (! writeSuspendPredicate(size)) {
2158
+ closed?.sendException?.let { throw it }
2159
+ return
2160
+ }
2161
+
2162
+ writeSuspensionSize = size
2163
+ if (attachedJob != null ) {
2164
+ return suspendCoroutineOrReturn(writeSuspension)
2165
+ }
2166
+
2167
+ return suspendCoroutineOrReturn { raw ->
2168
+ val c = writeSuspendContinuationCache
2169
+ writeSuspension(c)
2170
+ c.swap(raw)
2171
+ }
2172
+ }
2173
+
2052
2174
private suspend fun writeSuspend (size : Int ) {
2053
- // println("Write suspend (enter)")
2054
2175
while (writeSuspendPredicate(size)) {
2055
2176
suspendCancellableCoroutine<Unit >(holdCancellability = true ) { c ->
2056
2177
do {
@@ -2062,15 +2183,13 @@ internal class ByteBufferChannel(
2062
2183
} while (! setContinuation({ writeOp }, WriteOp , c, { writeSuspendPredicate(size) }))
2063
2184
2064
2185
flushImpl(1 , minWriteSize = size)
2065
- // println("Write suspend (loop), op = ${writeOp}, state = $state, joined = $joining")
2066
2186
}
2067
2187
}
2068
2188
2069
2189
closed?.sendException?.let { throw it }
2070
- // println("Write suspend (leave)")
2071
2190
}
2072
2191
2073
- private inline fun <T , C : CancellableContinuation <T >> setContinuation (getter : () -> C ? , updater : AtomicReferenceFieldUpdater <ByteBufferChannel , C ?>, continuation : C , predicate : () -> Boolean ): Boolean {
2192
+ private inline fun <T , C : Continuation <T >> setContinuation (getter : () -> C ? , updater : AtomicReferenceFieldUpdater <ByteBufferChannel , C ?>, continuation : C , predicate : () -> Boolean ): Boolean {
2074
2193
while (true ) {
2075
2194
val current = getter()
2076
2195
if (current != null ) throw IllegalStateException (" Operation is already in progress" )
@@ -2081,7 +2200,9 @@ internal class ByteBufferChannel(
2081
2200
2082
2201
if (updater.compareAndSet(this , null , continuation)) {
2083
2202
if (predicate() || ! updater.compareAndSet(this , continuation, null )) {
2084
- continuation.initCancellability()
2203
+ if (attachedJob == null && continuation is CancellableContinuation <* >) {
2204
+ continuation.initCancellability()
2205
+ }
2085
2206
return true
2086
2207
}
2087
2208
@@ -2144,7 +2265,8 @@ internal class ByteBufferChannel(
2144
2265
2145
2266
override fun request (skip : Int , atLeast : Int ) = null
2146
2267
2147
- suspend override fun awaitAtLeast (n : Int ) {
2268
+ suspend override fun awaitAtLeast (n : Int ): Boolean {
2269
+ return false
2148
2270
}
2149
2271
}
2150
2272
0 commit comments