Skip to content

Commit bd7c030

Browse files
authored
Optimize Flow.take (#1538)
* Allocate SM instance only once for the last flow value
1 parent 3826ae5 commit bd7c030

File tree

3 files changed

+168
-3
lines changed

3 files changed

+168
-3
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package benchmarks.flow
6+
7+
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.flow.*
9+
import org.openjdk.jmh.annotations.*
10+
import java.util.concurrent.*
11+
import java.util.concurrent.CancellationException
12+
import kotlin.coroutines.*
13+
import kotlin.coroutines.intrinsics.*
14+
import benchmarks.flow.scrabble.flow as unsafeFlow
15+
16+
@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
17+
@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
18+
@Fork(value = 1)
19+
@BenchmarkMode(Mode.AverageTime)
20+
@OutputTimeUnit(TimeUnit.MICROSECONDS)
21+
@State(Scope.Benchmark)
22+
open class TakeBenchmark {
23+
@Param("1", "10", "100", "1000")
24+
private var size: Int = 0
25+
26+
private suspend inline fun Flow<Long>.consume() =
27+
filter { it % 2L != 0L }
28+
.map { it * it }.count()
29+
30+
@Benchmark
31+
fun baseline() = runBlocking<Int> {
32+
(0L until size).asFlow().consume()
33+
}
34+
35+
@Benchmark
36+
fun originalTake() = runBlocking<Int> {
37+
(0L..Long.MAX_VALUE).asFlow().originalTake(size).consume()
38+
}
39+
40+
@Benchmark
41+
fun fastPathTake() = runBlocking<Int> {
42+
(0L..Long.MAX_VALUE).asFlow().fastPathTake(size).consume()
43+
}
44+
45+
@Benchmark
46+
fun mergedStateMachine() = runBlocking<Int> {
47+
(0L..Long.MAX_VALUE).asFlow().mergedStateMachineTake(size).consume()
48+
}
49+
50+
internal class StacklessCancellationException() : CancellationException() {
51+
override fun fillInStackTrace(): Throwable = this
52+
}
53+
54+
public fun <T> Flow<T>.originalTake(count: Int): Flow<T> {
55+
return unsafeFlow {
56+
var consumed = 0
57+
try {
58+
collect { value ->
59+
emit(value)
60+
if (++consumed == count) {
61+
throw StacklessCancellationException()
62+
}
63+
}
64+
} catch (e: StacklessCancellationException) {
65+
// Nothing, bail out
66+
}
67+
}
68+
}
69+
70+
private suspend fun <T> FlowCollector<T>.emitAbort(value: T) {
71+
emit(value)
72+
throw StacklessCancellationException()
73+
}
74+
75+
public fun <T> Flow<T>.fastPathTake(count: Int): Flow<T> {
76+
return unsafeFlow {
77+
var consumed = 0
78+
try {
79+
collect { value ->
80+
if (++consumed < count) {
81+
return@collect emit(value)
82+
} else {
83+
return@collect emitAbort(value)
84+
}
85+
}
86+
} catch (e: StacklessCancellationException) {
87+
// Nothing, bail out
88+
}
89+
}
90+
}
91+
92+
93+
public fun <T> Flow<T>.mergedStateMachineTake(count: Int): Flow<T> {
94+
return unsafeFlow() {
95+
try {
96+
val takeCollector = FlowTakeCollector(count, this)
97+
collect(takeCollector)
98+
} catch (e: StacklessCancellationException) {
99+
// Nothing, bail out
100+
}
101+
}
102+
}
103+
104+
105+
private class FlowTakeCollector<T>(
106+
private val count: Int,
107+
downstream: FlowCollector<T>
108+
) : FlowCollector<T>, Continuation<Unit> {
109+
private var consumed = 0
110+
// Workaround for KT-30991
111+
private val emitFun = run {
112+
val suspendFun: suspend (T) -> Unit = { downstream.emit(it) }
113+
suspendFun as Function2<T, Continuation<Unit>, Any?>
114+
}
115+
116+
private var caller: Continuation<Unit>? = null // lateinit
117+
118+
override val context: CoroutineContext
119+
get() = caller?.context ?: EmptyCoroutineContext
120+
121+
override fun resumeWith(result: Result<Unit>) {
122+
val completion = caller!!
123+
if (++consumed == count) completion.resumeWith(Result.failure(StacklessCancellationException()))
124+
else completion.resumeWith(Result.success(Unit))
125+
}
126+
127+
override suspend fun emit(value: T) = suspendCoroutineUninterceptedOrReturn<Unit> sc@{
128+
// Invoke it in non-suspending way
129+
caller = it
130+
val result = emitFun.invoke(value, this)
131+
if (result !== COROUTINE_SUSPENDED) {
132+
if (++consumed == count) throw StacklessCancellationException()
133+
else return@sc Unit
134+
}
135+
COROUTINE_SUSPENDED
136+
}
137+
}
138+
}

kotlinx-coroutines-core/common/src/flow/operators/Limit.kt

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ public fun <T> Flow<T>.take(count: Int): Flow<T> {
5555
var consumed = 0
5656
try {
5757
collect { value ->
58-
emit(value)
59-
if (++consumed == count) {
60-
throw AbortFlowException()
58+
if (++consumed < count) {
59+
return@collect emit(value)
60+
} else {
61+
return@collect emitAbort(value)
6162
}
6263
}
6364
} catch (e: AbortFlowException) {
@@ -66,6 +67,11 @@ public fun <T> Flow<T>.take(count: Int): Flow<T> {
6667
}
6768
}
6869

70+
private suspend fun <T> FlowCollector<T>.emitAbort(value: T) {
71+
emit(value)
72+
throw AbortFlowException()
73+
}
74+
6975
/**
7076
* Returns a flow that contains first elements satisfying the given [predicate].
7177
*/

kotlinx-coroutines-core/common/test/flow/operators/TakeTest.kt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,27 @@ class TakeTest : TestBase() {
2121
assertEquals(2, flow.drop(1).take(1).single())
2222
}
2323

24+
@Test
25+
fun testIllegalArgument() {
26+
assertFailsWith<IllegalArgumentException> { flowOf(1).take(0) }
27+
assertFailsWith<IllegalArgumentException> { flowOf(1).take(-1) }
28+
}
29+
30+
@Test
31+
fun testTakeSuspending() = runTest {
32+
val flow = flow {
33+
emit(1)
34+
yield()
35+
emit(2)
36+
yield()
37+
}
38+
39+
assertEquals(3, flow.take(2).sum())
40+
assertEquals(3, flow.take(Int.MAX_VALUE).sum())
41+
assertEquals(1, flow.take(1).single())
42+
assertEquals(2, flow.drop(1).take(1).single())
43+
}
44+
2445
@Test
2546
fun testEmptyFlow() = runTest {
2647
val sum = emptyFlow<Int>().take(10).sum()

0 commit comments

Comments
 (0)