Skip to content

Commit 7154d59

Browse files
author
Sergey Mashkov
committed
IO: add convinient packet functions on Java IO Streams and NIO channels
1 parent 5428565 commit 7154d59

File tree

4 files changed

+317
-1
lines changed

4 files changed

+317
-1
lines changed

core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/buffers/BufferView.kt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,19 @@ internal class BufferView private constructor(private var content: ByteBuffer,
274274
}
275275
}
276276

277-
internal fun writeDirect(size: Int, block: (ByteBuffer) -> Unit) {
277+
internal inline fun readDirect(block: (ByteBuffer) -> Unit) {
278+
val bb = readDuplicated(readRemaining)
279+
val positionBefore = bb.position()
280+
val limit = bb.limit()
281+
block(bb)
282+
val delta = bb.position() - positionBefore
283+
if (delta < 0) throw IllegalStateException("Wrong buffer position change: negative shift $delta")
284+
if (bb.limit() != limit) throw IllegalStateException("Limit change is now allowed")
285+
286+
readPosition += delta
287+
}
288+
289+
internal inline fun writeDirect(size: Int, block: (ByteBuffer) -> Unit) {
278290
val rem = writeRemaining
279291
require (size <= rem) { "size $size is greater than buffer's remaining capacity $rem" }
280292
val buffer = writeDuplicated(rem)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package kotlinx.coroutines.experimental.io.packet
2+
3+
import kotlinx.coroutines.experimental.io.buffers.*
4+
import kotlinx.coroutines.experimental.io.internal.*
5+
import java.io.*
6+
import java.nio.channels.*
7+
8+
fun WritableByteChannel.writePacket(builder: ByteWritePacket.() -> Unit) {
9+
writePacket(buildPacket(block = builder))
10+
}
11+
12+
fun WritableByteChannel.writePacket(p: ByteReadPacket) {
13+
if (p is ByteReadPacketViewBased) {
14+
var b: BufferView? = null
15+
try {
16+
while (true) {
17+
b = p.steal() ?: break
18+
19+
b.readDirect { bb ->
20+
while (bb.hasRemaining()) {
21+
write(bb)
22+
}
23+
}
24+
}
25+
} finally {
26+
b?.release()
27+
p.release()
28+
}
29+
} else {
30+
writePacketSlow(p)
31+
}
32+
}
33+
34+
private fun WritableByteChannel.writePacketSlow(p: ByteReadPacket) {
35+
val buffer = BufferPool.borrow()
36+
try {
37+
while (!p.isEmpty) {
38+
buffer.clear()
39+
p.readAvailable(buffer)
40+
buffer.flip()
41+
42+
while (buffer.hasRemaining()) {
43+
write(buffer)
44+
}
45+
}
46+
} finally {
47+
BufferPool.recycle(buffer)
48+
p.release()
49+
}
50+
}
51+
52+
fun ReadableByteChannel.readPacketExact(n: Long): ByteReadPacket = readPacketImpl(n, n)
53+
fun ReadableByteChannel.readPacketAtLeast(n: Long): ByteReadPacket = readPacketImpl(n, Long.MAX_VALUE)
54+
fun ReadableByteChannel.readPacketAtMost(n: Long): ByteReadPacket = readPacketImpl(1L, n)
55+
56+
private fun ReadableByteChannel.readPacketImpl(min: Long, max: Long): ByteReadPacket {
57+
require(min >= 0L)
58+
require(min <= max)
59+
60+
if (max == 0L) return ByteReadPacketEmpty
61+
62+
val empty = BufferView.Empty
63+
var head: BufferView = empty
64+
var tail: BufferView = empty
65+
66+
var read = 0L
67+
68+
try {
69+
while (read < min || (read == min && min == 0L)) {
70+
val remInt = (max - read).coerceAtMost(Int.MAX_VALUE.toLong()).toInt()
71+
72+
val part = tail.takeIf { it.writeRemaining.let { it > 200 || it >= remInt } } ?: BufferView.Pool.borrow().also {
73+
if (head === empty) {
74+
head = it; tail = it
75+
}
76+
}
77+
if (tail !== part) {
78+
tail.next = part
79+
tail = part
80+
}
81+
82+
part.writeDirect(1) { bb ->
83+
val l = bb.limit()
84+
if (bb.remaining() > remInt) {
85+
bb.limit(bb.position() + remInt)
86+
}
87+
88+
val rc = read(bb)
89+
if (rc == -1) throw EOFException("Premature end of stream: was read $read bytes of $min")
90+
91+
bb.limit(l)
92+
read += rc
93+
}
94+
}
95+
} catch (t: Throwable) {
96+
head.releaseAll()
97+
throw t
98+
}
99+
100+
return ByteReadPacketViewBased(head)
101+
}
102+
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package kotlinx.coroutines.experimental.io.packet
2+
3+
import java.io.*
4+
5+
fun OutputStream.writePacket(builder: ByteWritePacket.() -> Unit) {
6+
writePacket(buildPacket(block = builder))
7+
}
8+
9+
fun OutputStream.writePacket(p: ByteReadPacket) {
10+
val s = p.remaining
11+
if (s == 0) return
12+
val buffer = ByteArray(s.coerceAtMost(4096))
13+
14+
try {
15+
while (!p.isEmpty) {
16+
val size = p.readAvailable(buffer)
17+
write(buffer, 0, size)
18+
}
19+
} finally {
20+
p.release()
21+
}
22+
}
23+
24+
fun InputStream.readPacketExact(n: Long): ByteReadPacket = readPacketImpl(n, n)
25+
fun InputStream.readPacketAtLeast(n: Long): ByteReadPacket = readPacketImpl(n, Long.MAX_VALUE)
26+
fun InputStream.readPacketAtMost(n: Long): ByteReadPacket = readPacketImpl(1L, n)
27+
28+
private fun InputStream.readPacketImpl(min: Long, max: Long): ByteReadPacket {
29+
require(min >= 0L)
30+
require(min <= max)
31+
32+
val buffer = ByteArray(max.coerceAtMost(4096).toInt())
33+
val builder = WritePacket()
34+
35+
var read = 0L
36+
37+
try {
38+
while (read < min || (read == min && min == 0L)) {
39+
val remInt = minOf(max - read, Int.MAX_VALUE.toLong()).toInt()
40+
val rc = read(buffer, 0, minOf(remInt, buffer.size))
41+
if (rc == -1) throw EOFException("Premature end of stream: was read $read bytes of $min")
42+
read += rc
43+
builder.writeFully(buffer, 0, rc)
44+
}
45+
} catch (t: Throwable) {
46+
builder.release()
47+
throw t
48+
}
49+
50+
return builder.build()
51+
}
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package kotlinx.coroutines.experimental.io
2+
3+
import kotlinx.coroutines.experimental.io.packet.*
4+
import org.junit.*
5+
import java.io.*
6+
import java.nio.ByteOrder
7+
import java.nio.channels.*
8+
import kotlin.test.*
9+
10+
class PacketInteropTest {
11+
private val baos = ByteArrayOutputStream()
12+
private val out = Channels.newChannel(baos)!!
13+
14+
private val bais by lazy { ByteArrayInputStream(baos.toByteArray()) }
15+
private val input by lazy { Channels.newChannel(bais)!! }
16+
17+
@Test
18+
fun testStream() {
19+
baos.writePacket {
20+
writeInt(777)
21+
writeLong(0x1234567812345678L)
22+
writeStringUtf8("OK")
23+
}
24+
25+
val result = ByteBuffer.wrap(baos.toByteArray())!!
26+
result.order(ByteOrder.BIG_ENDIAN)
27+
assertEquals(777, result.getInt())
28+
assertEquals(0x1234567812345678L, result.getLong())
29+
assertEquals(0x4f4b, result.getShort())
30+
31+
val p = bais.readPacketExact(14)
32+
assertEquals(777, p.readInt())
33+
assertEquals(0x1234567812345678L, p.readLong())
34+
assertEquals("OK", p.readUTF8Line())
35+
}
36+
37+
@Test
38+
fun testStreamLong() {
39+
baos.writePacket {
40+
writeInt(777)
41+
writeLong(0x1234567812345678L)
42+
43+
repeat(4000) {
44+
append("OK")
45+
}
46+
}
47+
48+
val result = ByteBuffer.wrap(baos.toByteArray())!!
49+
result.order(ByteOrder.BIG_ENDIAN)
50+
assertEquals(777, result.getInt())
51+
assertEquals(0x1234567812345678L, result.getLong())
52+
repeat(4000) {
53+
assertEquals(0x4f4b, result.getShort())
54+
}
55+
56+
val p = bais.readPacketExact(baos.size().toLong())
57+
assertEquals(777, p.readInt())
58+
assertEquals(0x1234567812345678L, p.readLong())
59+
repeat(4000) {
60+
assertEquals(0x4f4b, p.readShort())
61+
}
62+
}
63+
64+
@Test
65+
fun testChannel() {
66+
out.writePacket {
67+
writeInt(777)
68+
writeLong(0x1234567812345678L)
69+
writeStringUtf8("OK")
70+
}
71+
out.close()
72+
73+
val result = ByteBuffer.wrap(baos.toByteArray())!!
74+
result.order(ByteOrder.BIG_ENDIAN)
75+
assertEquals(777, result.getInt())
76+
assertEquals(0x1234567812345678L, result.getLong())
77+
assertEquals(0x4f4b, result.getShort())
78+
79+
val p = input.readPacketExact(14)
80+
assertEquals(777, p.readInt())
81+
assertEquals(0x1234567812345678L, p.readLong())
82+
assertEquals("OK", p.readUTF8Line())
83+
}
84+
85+
@Test
86+
fun testChannelLong() {
87+
out.writePacket {
88+
writeInt(777)
89+
writeLong(0x1234567812345678L)
90+
repeat(4000) {
91+
append("OK")
92+
}
93+
}
94+
out.close()
95+
96+
val result = ByteBuffer.wrap(baos.toByteArray())!!
97+
result.order(ByteOrder.BIG_ENDIAN)
98+
assertEquals(777, result.getInt())
99+
assertEquals(0x1234567812345678L, result.getLong())
100+
repeat(4000) {
101+
assertEquals(0x4f4b, result.getShort())
102+
}
103+
104+
val p = input.readPacketExact(baos.size().toLong())
105+
assertEquals(777, p.readInt())
106+
assertEquals(0x1234567812345678L, p.readLong())
107+
repeat(4000) {
108+
assertEquals(0x4f4b, p.readShort())
109+
}
110+
}
111+
112+
@Test
113+
fun testStreamReadPacketAtLeast() {
114+
baos.writePacket {
115+
writeInt(0x12345678)
116+
}
117+
118+
assertEquals(4, bais.readPacketAtLeast(1).remaining)
119+
assertEquals(0, bais.available())
120+
}
121+
122+
@Test
123+
fun testStreamReadPacketAtMost() {
124+
baos.writePacket {
125+
writeInt(0x12345678)
126+
}
127+
128+
assertEquals(1, bais.readPacketAtMost(1).remaining)
129+
assertEquals(3, bais.available())
130+
}
131+
132+
@Test
133+
fun testChannelReadPacketAtLeast() {
134+
baos.writePacket {
135+
writeInt(0x12345678)
136+
}
137+
138+
assertEquals(4, input.readPacketAtLeast(1).remaining)
139+
assertEquals(0, bais.available())
140+
}
141+
142+
@Test
143+
fun testChannelReadPacketAtMost() {
144+
baos.writePacket {
145+
writeInt(0x12345678)
146+
}
147+
148+
assertEquals(1, input.readPacketAtMost(1).remaining)
149+
assertEquals(3, bais.available())
150+
}
151+
}

0 commit comments

Comments
 (0)