Skip to content

Commit 80af499

Browse files
authored
Do not propagate cancellation to the upstream in Flow flat* operators (#2964)
* Do not propagate cancellation to the upstream in Flow flat* operators Fixes #2964
1 parent 85b17ce commit 80af499

File tree

8 files changed

+84
-46
lines changed

8 files changed

+84
-46
lines changed

kotlinx-coroutines-core/common/src/channels/Produce.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ internal fun <E> CoroutineScope.produce(
133133
return coroutine
134134
}
135135

136-
internal open class ProducerCoroutine<E>(
136+
private class ProducerCoroutine<E>(
137137
parentContext: CoroutineContext, channel: Channel<E>
138138
) : ChannelCoroutine<E>(parentContext, channel, true, active = true), ProducerScope<E> {
139139
override val isActive: Boolean

kotlinx-coroutines-core/common/src/flow/internal/FlowCoroutine.kt

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,33 +51,11 @@ internal fun <R> scopedFlow(@BuilderInference block: suspend CoroutineScope.(Flo
5151
flowScope { block(this@flow) }
5252
}
5353

54-
internal fun <T> CoroutineScope.flowProduce(
55-
context: CoroutineContext,
56-
capacity: Int = 0,
57-
@BuilderInference block: suspend ProducerScope<T>.() -> Unit
58-
): ReceiveChannel<T> {
59-
val channel = Channel<T>(capacity)
60-
val newContext = newCoroutineContext(context)
61-
val coroutine = FlowProduceCoroutine(newContext, channel)
62-
coroutine.start(CoroutineStart.ATOMIC, coroutine, block)
63-
return coroutine
64-
}
65-
6654
private class FlowCoroutine<T>(
6755
context: CoroutineContext,
6856
uCont: Continuation<T>
6957
) : ScopeCoroutine<T>(context, uCont) {
70-
public override fun childCancelled(cause: Throwable): Boolean {
71-
if (cause is ChildCancelledException) return true
72-
return cancelImpl(cause)
73-
}
74-
}
75-
76-
private class FlowProduceCoroutine<T>(
77-
parentContext: CoroutineContext,
78-
channel: Channel<T>
79-
) : ProducerCoroutine<T>(parentContext, channel) {
80-
public override fun childCancelled(cause: Throwable): Boolean {
58+
override fun childCancelled(cause: Throwable): Boolean {
8159
if (cause is ChildCancelledException) return true
8260
return cancelImpl(cause)
8361
}

kotlinx-coroutines-core/common/src/flow/internal/Merge.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ internal class ChannelFlowTransformLatest<T, R>(
2222

2323
override suspend fun flowCollect(collector: FlowCollector<R>) {
2424
assert { collector is SendingCollector } // So cancellation behaviour is not leaking into the downstream
25-
flowScope {
25+
coroutineScope {
2626
var previousFlow: Job? = null
2727
flow.collect { value ->
2828
previousFlow?.apply {
@@ -49,7 +49,7 @@ internal class ChannelFlowMerge<T>(
4949
ChannelFlowMerge(flow, concurrency, context, capacity, onBufferOverflow)
5050

5151
override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
52-
return scope.flowProduce(context, capacity, block = collectToFun)
52+
return scope.produce(context, capacity, block = collectToFun)
5353
}
5454

5555
override suspend fun collectTo(scope: ProducerScope<T>) {
@@ -87,7 +87,7 @@ internal class ChannelLimitedFlowMerge<T>(
8787
ChannelLimitedFlowMerge(flows, context, capacity, onBufferOverflow)
8888

8989
override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
90-
return scope.flowProduce(context, capacity, block = collectToFun)
90+
return scope.produce(context, capacity, block = collectToFun)
9191
}
9292

9393
override suspend fun collectTo(scope: ProducerScope<T>) {

kotlinx-coroutines-core/common/src/flow/operators/Merge.kt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public fun <T, R> Flow<T>.flatMapConcat(transform: suspend (value: T) -> Flow<R>
6161
* its concurrent merging so that only one properly configured channel is used for execution of merging logic.
6262
*
6363
* @param concurrency controls the number of in-flight flows, at most [concurrency] flows are collected
64-
* at the same time. By default it is equal to [DEFAULT_CONCURRENCY].
64+
* at the same time. By default, it is equal to [DEFAULT_CONCURRENCY].
6565
*/
6666
@FlowPreview
6767
public fun <T, R> Flow<T>.flatMapMerge(
@@ -71,8 +71,7 @@ public fun <T, R> Flow<T>.flatMapMerge(
7171
map(transform).flattenMerge(concurrency)
7272

7373
/**
74-
* Flattens the given flow of flows into a single flow in a sequentially manner, without interleaving nested flows.
75-
* This method is conceptually identical to `flattenMerge(concurrency = 1)` but has faster implementation.
74+
* Flattens the given flow of flows into a single flow in a sequential manner, without interleaving nested flows.
7675
*
7776
* Inner flows are collected by this operator *sequentially*.
7877
*/
@@ -119,7 +118,7 @@ public fun <T> merge(vararg flows: Flow<T>): Flow<T> = flows.asIterable().merge(
119118
* Flattens the given flow of flows into a single flow with a [concurrency] limit on the number of
120119
* concurrently collected flows.
121120
*
122-
* If [concurrency] is more than 1, then inner flows are be collected by this operator *concurrently*.
121+
* If [concurrency] is more than 1, then inner flows are collected by this operator *concurrently*.
123122
* With `concurrency == 1` this operator is identical to [flattenConcat].
124123
*
125124
* ### Operator fusion
@@ -131,7 +130,7 @@ public fun <T> merge(vararg flows: Flow<T>): Flow<T> = flows.asIterable().merge(
131130
* and size of its output buffer can be changed by applying subsequent [buffer] operator.
132131
*
133132
* @param concurrency controls the number of in-flight flows, at most [concurrency] flows are collected
134-
* at the same time. By default it is equal to [DEFAULT_CONCURRENCY].
133+
* at the same time. By default, it is equal to [DEFAULT_CONCURRENCY].
135134
*/
136135
@FlowPreview
137136
public fun <T> Flow<Flow<T>>.flattenMerge(concurrency: Int = DEFAULT_CONCURRENCY): Flow<T> {

kotlinx-coroutines-core/common/test/flow/operators/FlatMapMergeFastPathTest.kt

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,14 @@ class FlatMapMergeFastPathTest : FlatMapMergeBaseTest() {
3939

4040
@Test
4141
fun testCancellationExceptionDownstream() = runTest {
42-
val flow = flow {
43-
emit(1)
44-
hang { expect(2) }
45-
}.flatMapMerge {
42+
val flow = flowOf(1, 2, 3).flatMapMerge {
4643
flow {
4744
emit(it)
48-
expect(1)
4945
throw CancellationException("")
5046
}
5147
}.buffer(64)
5248

53-
assertFailsWith<CancellationException>(flow)
54-
finish(3)
49+
assertEquals(listOf(1, 2, 3), flow.toList())
5550
}
5651

5752
@Test

kotlinx-coroutines-core/common/test/flow/operators/FlatMapMergeTest.kt

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,14 @@ class FlatMapMergeTest : FlatMapMergeBaseTest() {
6969

7070
@Test
7171
fun testCancellationExceptionDownstream() = runTest {
72-
val flow = flow {
73-
emit(1)
74-
hang { expect(2) }
75-
}.flatMapMerge {
72+
val flow = flowOf(1, 2, 3).flatMapMerge {
7673
flow {
7774
emit(it)
78-
expect(1)
7975
throw CancellationException("")
8076
}
8177
}
8278

83-
assertFailsWith<CancellationException>(flow)
84-
finish(3)
79+
assertEquals(listOf(1, 2, 3), flow.toList())
8580
}
8681

8782
@Test

kotlinx-coroutines-core/common/test/flow/operators/FlattenConcatTest.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,17 @@ class FlattenConcatTest : FlatMapBaseTest() {
3636
consumer.cancelAndJoin()
3737
finish(2)
3838
}
39+
40+
@Test
41+
fun testCancellation() = runTest {
42+
val flow = flow {
43+
repeat(5) {
44+
emit(flow {
45+
if (it == 2) throw CancellationException("")
46+
emit(1)
47+
})
48+
}
49+
}
50+
assertFailsWith<CancellationException>(flow.flattenConcat())
51+
}
3952
}

kotlinx-coroutines-core/common/test/flow/operators/MergeTest.kt

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,64 @@ abstract class MergeTest : TestBase() {
4545
assertEquals(listOf("source"), result)
4646
}
4747

48+
@Test
49+
fun testOneSourceCancelled() = runTest {
50+
val flow = flow {
51+
expect(1)
52+
emit(1)
53+
expect(2)
54+
yield()
55+
throw CancellationException("")
56+
}
57+
58+
val otherFlow = flow {
59+
repeat(5) {
60+
emit(1)
61+
yield()
62+
}
63+
64+
expect(3)
65+
}
66+
67+
val result = listOf(flow, otherFlow).merge().toList()
68+
assertEquals(MutableList(6) { 1 }, result)
69+
finish(4)
70+
}
71+
72+
@Test
73+
fun testOneSourceCancelledNonFused() = runTest {
74+
val flow = flow {
75+
expect(1)
76+
emit(1)
77+
expect(2)
78+
yield()
79+
throw CancellationException("")
80+
}
81+
82+
val otherFlow = flow {
83+
repeat(5) {
84+
emit(1)
85+
yield()
86+
}
87+
88+
expect(3)
89+
}
90+
91+
val result = listOf(flow, otherFlow).nonFuseableMerge().toList()
92+
assertEquals(MutableList(6) { 1 }, result)
93+
finish(4)
94+
}
95+
96+
private fun <T> Iterable<Flow<T>>.nonFuseableMerge(): Flow<T> {
97+
return channelFlow {
98+
forEach { flow ->
99+
launch {
100+
flow.collect { send(it) }
101+
}
102+
}
103+
}
104+
}
105+
48106
@Test
49107
fun testIsolatedContext() = runTest {
50108
val flow = flow {

0 commit comments

Comments
 (0)