Skip to content

Commit 0342a0a

Browse files
authored
Restore context preservation invariant in flatMapMerge (#1452)
* Introduce (again) flowProduce in order to properly propagate cancellation to the upstream in flatMapMerge. Previously this issue was masked by SerializingCollector fast-path * Re-implement flatMapMerge via the channel to have context preservation property Fixes #1440
1 parent bcf4a8c commit 0342a0a

File tree

14 files changed

+216
-151
lines changed

14 files changed

+216
-151
lines changed

benchmarks/src/jmh/kotlin/benchmarks/YieldRelativeCostBenchmark.kt

Lines changed: 0 additions & 35 deletions
This file was deleted.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package benchmarks.flow
6+
7+
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.flow.*
9+
import org.openjdk.jmh.annotations.*
10+
import java.util.concurrent.*
11+
12+
@Warmup(iterations = 7, time = 1, timeUnit = TimeUnit.SECONDS)
13+
@Measurement(iterations = 7, time = 1, timeUnit = TimeUnit.SECONDS)
14+
@Fork(value = 1)
15+
@BenchmarkMode(Mode.AverageTime)
16+
@OutputTimeUnit(TimeUnit.MICROSECONDS)
17+
@State(Scope.Benchmark)
18+
open class FlatMapMergeBenchmark {
19+
20+
// Note: tests only absence of contention on downstream
21+
22+
@Param("10", "100", "1000")
23+
private var iterations = 100
24+
25+
@Benchmark
26+
fun flatMapUnsafe() = runBlocking {
27+
benchmarks.flow.scrabble.flow {
28+
repeat(iterations) { emit(it) }
29+
}.flatMapMerge { value ->
30+
flowOf(value)
31+
}.collect {
32+
if (it == -1) error("")
33+
}
34+
}
35+
36+
@Benchmark
37+
fun flatMapSafe() = runBlocking {
38+
kotlinx.coroutines.flow.flow {
39+
repeat(iterations) { emit(it) }
40+
}.flatMapMerge { value ->
41+
flowOf(value)
42+
}.collect {
43+
if (it == -1) error("")
44+
}
45+
}
46+
47+
}

benchmarks/src/jmh/kotlin/benchmarks/flow/misc/Numbers.kt renamed to benchmarks/src/jmh/kotlin/benchmarks/flow/NumbersBenchmark.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
*/
44

55

6-
package benchmarks.flow.misc
6+
package benchmarks.flow
77

88
import benchmarks.flow.scrabble.flow
99
import io.reactivex.*
@@ -35,7 +35,7 @@ import java.util.concurrent.*
3535
@BenchmarkMode(Mode.AverageTime)
3636
@OutputTimeUnit(TimeUnit.MICROSECONDS)
3737
@State(Scope.Benchmark)
38-
open class Numbers {
38+
open class NumbersBenchmark {
3939

4040
companion object {
4141
private const val primes = 100

benchmarks/src/jmh/kotlin/benchmarks/flow/misc/SafeFlowBenchmark.kt renamed to benchmarks/src/jmh/kotlin/benchmarks/flow/SafeFlowBenchmark.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

5-
package benchmarks.flow.misc
5+
package benchmarks.flow
66

77
import kotlinx.coroutines.*
88
import kotlinx.coroutines.flow.*

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ public final class kotlinx/coroutines/flow/internal/SafeCollectorKt {
993993
public static final fun unsafeFlow (Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
994994
}
995995

996-
public final class kotlinx/coroutines/flow/internal/SendingCollector : kotlinx/coroutines/flow/internal/ConcurrentFlowCollector {
996+
public final class kotlinx/coroutines/flow/internal/SendingCollector : kotlinx/coroutines/flow/FlowCollector {
997997
public fun <init> (Lkotlinx/coroutines/channels/SendChannel;)V
998998
public fun emit (Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
999999
}

kotlinx-coroutines-core/common/src/channels/Produce.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ public fun <E> CoroutineScope.produce(
126126
return coroutine
127127
}
128128

129-
private class ProducerCoroutine<E>(
129+
internal open class ProducerCoroutine<E>(
130130
parentContext: CoroutineContext, channel: Channel<E>
131131
) : ChannelCoroutine<E>(parentContext, channel, active = true), ProducerScope<E> {
132132
override val isActive: Boolean

kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public abstract class ChannelFlow<T>(
5858
protected abstract suspend fun collectTo(scope: ProducerScope<T>)
5959

6060
// shared code to create a suspend lambda from collectTo function in one place
61-
private val collectToFun: suspend (ProducerScope<T>) -> Unit
61+
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
6262
get() = { collectTo(it) }
6363

6464
private val produceCapacity: Int
@@ -140,13 +140,11 @@ internal class ChannelFlowOperatorImpl<T>(
140140
private fun <T> FlowCollector<T>.withUndispatchedContextCollector(emitContext: CoroutineContext): FlowCollector<T> = when (this) {
141141
// SendingCollector & NopCollector do not care about the context at all and can be used as is
142142
is SendingCollector, is NopCollector -> this
143-
// Original collector is concurrent, so wrap into ConcurrentUndispatchedContextCollector (also concurrent)
144-
is ConcurrentFlowCollector -> ConcurrentUndispatchedContextCollector(this, emitContext)
145143
// Otherwise just wrap into UndispatchedContextCollector interface implementation
146144
else -> UndispatchedContextCollector(this, emitContext)
147145
}
148146

149-
private open class UndispatchedContextCollector<T>(
147+
private class UndispatchedContextCollector<T>(
150148
downstream: FlowCollector<T>,
151149
private val emitContext: CoroutineContext
152150
) : FlowCollector<T> {
@@ -157,12 +155,6 @@ private open class UndispatchedContextCollector<T>(
157155
withContextUndispatched(emitContext, countOrElement, emitRef, value)
158156
}
159157

160-
// named class for a combination of UndispatchedContextCollector & ConcurrentFlowCollector interface
161-
private class ConcurrentUndispatchedContextCollector<T>(
162-
downstream: ConcurrentFlowCollector<T>,
163-
emitContext: CoroutineContext
164-
) : UndispatchedContextCollector<T>(downstream, emitContext), ConcurrentFlowCollector<T>
165-
166158
// Efficiently computes block(value) in the newContext
167159
private suspend fun <T, V> withContextUndispatched(
168160
newContext: CoroutineContext,

kotlinx-coroutines-core/common/src/flow/internal/Concurrent.kt

Lines changed: 0 additions & 81 deletions
This file was deleted.

kotlinx-coroutines-core/common/src/flow/internal/FlowCoroutine.kt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ internal fun <R> scopedFlow(@BuilderInference block: suspend CoroutineScope.(Flo
5252
flowScope { block(collector) }
5353
}
5454

55+
internal fun <T> CoroutineScope.flowProduce(
56+
context: CoroutineContext,
57+
capacity: Int = 0,
58+
@BuilderInference block: suspend ProducerScope<T>.() -> Unit
59+
): ReceiveChannel<T> {
60+
val channel = Channel<T>(capacity)
61+
val newContext = newCoroutineContext(context)
62+
val coroutine = FlowProduceCoroutine(newContext, channel)
63+
coroutine.start(CoroutineStart.DEFAULT, coroutine, block)
64+
return coroutine
65+
}
66+
5567
private class FlowCoroutine<T>(
5668
context: CoroutineContext,
5769
uCont: Continuation<T>
@@ -61,3 +73,13 @@ private class FlowCoroutine<T>(
6173
return cancelImpl(cause)
6274
}
6375
}
76+
77+
private class FlowProduceCoroutine<T>(
78+
parentContext: CoroutineContext,
79+
channel: Channel<T>
80+
) : ProducerCoroutine<T>(parentContext, channel) {
81+
public override fun childCancelled(cause: Throwable): Boolean {
82+
if (cause is ChildCancelledException) return true
83+
return cancelImpl(cause)
84+
}
85+
}

kotlinx-coroutines-core/common/src/flow/internal/Merge.kt

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,21 @@ internal class ChannelFlowTransformLatest<T, R>(
3838
}
3939

4040
internal class ChannelFlowMerge<T>(
41-
flow: Flow<Flow<T>>,
41+
private val flow: Flow<Flow<T>>,
4242
private val concurrency: Int,
4343
context: CoroutineContext = EmptyCoroutineContext,
44-
capacity: Int = Channel.OPTIONAL_CHANNEL
45-
) : ChannelFlowOperator<Flow<T>, T>(flow, context, capacity) {
44+
capacity: Int = Channel.BUFFERED
45+
) : ChannelFlow<T>(context, capacity) {
4646
override fun create(context: CoroutineContext, capacity: Int): ChannelFlow<T> =
4747
ChannelFlowMerge(flow, concurrency, context, capacity)
4848

49-
// The actual merge implementation with concurrency limit
50-
private suspend fun mergeImpl(scope: CoroutineScope, collector: ConcurrentFlowCollector<T>) {
49+
override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
50+
return scope.flowProduce(context, capacity, block = collectToFun)
51+
}
52+
53+
override suspend fun collectTo(scope: ProducerScope<T>) {
5154
val semaphore = Semaphore(concurrency)
55+
val collector = SendingCollector(scope)
5256
val job: Job? = coroutineContext[Job]
5357
flow.collect { inner ->
5458
/*
@@ -68,19 +72,6 @@ internal class ChannelFlowMerge<T>(
6872
}
6973
}
7074

71-
// Fast path in ChannelFlowOperator calls this function (channel was not created yet)
72-
override suspend fun flowCollect(collector: FlowCollector<T>) {
73-
// this function should not have been invoked when channel was explicitly requested
74-
assert { capacity == Channel.OPTIONAL_CHANNEL }
75-
flowScope {
76-
mergeImpl(this, collector.asConcurrentFlowCollector())
77-
}
78-
}
79-
80-
// Slow path when output channel is required (and was created)
81-
override suspend fun collectTo(scope: ProducerScope<T>) =
82-
mergeImpl(scope, SendingCollector(scope))
83-
8475
override fun additionalToStringProps(): String =
8576
"concurrency=$concurrency, "
8677
}

0 commit comments

Comments
 (0)