Skip to content

Commit 623db41

Browse files
qwwdfsadlowasser
andauthored
Add update, updateAndGet, and getAndUpdate extension functions to MutableStateFlow (#2729)
* Add update, updateAndGet, and getAndUpdate extension functions to MutableStateFlow (#2720). Fixes #2720 Co-authored-by: Louis Wasserman <[email protected]>
1 parent d8eb80e commit 623db41

File tree

4 files changed

+106
-24
lines changed

4 files changed

+106
-24
lines changed

kotlinx-coroutines-core/api/kotlinx-coroutines-core.api

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,9 @@ public abstract interface class kotlinx/coroutines/flow/StateFlow : kotlinx/coro
11001100

11011101
public final class kotlinx/coroutines/flow/StateFlowKt {
11021102
public static final fun MutableStateFlow (Ljava/lang/Object;)Lkotlinx/coroutines/flow/MutableStateFlow;
1103+
public static final fun getAndUpdate (Lkotlinx/coroutines/flow/MutableStateFlow;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
1104+
public static final fun update (Lkotlinx/coroutines/flow/MutableStateFlow;Lkotlin/jvm/functions/Function1;)V
1105+
public static final fun updateAndGet (Lkotlinx/coroutines/flow/MutableStateFlow;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
11031106
}
11041107

11051108
public abstract class kotlinx/coroutines/flow/internal/ChannelFlow : kotlinx/coroutines/flow/internal/FusibleFlow {

kotlinx-coroutines-core/common/src/flow/StateFlow.kt

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import kotlin.native.concurrent.*
3737
* val counter = _counter.asStateFlow() // publicly exposed as read-only state flow
3838
*
3939
* fun inc() {
40-
* _counter.value++
40+
* _counter.update { count -> count + 1 } // atomic, safe for concurrent use
4141
* }
4242
* }
4343
* ```
@@ -186,6 +186,56 @@ public interface MutableStateFlow<T> : StateFlow<T>, MutableSharedFlow<T> {
186186
@Suppress("FunctionName")
187187
public fun <T> MutableStateFlow(value: T): MutableStateFlow<T> = StateFlowImpl(value ?: NULL)
188188

189+
// ------------------------------------ Update methods ------------------------------------
190+
191+
/**
192+
* Updates the [MutableStateFlow.value] atomically using the specified [function] of its value, and returns the new
193+
* value.
194+
*
195+
* [function] may be evaluated multiple times, if [value] is being concurrently updated.
196+
*/
197+
public inline fun <T> MutableStateFlow<T>.updateAndGet(function: (T) -> T): T {
198+
while (true) {
199+
val prevValue = value
200+
val nextValue = function(prevValue)
201+
if (compareAndSet(prevValue, nextValue)) {
202+
return nextValue
203+
}
204+
}
205+
}
206+
207+
/**
208+
* Updates the [MutableStateFlow.value] atomically using the specified [function] of its value, and returns its
209+
* prior value.
210+
*
211+
* [function] may be evaluated multiple times, if [value] is being concurrently updated.
212+
*/
213+
public inline fun <T> MutableStateFlow<T>.getAndUpdate(function: (T) -> T): T {
214+
while (true) {
215+
val prevValue = value
216+
val nextValue = function(prevValue)
217+
if (compareAndSet(prevValue, nextValue)) {
218+
return prevValue
219+
}
220+
}
221+
}
222+
223+
224+
/**
225+
* Updates the [MutableStateFlow.value] atomically using the specified [function] of its value.
226+
*
227+
* [function] may be evaluated multiple times, if [value] is being concurrently updated.
228+
*/
229+
public inline fun <T> MutableStateFlow<T>.update(function: (T) -> T) {
230+
while (true) {
231+
val prevValue = value
232+
val nextValue = function(prevValue)
233+
if (compareAndSet(prevValue, nextValue)) {
234+
return
235+
}
236+
}
237+
}
238+
189239
// ------------------------------------ Implementation ------------------------------------
190240

191241
@SharedImmutable
@@ -366,10 +416,7 @@ private class StateFlowImpl<T>(
366416
}
367417

368418
internal fun MutableStateFlow<Int>.increment(delta: Int) {
369-
while (true) { // CAS loop
370-
val current = value
371-
if (compareAndSet(current, current + delta)) return
372-
}
419+
update { it + delta }
373420
}
374421

375422
internal fun <T> StateFlow<T>.fuseStateFlow(

kotlinx-coroutines-core/common/test/flow/sharing/StateFlowTest.kt

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -174,23 +174,11 @@ class StateFlowTest : TestBase() {
174174
}
175175

176176
@Test
177-
fun testReferenceUpdatesAndCAS() {
178-
val d0 = Data(0)
179-
val d0_1 = Data(0)
180-
val d1 = Data(1)
181-
val d1_1 = Data(1)
182-
val d1_2 = Data(1)
183-
val state = MutableStateFlow(d0)
184-
assertSame(d0, state.value)
185-
state.value = d0_1 // equal, nothing changes
186-
assertSame(d0, state.value)
187-
state.value = d1 // updates
188-
assertSame(d1, state.value)
189-
assertFalse(state.compareAndSet(d0, d0)) // wrong value
190-
assertSame(d1, state.value)
191-
assertTrue(state.compareAndSet(d1_1, d1_2)) // "updates", but ref stays
192-
assertSame(d1, state.value)
193-
assertTrue(state.compareAndSet(d1_1, d0)) // updates, reference changes
194-
assertSame(d0, state.value)
177+
fun testUpdate() = runTest {
178+
val state = MutableStateFlow(0)
179+
state.update { it + 2 }
180+
assertEquals(2, state.value)
181+
state.update { it + 3 }
182+
assertEquals(5, state.value)
195183
}
196-
}
184+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.flow
6+
7+
import kotlinx.coroutines.*
8+
import org.junit.*
9+
import kotlin.test.*
10+
import kotlin.test.Test
11+
12+
class StateFlowUpdateStressTest : TestBase() {
13+
private val iterations = 1_000_000 * stressTestMultiplier
14+
15+
@get:Rule
16+
public val executor = ExecutorRule(2)
17+
18+
@Test
19+
fun testUpdate() = doTest { update { it + 1 } }
20+
21+
@Test
22+
fun testUpdateAndGet() = doTest { updateAndGet { it + 1 } }
23+
24+
@Test
25+
fun testGetAndUpdate() = doTest { getAndUpdate { it + 1 } }
26+
27+
private fun doTest(increment: MutableStateFlow<Int>.() -> Unit) = runTest {
28+
val flow = MutableStateFlow(0)
29+
val j1 = launch(Dispatchers.Default) {
30+
repeat(iterations / 2) {
31+
flow.increment()
32+
}
33+
}
34+
35+
val j2 = launch(Dispatchers.Default) {
36+
repeat(iterations / 2) {
37+
flow.increment()
38+
}
39+
}
40+
41+
joinAll(j1, j2)
42+
assertEquals(iterations, flow.value)
43+
}
44+
}

0 commit comments

Comments
 (0)