diff --git a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectCredentialsTokenManager.kt b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectCredentialsTokenManager.kt index 9458a678bff..d96e544d6af 100644 --- a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectCredentialsTokenManager.kt +++ b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectCredentialsTokenManager.kt @@ -47,6 +47,7 @@ import kotlinx.coroutines.ensureActive import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.getAndUpdate import kotlinx.coroutines.launch /** Base class that shares logic for managing the Auth token and AppCheck token. */ @@ -148,9 +149,18 @@ internal sealed class DataConnectCredentialsTokenManager( */ fun close() { logger.debug { "close()" } + weakThis.clear() coroutineScope.cancel() - setClosedState() + + val oldState = state.getAndUpdate { State.Closed } + when (oldState) { + is State.Closed -> {} + is State.New -> {} + is State.StateWithProvider -> { + removeTokenListener(oldState.provider) + } + } } /** @@ -175,25 +185,6 @@ internal sealed class DataConnectCredentialsTokenManager( logger.debug { "awaitTokenProvider() done: currentState=$currentState" } } - // This function must ONLY be called from close(). - private fun setClosedState() { - while (true) { - val oldState = state.value - val provider: T? = - when (oldState) { - is State.Closed -> return - is State.New -> null - is State.Idle -> oldState.provider - is State.Active -> oldState.provider - } - - if (state.compareAndSet(oldState, State.Closed)) { - provider?.let { removeTokenListener(it) } - break - } - } - } - /** * Sets a flag to force-refresh the token upon the next call to [getToken]. * @@ -201,25 +192,34 @@ internal sealed class DataConnectCredentialsTokenManager( */ fun forceRefresh() { logger.debug { "forceRefresh()" } - while (true) { - val oldState = state.value - val newState: State.StateWithForceTokenRefresh = - when (oldState) { - is State.Closed -> return - is State.New -> oldState.copy(forceTokenRefresh = true) - is State.Idle -> oldState.copy(forceTokenRefresh = true) - is State.Active -> { - val message = "needs token refresh (wgrwbrvjxt)" - oldState.job.cancel(message, ForceRefresh(message)) - State.Idle(oldState.provider, forceTokenRefresh = true) + val oldState = + state.getAndUpdate { currentState -> + val newState = + when (currentState) { + is State.Closed -> State.Closed + is State.New -> currentState.copy(forceTokenRefresh = true) + is State.Idle -> currentState.copy(forceTokenRefresh = true) + is State.Active -> State.Idle(currentState.provider, forceTokenRefresh = true) } + + check(newState is State.Closed || newState is State.StateWithForceTokenRefresh) { + "internal error gbazc7qr66: newState should have been Closed or " + + "StateWithForceTokenRefresh, but got: $newState" + } + check((newState as? State.StateWithForceTokenRefresh)?.forceTokenRefresh !== false) { + "internal error fnzwyrsez2: newState.forceTokenRefresh should have been true" } - check(newState.forceTokenRefresh) { - "newState.forceTokenRefresh should be true (error code gnvr2wx7nz)" + newState } - if (state.compareAndSet(oldState, newState)) { - break + + when (oldState) { + is State.Closed -> {} + is State.New -> {} + is State.Idle -> {} + is State.Active -> { + val message = "needs token refresh (wgrwbrvjxt)" + oldState.job.cancel(message, ForceRefresh(message)) } } } @@ -350,30 +350,30 @@ internal sealed class DataConnectCredentialsTokenManager( logger.debug { "onProviderAvailable(newProvider=$newProvider)" } addTokenListener(newProvider) - while (true) { - val oldState = state.value - val newState = - when (oldState) { - is State.Closed -> { - logger.debug { - "onProviderAvailable(newProvider=$newProvider)" + - " unregistering token listener that was just added" - } - removeTokenListener(newProvider) - break - } - is State.New -> State.Idle(newProvider, oldState.forceTokenRefresh) - is State.Idle -> State.Idle(newProvider, oldState.forceTokenRefresh) - is State.Active -> { - val newProviderClassName = newProvider::class.qualifiedName - val message = "a new provider $newProviderClassName is available (symhxtmazy)" - oldState.job.cancel(message, NewProvider(message)) - State.Idle(newProvider, forceTokenRefresh = false) - } + val oldState = + state.getAndUpdate { currentState -> + when (currentState) { + is State.Closed -> State.Closed + is State.New -> State.Idle(newProvider, currentState.forceTokenRefresh) + is State.Idle -> State.Idle(newProvider, currentState.forceTokenRefresh) + is State.Active -> State.Idle(newProvider, forceTokenRefresh = false) } + } - if (state.compareAndSet(oldState, newState)) { - break + when (oldState) { + is State.Closed -> { + logger.debug { + "onProviderAvailable(newProvider=$newProvider)" + + " unregistering token listener that was just added" + } + removeTokenListener(newProvider) + } + is State.New -> {} + is State.Idle -> {} + is State.Active -> { + val newProviderClassName = newProvider::class.qualifiedName + val message = "a new provider $newProviderClassName is available (symhxtmazy)" + oldState.job.cancel(message, NewProvider(message)) } } } diff --git a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/FirebaseDataConnectImpl.kt b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/FirebaseDataConnectImpl.kt index f49afda964f..9ad64cc6054 100644 --- a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/FirebaseDataConnectImpl.kt +++ b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/FirebaseDataConnectImpl.kt @@ -54,6 +54,7 @@ import kotlinx.coroutines.async import kotlinx.coroutines.cancel import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.updateAndGet import kotlinx.coroutines.runBlocking import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock @@ -406,35 +407,42 @@ internal class FirebaseDataConnectImpl( dataConnectAuth.close() dataConnectAppCheck.close() - // Start the job to asynchronously close the gRPC client. - while (true) { - val oldCloseJob = closeJob.value - - oldCloseJob.ref?.let { - if (!it.isCancelled) { - return it - } + // Create the "close job" to asynchronously close the gRPC client. + @OptIn(DelicateCoroutinesApi::class) + val newCloseJob = + GlobalScope.async(start = CoroutineStart.LAZY) { + lazyGrpcRPCs.initializedValueOrNull?.close() } + newCloseJob.invokeOnCompletion { exception -> + if (exception === null) { + logger.debug { "close() completed successfully" } + } else { + logger.warn(exception) { "close() failed" } + } + } - @OptIn(DelicateCoroutinesApi::class) - val newCloseJob = - GlobalScope.async(start = CoroutineStart.LAZY) { - lazyGrpcRPCs.initializedValueOrNull?.close() - } - - newCloseJob.invokeOnCompletion { exception -> - if (exception === null) { - logger.debug { "close() completed successfully" } + // Register the new "close job". Do not overwrite a close job that is already in progress (to + // avoid having more than one close job in progress at a time) or a close job that completed + // successfully (since there is nothing to do if a previous close job was successful). + val updatedCloseJobRef = + closeJob.updateAndGet { currentCloseJobRef: NullableReference> -> + if (currentCloseJobRef.ref !== null && !currentCloseJobRef.ref.isCancelled) { + currentCloseJobRef } else { - logger.warn(exception) { "close() failed" } + NullableReference(newCloseJob) } } - if (closeJob.compareAndSet(oldCloseJob, NullableReference(newCloseJob))) { - newCloseJob.start() - return newCloseJob + // Start the updated "close job" (if it was already started then start() is a no-op). + val updatedCloseJob = + checkNotNull(updatedCloseJobRef.ref) { + "internal error: closeJob.updateAndGet() returned a NullableReference whose 'ref' " + + "property was null; however it should NOT have been null (error code y5fk4ntdnd)" } - } + updatedCloseJob.start() + + // Return the "close job", which _may_ already be completed, so the caller can await it. + return updatedCloseJob } // The generated SDK relies on equals() and hashCode() using object identity. diff --git a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/QuerySubscriptionImpl.kt b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/QuerySubscriptionImpl.kt index ceeb861cab8..2ca9aea6771 100644 --- a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/QuerySubscriptionImpl.kt +++ b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/QuerySubscriptionImpl.kt @@ -27,6 +27,7 @@ import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch internal class QuerySubscriptionImpl(query: QueryRefImpl) : @@ -80,22 +81,17 @@ internal class QuerySubscriptionImpl(query: QueryRefImpl= prospectiveSequenceNumber) { - return - } - } - - if (_lastResult.compareAndSet(currentLastResult, NullableReference(prospectiveLastResult))) { - return + _lastResult.update { currentLastResult -> + if ( + currentLastResult.ref != null && + currentLastResult.ref.sequencedResult.sequenceNumber >= + prospectiveLastResult.sequencedResult.sequenceNumber + ) { + currentLastResult + } else { + NullableReference(prospectiveLastResult) } } } diff --git a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/querymgr/RegisteredDataDeserialzer.kt b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/querymgr/RegisteredDataDeserialzer.kt index 3f94a7f95a0..1fa6d94eae4 100644 --- a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/querymgr/RegisteredDataDeserialzer.kt +++ b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/querymgr/RegisteredDataDeserialzer.kt @@ -31,6 +31,7 @@ import kotlinx.coroutines.channels.BufferOverflow import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.onSubscription +import kotlinx.coroutines.flow.update import kotlinx.coroutines.withContext import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.modules.SerializersModule @@ -84,17 +85,14 @@ internal class RegisteredDataDeserializer( lazyDeserialize(requestId, sequencedResult) ) - // Use a compare-and-swap ("CAS") loop to ensure that an old update never clobbers a newer one. - while (true) { - val currentUpdate = latestUpdate.value + latestUpdate.update { currentUpdate -> if ( currentUpdate.ref !== null && currentUpdate.ref.sequenceNumber > sequencedResult.sequenceNumber ) { - break // don't clobber a newer update with an older one - } - if (latestUpdate.compareAndSet(currentUpdate, NullableReference(newUpdate))) { - break + currentUpdate // don't clobber a newer update with an older one + } else { + NullableReference(newUpdate) } } diff --git a/firebase-dataconnect/testutil/src/main/kotlin/com/google/firebase/dataconnect/testutil/SuspendingCountDownLatch.kt b/firebase-dataconnect/testutil/src/main/kotlin/com/google/firebase/dataconnect/testutil/SuspendingCountDownLatch.kt index 7098a390886..b7828e0c36a 100644 --- a/firebase-dataconnect/testutil/src/main/kotlin/com/google/firebase/dataconnect/testutil/SuspendingCountDownLatch.kt +++ b/firebase-dataconnect/testutil/src/main/kotlin/com/google/firebase/dataconnect/testutil/SuspendingCountDownLatch.kt @@ -19,6 +19,7 @@ package com.google.firebase.dataconnect.testutil import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.update /** * An implementation of [java.util.concurrent.CountDownLatch] that suspends instead of blocking. @@ -60,14 +61,10 @@ class SuspendingCountDownLatch(count: Int) { * @throws IllegalStateException if called when the count has already reached zero. */ fun countDown(): SuspendingCountDownLatch { - while (true) { - val oldValue = _count.value - check(oldValue > 0) { "countDown() called too many times (oldValue=$oldValue)" } - - val newValue = oldValue - 1 - if (_count.compareAndSet(oldValue, newValue)) { - return this - } + _count.update { currentValue -> + check(currentValue > 0) { "countDown() called too many times (currentValue=$currentValue)" } + currentValue - 1 } + return this } }