Skip to content

Commit 6e61c23

Browse files
author
Sergey Mashkov
committed
IO: add Input/OutputStream adapters for regular Java blocking API
1 parent 074f0a2 commit 6e61c23

File tree

2 files changed

+325
-1
lines changed
  • core/kotlinx-coroutines-io/src
    • main/kotlin/kotlinx/coroutines/experimental/io/jvm/javaio
    • test/kotlin/kotlinx/coroutines/experimental/io

2 files changed

+325
-1
lines changed
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
package kotlinx.coroutines.experimental.io.jvm.javaio
2+
3+
import kotlinx.atomicfu.*
4+
import kotlinx.coroutines.experimental.*
5+
import kotlinx.coroutines.experimental.io.*
6+
import java.io.*
7+
import java.util.concurrent.locks.*
8+
import kotlin.coroutines.experimental.*
9+
import kotlin.coroutines.experimental.intrinsics.*
10+
11+
/**
12+
* Create blocking [java.io.InputStream] for this channel that does block every time the channel suspends at read
13+
* Similar to do reading in [runBlocking] however you can pass it to regular blocking API
14+
*/
15+
fun ByteReadChannel.toInputStream(parent: Job? = null): InputStream = InputAdapter(parent, this)
16+
17+
/**
18+
* Create blocking [java.io.OutputStream] for this channel that does block every time the channel suspends at write
19+
* Similar to do reading in [runBlocking] however you can pass it to regular blocking API
20+
*/
21+
fun ByteWriteChannel.toOutputStream(parent: Job? = null): OutputStream = OutputAdapter(parent, this)
22+
23+
private class InputAdapter(parent: Job?, private val channel: ByteReadChannel) : InputStream() {
24+
private val loop = object : BlockingAdapter(parent) {
25+
override suspend fun loop() {
26+
var rc = 0
27+
while (true) {
28+
val buffer = rendezvous(rc) as ByteArray
29+
rc = channel.readAvailable(buffer, offset, length)
30+
if (rc == -1) break
31+
}
32+
}
33+
}
34+
35+
private var single: ByteArray? = null
36+
37+
override fun available(): Int {
38+
return channel.availableForRead
39+
}
40+
41+
@Synchronized
42+
override fun read(): Int {
43+
val buffer = single ?: ByteArray(1).also { single = it }
44+
loop.submitAndAwait(buffer, 0, 1)
45+
return buffer[0].toInt() and 0xff
46+
}
47+
48+
@Synchronized
49+
override fun read(b: ByteArray?, off: Int, len: Int): Int {
50+
return loop.submitAndAwait(b!!, off, len)
51+
}
52+
53+
@Synchronized
54+
override fun close() {
55+
super.close()
56+
channel.cancel()
57+
loop.shutdown()
58+
}
59+
}
60+
61+
private val CloseToken = Any()
62+
private val FlushToken = Any()
63+
64+
private class OutputAdapter(parent: Job?, private val channel: ByteWriteChannel) : OutputStream() {
65+
private val loop = object : BlockingAdapter(parent) {
66+
override suspend fun loop() {
67+
try {
68+
while (true) {
69+
val task = rendezvous(0)
70+
if (task === CloseToken) break
71+
if (task === FlushToken) channel.flush()
72+
else if (task is ByteArray) channel.writeFully(task, offset, length)
73+
}
74+
} catch (t: Throwable) {
75+
if (t !is CancellationException) {
76+
channel.close(t)
77+
}
78+
} finally {
79+
channel.close()
80+
}
81+
}
82+
}
83+
84+
private var single: ByteArray? = null
85+
86+
@Synchronized
87+
override fun write(b: Int) {
88+
val buffer = single ?: ByteArray(1).also { single = it }
89+
buffer[0] = b.toByte()
90+
loop.submitAndAwait(buffer, 0, 1)
91+
}
92+
93+
@Synchronized
94+
override fun write(b: ByteArray?, off: Int, len: Int) {
95+
loop.submitAndAwait(b!!, off, len)
96+
}
97+
98+
@Synchronized
99+
override fun flush() {
100+
loop.submitAndAwait(FlushToken)
101+
}
102+
103+
@Synchronized
104+
override fun close() {
105+
loop.submitAndAwait(CloseToken)
106+
loop.shutdown()
107+
}
108+
}
109+
110+
private abstract class BlockingAdapter(val parent: Job? = null) {
111+
private val end: Continuation<Unit> = object : Continuation<Unit> {
112+
override val context: CoroutineContext
113+
get() = if (parent != null) Unconfined + parent else EmptyCoroutineContext
114+
115+
override fun resume(value: Unit) {
116+
var thread: Thread? = null
117+
result.value = -1
118+
state.update { current ->
119+
when (current) {
120+
is Thread -> {
121+
thread = current
122+
Unit
123+
}
124+
this -> Unit
125+
else -> return
126+
}
127+
}
128+
129+
thread?.let { LockSupport.unpark(it) }
130+
disposable?.dispose()
131+
}
132+
133+
override fun resumeWithException(exception: Throwable) {
134+
var thread: Thread? = null
135+
var continuation: Continuation<*>? = null
136+
137+
result.value = -1
138+
state.update { current ->
139+
when (current) {
140+
is Thread -> {
141+
thread = current
142+
exception
143+
}
144+
is Continuation<*> -> {
145+
continuation = current
146+
exception
147+
}
148+
this -> exception
149+
else -> return
150+
}
151+
}
152+
153+
thread?.let { LockSupport.unpark(it) }
154+
continuation?.resumeWithException(exception)
155+
156+
if (exception !is CancellationException) {
157+
parent?.cancel(exception)
158+
}
159+
160+
disposable?.dispose()
161+
}
162+
}
163+
164+
@Suppress("LeakingThis")
165+
private val state: AtomicRef<Any> = atomic(this) // could be a thread, a continuation, Unit, an exception or this if not yet started
166+
private val result = atomic(0)
167+
private val disposable: DisposableHandle? = parent?.invokeOnCompletion { cause ->
168+
if (cause != null) {
169+
end.resumeWithException(cause)
170+
}
171+
}
172+
173+
protected var offset: Int = 0
174+
private set
175+
protected var length: Int = 0
176+
private set
177+
178+
init {
179+
val block: suspend () -> Unit = { loop() }
180+
block.startCoroutineUninterceptedOrReturn(end)
181+
require(state.value !== this)
182+
}
183+
184+
protected abstract suspend fun loop()
185+
186+
fun shutdown() {
187+
disposable?.dispose()
188+
end.resumeWithException(CancellationException("Stream closed"))
189+
}
190+
191+
fun submitAndAwait(buffer: ByteArray, offset: Int, length: Int): Int {
192+
this.offset = offset
193+
this.length = length
194+
return submitAndAwait(buffer)
195+
}
196+
197+
fun submitAndAwait(jobToken: Any): Int {
198+
val thread = Thread.currentThread()!!
199+
200+
var cont: Continuation<Any>? = null
201+
202+
state.update { value ->
203+
when (value) {
204+
is Continuation<*> -> {
205+
@Suppress("UNCHECKED_CAST")
206+
cont = value as Continuation<Any>
207+
thread
208+
}
209+
is Unit -> {
210+
return result.value
211+
}
212+
is Throwable -> {
213+
throw value
214+
}
215+
is Thread -> throw IllegalStateException("There is already thread owning adapter")
216+
this -> throw IllegalStateException("Not yet started")
217+
else -> NoWhenBranchMatchedException()
218+
}
219+
}
220+
221+
cont!!.resume(jobToken)
222+
223+
while (state.value === thread) {
224+
LockSupport.park()
225+
}
226+
227+
return result.value
228+
}
229+
230+
@Suppress("NOTHING_TO_INLINE")
231+
protected suspend inline fun rendezvous(rc: Int): Any {
232+
result.value = rc
233+
234+
return suspendCoroutineOrReturn { c ->
235+
var thread: Thread? = null
236+
237+
state.update { value ->
238+
when (value) {
239+
is Thread -> {
240+
thread = value
241+
c
242+
}
243+
this -> c
244+
else -> throw IllegalStateException("Already suspended or in finished state")
245+
}
246+
}
247+
248+
if (thread != null) {
249+
LockSupport.unpark(thread)
250+
}
251+
252+
COROUTINE_SUSPENDED
253+
}
254+
}
255+
}

