Skip to content

Commit 35f9ad5

Browse files
committed
FlatMap improvements:
* Proper stress tests * Liveness guarantee in corner cases * merge for Flow<Flow<*>>
1 parent 2b8218a commit 35f9ad5

File tree

3 files changed

+148
-19
lines changed

3 files changed

+148
-19
lines changed

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,9 @@ public final class kotlinx/coroutines/flow/FlowKt {
821821
public static final fun map (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
822822
public static final fun mapNotNull (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
823823
public static final fun merge (Ljava/lang/Iterable;II)Lkotlinx/coroutines/flow/Flow;
824+
public static final fun merge (Lkotlinx/coroutines/flow/Flow;II)Lkotlinx/coroutines/flow/Flow;
824825
public static synthetic fun merge$default (Ljava/lang/Iterable;IIILjava/lang/Object;)Lkotlinx/coroutines/flow/Flow;
826+
public static synthetic fun merge$default (Lkotlinx/coroutines/flow/Flow;IIILjava/lang/Object;)Lkotlinx/coroutines/flow/Flow;
825827
public static final fun onEach (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
826828
public static final fun onErrorCollect (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function1;)Lkotlinx/coroutines/flow/Flow;
827829
public static synthetic fun onErrorCollect$default (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lkotlinx/coroutines/flow/Flow;

kotlinx-coroutines-core/common/src/flow/operators/Merge.kt

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,37 @@ public fun <T, R> Flow<T>.flatMap(concurrency: Int = 16, bufferSize: Int = 16, m
3333
semaphore.send(Unit) // Acquire concurrency permit
3434
val inner = mapper(outerValue)
3535
launch {
36-
inner.collect { value ->
37-
flatMap.push(value)
36+
try {
37+
inner.collect { value ->
38+
flatMap.emit(value)
39+
}
40+
} finally {
41+
semaphore.receive() // Release concurrency permit
3842
}
39-
semaphore.receive() // Release concurrency permit
4043
}
4144
}
4245
}
4346
}
4447
}
4548

4649
/**
47-
* Merges given sequence of flows into a single flow with no guarantees on the order.
50+
* Merges the given sequence of flows into a single flow with no guarantees on the order.
4851
*
4952
* [bufferSize] parameter controls the size of backpressure aka the amount of queued in-flight elements.
5053
* [concurrency] parameter controls the size of in-flight flows.
5154
*/
5255
@FlowPreview
5356
public fun <T> Iterable<Flow<T>>.merge(concurrency: Int = 16, bufferSize: Int = 16): Flow<T> = asFlow().flatMap(concurrency, bufferSize) { it }
5457

58+
/**
59+
* Merges the given flow of flows into a single flow with no guarantees on the order.
60+
*
61+
* [bufferSize] parameter controls the size of backpressure aka the amount of queued in-flight elements.
62+
* [concurrency] parameter controls the size of in-flight flows.
63+
*/
64+
@FlowPreview
65+
public fun <T> Flow<Flow<T>>.merge(concurrency: Int = 16, bufferSize: Int = 16): Flow<T> = flatMap(concurrency, bufferSize) { it }
66+
5567
/**
5668
* Concatenates values of each flow sequentially, without interleaving them.
5769
*/
@@ -85,31 +97,45 @@ private class SerializingFlatMapCollector<T>(
8597
) {
8698

8799
// Let's try to leverage the fact that flatMap is never contended
88-
private val channel: Channel<Any?> by lazy { Channel<Any?>(bufferSize) }
89-
private val inProgress = atomic(false)
100+
private val channel: Channel<Any?> by lazy { Channel<Any?>(bufferSize) } // Should be any, but KT-30796
101+
private val inProgressLock = atomic(false)
102+
private val sentValues = atomic(0)
90103

91-
public suspend fun push(value: T) {
92-
if (!inProgress.compareAndSet(false, true)) {
104+
public suspend fun emit(value: T) {
105+
if (!inProgressLock.tryAcquire()) {
106+
sentValues.incrementAndGet()
93107
channel.send(value ?: NullSurrogate)
94-
if (inProgress.compareAndSet(false, true)) {
95-
helpPush()
108+
if (inProgressLock.tryAcquire()) {
109+
helpEmit()
96110
}
97111
return
98112
}
99113

100114
downstream.emit(value)
101-
helpPush()
115+
helpEmit()
102116
}
103117

104118
@Suppress("UNCHECKED_CAST")
105-
private suspend fun helpPush() {
106-
var element = channel.poll()
107-
while (element != null) { // TODO receive or closed
108-
if (element === NullSurrogate) downstream.emit(null as T)
109-
else downstream.emit(element as T)
110-
element = channel.poll()
111-
}
119+
private suspend fun helpEmit() {
120+
while (true) {
121+
var element = channel.poll()
122+
while (element != null) { // TODO receive or closed
123+
if (element === NullSurrogate) downstream.emit(null as T)
124+
else downstream.emit(element as T)
125+
sentValues.decrementAndGet()
126+
element = channel.poll()
127+
}
112128

113-
inProgress.value = false
129+
inProgressLock.release()
130+
// Enforce liveness of the algorithm
131+
// TODO looks like isEmpty use-case
132+
if (sentValues.value == 0 || !inProgressLock.tryAcquire()) break
133+
}
114134
}
115135
}
136+
137+
private fun AtomicBoolean.tryAcquire(): Boolean = compareAndSet(false, true)
138+
139+
private fun AtomicBoolean.release() {
140+
value = false
141+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.flow
6+
7+
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.scheduling.*
9+
import org.junit.Assume.*
10+
import org.junit.Test
11+
import java.util.concurrent.atomic.*
12+
import kotlin.test.*
13+
14+
class FlatMapStressTest : TestBase() {
15+
16+
private val iterations = 2000 * stressTestMultiplier
17+
private val expectedSum = iterations * (iterations + 1) / 2
18+
19+
@Test
20+
fun testConcurrencyLevel() = runTest {
21+
withContext(Dispatchers.Default) {
22+
testConcurrencyLevel(2)
23+
}
24+
}
25+
26+
@Test
27+
fun testConcurrencyLevel2() = runTest {
28+
withContext(Dispatchers.Default) {
29+
testConcurrencyLevel(3)
30+
}
31+
}
32+
33+
@Test
34+
fun testBufferSize() = runTest {
35+
val bufferSize = 5
36+
withContext(Dispatchers.Default) {
37+
val inFlightElements = AtomicLong(0L)
38+
var result = 0
39+
(1..iterations step 4).asFlow().flatMap(bufferSize = bufferSize) { value ->
40+
unsafeFlow {
41+
repeat(4) {
42+
emit(value + it)
43+
inFlightElements.incrementAndGet()
44+
}
45+
}
46+
}.collect { value ->
47+
val inFlight = inFlightElements.get()
48+
assertTrue(inFlight <= bufferSize + 1,
49+
"Expected less in flight elements than ${bufferSize + 1}, but had $inFlight")
50+
inFlightElements.decrementAndGet()
51+
result += value
52+
}
53+
54+
assertEquals(0, inFlightElements.get())
55+
assertEquals(expectedSum, result)
56+
}
57+
}
58+
59+
@Test
60+
fun testDelivery() = runTest {
61+
withContext(Dispatchers.Default) {
62+
val result = (1..iterations step 4).asFlow().flatMap { value ->
63+
unsafeFlow {
64+
repeat(4) { emit(value + it) }
65+
}
66+
}.sum()
67+
assertEquals(expectedSum, result)
68+
}
69+
}
70+
71+
@Test
72+
fun testIndependentShortBursts() = runTest {
73+
withContext(Dispatchers.Default) {
74+
repeat(iterations) {
75+
val result = (1..4).asFlow().flatMap { value ->
76+
unsafeFlow {
77+
emit(value)
78+
emit(value)
79+
}
80+
}.sum()
81+
assertEquals(20, result)
82+
}
83+
}
84+
}
85+
86+
private suspend fun testConcurrencyLevel(maxConcurrency: Int) {
87+
assumeTrue(maxConcurrency <= CORE_POOL_SIZE)
88+
val concurrency = AtomicLong()
89+
val result = (1..iterations).asFlow().flatMap(concurrency = maxConcurrency) { value ->
90+
unsafeFlow {
91+
val current = concurrency.incrementAndGet()
92+
assertTrue(current in 1..maxConcurrency)
93+
emit(value)
94+
concurrency.decrementAndGet()
95+
}
96+
}.sum()
97+
98+
assertEquals(0, concurrency.get())
99+
assertEquals(expectedSum, result)
100+
}
101+
}

0 commit comments

Comments
 (0)