Skip to content

Commit dfae5e2

Browse files
author
Sergey Mashkov
committed
IO: add writeWhile function
1 parent ffc7134 commit dfae5e2

File tree

3 files changed

+165
-33
lines changed

3 files changed

+165
-33
lines changed

core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteBufferChannel.kt

Lines changed: 115 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ package kotlinx.coroutines.experimental.io
55
import kotlinx.atomicfu.*
66
import kotlinx.coroutines.experimental.*
77
import kotlinx.coroutines.experimental.channels.*
8-
import kotlinx.coroutines.experimental.internal.*
98
import kotlinx.coroutines.experimental.io.internal.*
109
import kotlinx.coroutines.experimental.io.packet.*
1110
import kotlinx.io.core.*
@@ -368,30 +367,35 @@ internal class ByteBufferChannel(
368367
}
369368
}
370369

371-
private tailrec fun readAsMuchAsPossible(dst: ByteBuffer, consumed0: Int = 0): Int {
370+
private fun readAsMuchAsPossible(dst: ByteBuffer): Int {
372371
var consumed = 0
373372

374-
val rc = reading {
375-
val position = position()
376-
val remaining = limit() - position
373+
reading { state ->
374+
val buffer = this
375+
val bufferLimit = buffer.capacity() - reservedSize
377376

378-
val part = it.tryReadAtMost(minOf(remaining, dst.remaining()))
379-
if (part > 0) {
380-
consumed += part
377+
while (true) {
378+
val dstRemaining = dst.remaining()
379+
if (dstRemaining == 0) break
381380

382-
limit(position + part)
383-
dst.put(this)
381+
val position = readPosition
382+
val bufferRemaining = bufferLimit - position
384383

385-
bytesRead(it, part)
386-
true
387-
} else {
388-
false
384+
val part = state.tryReadAtMost(minOf(bufferRemaining, dstRemaining))
385+
if (part == 0) break
386+
387+
buffer.limit(position + part)
388+
buffer.position(position)
389+
dst.put(buffer)
390+
391+
bytesRead(state, part)
392+
consumed += part
389393
}
394+
395+
false
390396
}
391397

392-
return if (rc && dst.hasRemaining() && state.capacity.availableForRead > 0)
393-
readAsMuchAsPossible(dst, consumed0 + consumed)
394-
else consumed + consumed0
398+
return consumed
395399
}
396400

397401
private tailrec fun readAsMuchAsPossible(dst: BufferView, consumed0: Int = 0): Int {
@@ -421,25 +425,34 @@ internal class ByteBufferChannel(
421425
}
422426

423427

424-
private tailrec fun readAsMuchAsPossible(dst: ByteArray, offset: Int, length: Int, consumed0: Int = 0): Int {
428+
private fun readAsMuchAsPossible(dst: ByteArray, offset: Int, length: Int): Int {
425429
var consumed = 0
426430

427-
val rc = reading {
428-
val part = it.tryReadAtMost(minOf(remaining(), length))
429-
if (part > 0) {
430-
consumed += part
431-
get(dst, offset, part)
431+
reading { state ->
432+
val buffer = this
433+
val bufferLimit = buffer.capacity() - reservedSize
432434

433-
bytesRead(it, part)
434-
true
435-
} else {
436-
false
435+
while (true) {
436+
val lengthRemaining = length - consumed
437+
if (lengthRemaining == 0) break
438+
val position = readPosition
439+
val bufferRemaining = bufferLimit - position
440+
441+
val part = state.tryReadAtMost(minOf(bufferRemaining, lengthRemaining))
442+
if (part == 0) break
443+
444+
buffer.limit(position + part)
445+
buffer.position(position)
446+
buffer.get(dst, offset + consumed, part)
447+
448+
bytesRead(state, part)
449+
consumed += part
437450
}
451+
452+
false
438453
}
439454

440-
return if (rc && consumed < length && state.capacity.availableForRead > 0)
441-
readAsMuchAsPossible(dst, offset + consumed, length - consumed, consumed0 + consumed)
442-
else consumed + consumed0
455+
return consumed
443456
}
444457

445458
final suspend override fun readFully(dst: ByteArray, offset: Int, length: Int) {
@@ -1451,6 +1464,78 @@ internal class ByteBufferChannel(
14511464
return write(min, block)
14521465
}
14531466

1467+
override suspend fun writeWhile(block: (ByteBuffer) -> Boolean) {
1468+
if (!writeWhileNoSuspend(block)) return
1469+
closed?.let { throw it.sendException }
1470+
return writeWhileSuspend(block)
1471+
}
1472+
1473+
private fun writeWhileNoSuspend(block: (ByteBuffer) -> Boolean): Boolean {
1474+
var continueWriting = true
1475+
1476+
writing { dst, capacity ->
1477+
continueWriting = writeWhileLoop(dst, capacity, block)
1478+
}
1479+
1480+
return continueWriting
1481+
}
1482+
1483+
private suspend fun writeWhileSuspend(block: (ByteBuffer) -> Boolean) {
1484+
var continueWriting = true
1485+
1486+
writing { dst, capacity ->
1487+
while (true) {
1488+
writeSuspend(1)
1489+
if (joining != null) break
1490+
if (!writeWhileLoop(dst, capacity, block)) {
1491+
continueWriting = false
1492+
break
1493+
}
1494+
if (closed != null) break
1495+
}
1496+
}
1497+
1498+
if (!continueWriting) return
1499+
closed?.let { throw it.sendException }
1500+
joining?.let { return writeWhile(block) }
1501+
}
1502+
1503+
// it should be writing state set to use this function
1504+
private fun writeWhileLoop(dst: ByteBuffer, capacity: RingBufferCapacity, block: (ByteBuffer) -> Boolean): Boolean {
1505+
var continueWriting = true
1506+
val bufferLimit = dst.capacity() - reservedSize
1507+
1508+
while (continueWriting) {
1509+
val locked = capacity.tryWriteAtLeast(1) // see comments in [write]
1510+
if (locked == 0) break
1511+
1512+
val position = writePosition
1513+
val l = (position + locked).coerceAtMost(bufferLimit)
1514+
dst.limit(l)
1515+
dst.position(position)
1516+
1517+
continueWriting = try {
1518+
block(dst)
1519+
} catch (t: Throwable) {
1520+
capacity.completeRead(locked)
1521+
throw t
1522+
}
1523+
1524+
if (dst.limit() != l) throw IllegalStateException("buffer limit modified")
1525+
val actuallyWritten = dst.position() - position
1526+
if (actuallyWritten < 0) throw IllegalStateException("position has been moved backward: pushback is not supported")
1527+
1528+
dst.bytesWritten(capacity, actuallyWritten)
1529+
if (actuallyWritten < locked) {
1530+
capacity.completeRead(locked - actuallyWritten) // return back extra bytes
1531+
// it is important to use completeRead in spite of that we are writing here
1532+
// no need to resume read here
1533+
}
1534+
}
1535+
1536+
return continueWriting
1537+
}
1538+
14541539
override suspend fun read(min: Int, block: (ByteBuffer) -> Unit) {
14551540
require(min >= 0) { "min should be positive or zero" }
14561541

core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteWriteChannel.kt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package kotlinx.coroutines.experimental.io
33
import kotlinx.coroutines.experimental.io.packet.*
44
import kotlinx.coroutines.experimental.io.packet.ByteReadPacket
55
import kotlinx.io.core.*
6-
import java.nio.ByteBuffer
76
import java.nio.CharBuffer
87
import java.util.concurrent.CancellationException
98

@@ -76,6 +75,14 @@ public interface ByteWriteChannel {
7675
*/
7776
suspend fun write(min: Int = 1, block: (ByteBuffer) -> Unit)
7877

78+
/**
79+
* Invokes [block] for every free buffer until it return `false`. It will also suspend every time when no free
80+
* space available for write.
81+
*
82+
* @param block to be invoked when there is free space available for write
83+
*/
84+
suspend fun writeWhile(block: (ByteBuffer) -> Boolean)
85+
7986
/**
8087
* Writes a [packet] fully or fails if channel get closed before the whole packet has been written
8188
*/

core/kotlinx-coroutines-io/src/test/kotlin/kotlinx/coroutines/experimental/io/ByteBufferChannelScenarioTest.kt

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ import kotlinx.coroutines.experimental.launch
77
import kotlinx.coroutines.experimental.runBlocking
88
import kotlinx.coroutines.experimental.yield
99
import org.junit.*
10+
import org.junit.Test
1011
import java.io.IOException
11-
import kotlin.test.assertEquals
12-
import kotlin.test.fail
12+
import kotlin.test.*
1313

1414
class ByteBufferChannelScenarioTest : TestBase() {
1515
private val ch = ByteBufferChannel(true)
@@ -553,4 +553,44 @@ class ByteBufferChannelScenarioTest : TestBase() {
553553

554554
finish(5)
555555
}
556+
557+
@Test
558+
fun testWriteWhile() = runBlocking {
559+
val size = 16384
560+
561+
launch(coroutineContext) {
562+
expect(1)
563+
var b: Byte = 0
564+
var count = 0
565+
566+
ch.writeWhile { buffer ->
567+
while (buffer.hasRemaining() && count < size) {
568+
buffer.put(b++)
569+
count++
570+
}
571+
count < size
572+
}
573+
expect(3)
574+
ch.close()
575+
}
576+
577+
yield()
578+
579+
expect(2)
580+
581+
val buffer = ByteArray(size)
582+
ch.readFully(buffer)
583+
584+
var expectedB: Byte = 0
585+
for (i in buffer.indices) {
586+
assertEquals(expectedB, buffer[i])
587+
expectedB++
588+
}
589+
590+
yield()
591+
yield()
592+
593+
finish(4)
594+
assertTrue(ch.isClosedForRead)
595+
}
556596
}

0 commit comments

Comments
 (0)