core/kotlinx-coroutines-io/src/test/kotlin/kotlinx/coroutines/experimental/io/JavaIOTest.kt

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import kotlinx.coroutines.experimental.io.jvm.nio.*
66
import org.junit.Test
77
import java.io.*
88
import java.nio.channels.*
9+
import java.util.*
910
import kotlin.test.*
1011

1112
class JavaIOTest : TestBase() {
@@ -207,4 +208,72 @@ class JavaIOTest : TestBase() {
207208
pipe.sink().close()
208209
exec.close()
209210
}
210-
}
211+
212+
@Test
213+
fun testInputAdapter() {
214+
newFixedThreadPoolContext(2, "blocking-io").use { exec ->
215+
val input = channel.toInputStream()
216+
val data = ByteArray(100)
217+
Random().nextBytes(data)
218+
launch(exec) {
219+
channel.writeFully(data)
220+
channel.close()
221+
}
222+
223+
val result = ByteArray(100)
224+
assertEquals(100, input.read(result))
225+
assertEquals(-1, input.read(result))
226+
227+
assertTrue(result.contentEquals(data))
228+
}
229+
}
230+
231+
@Test
232+
fun testInputAdapter2() {
233+
newFixedThreadPoolContext(2, "blocking-io").use { exec ->
234+
val count = 100
235+
val data = ByteArray(4096)
236+
Random().nextBytes(data)
237+
238+
repeat(10000) {
239+
val channel = ByteChannel(false)
240+
launch(exec) {
241+
for (i in 1..count) {
242+
channel.writeFully(data)
243+
}
244+
channel.close()
245+
}
246+
247+
val result = channel.toInputStream().readBytes()
248+
assertEquals(4096 * count, result.size)
249+
}
250+
}
251+
}
252+
253+
@Test
254+
fun testOutputAdapter() {
255+
newFixedThreadPoolContext(2, "blocking-io").use { exec ->
256+
val output = channel.toOutputStream()
257+
val data = ByteArray(100)
258+
Random().nextBytes(data)
259+
260+
val j = launch(exec) {
261+
val result = ByteArray(100)
262+
assertEquals(100, channel.readAvailable(result))
263+
assertEquals(-1, channel.readAvailable(result))
264+
assertTrue(result.contentEquals(data))
265+
}
266+
267+
output.write(data)
268+
output.flush()
269+
output.close()
270+
271+
runBlocking {
272+
j.join()
273+
}
274+
j.invokeOnCompletion { cause ->
275+
if (cause != null) throw cause
276+
}
277+
}
278+
}
279+
}

0 commit comments

Comments
 (0)