Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
package com.google.firebase.dataconnect

import com.google.firebase.auth.FirebaseAuth
import com.google.firebase.dataconnect.core.FirebaseDataConnectInternal
import com.google.firebase.dataconnect.testutil.DataConnectBackend
import com.google.firebase.dataconnect.testutil.DataConnectIntegrationTestBase
import com.google.firebase.dataconnect.testutil.InProcessDataConnectGrpcServer
import com.google.firebase.dataconnect.testutil.awaitAuthReady
import com.google.firebase.dataconnect.testutil.newInstance
import com.google.firebase.dataconnect.testutil.property.arbitrary.dataConnect
import com.google.firebase.dataconnect.testutil.schemas.PersonSchema
Expand Down Expand Up @@ -202,7 +202,7 @@ class AuthIntegrationTest : DataConnectIntegrationTestBase() {
}

private suspend fun signIn() {
(personSchema.dataConnect as FirebaseDataConnectInternal).awaitAuthReady()
personSchema.dataConnect.awaitAuthReady()
val authResult = auth.run { signInAnonymously().await() }
withClue("authResult.user returned from signInAnonymously()") {
authResult.user.shouldNotBeNull()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import com.google.android.gms.tasks.Tasks
import com.google.firebase.appcheck.AppCheckProvider
import com.google.firebase.appcheck.AppCheckProviderFactory
import com.google.firebase.appcheck.FirebaseAppCheck
import com.google.firebase.dataconnect.core.FirebaseDataConnectInternal
import com.google.firebase.dataconnect.generated.GeneratedConnector
import com.google.firebase.dataconnect.generated.GeneratedMutation
import com.google.firebase.dataconnect.generated.GeneratedQuery
Expand All @@ -32,6 +31,8 @@ import com.google.firebase.dataconnect.testutil.DataConnectIntegrationTestBase
import com.google.firebase.dataconnect.testutil.DataConnectTestAppCheckToken
import com.google.firebase.dataconnect.testutil.FirebaseAuthBackend
import com.google.firebase.dataconnect.testutil.InProcessDataConnectGrpcServer
import com.google.firebase.dataconnect.testutil.awaitAppCheckReady
import com.google.firebase.dataconnect.testutil.awaitAuthReady
import com.google.firebase.dataconnect.testutil.getFirebaseAppIdFromStrings
import com.google.firebase.dataconnect.testutil.newInstance
import com.google.firebase.dataconnect.util.SuspendingLazy
Expand Down Expand Up @@ -138,7 +139,7 @@ class GrpcMetadataIntegrationTest : DataConnectIntegrationTestBase() {
fun executeQueryShouldNotSendAuthMetadataWhenNotLoggedIn() = runTest {
val grpcServer = inProcessDataConnectGrpcServer.newInstance()
val dataConnect = dataConnectFactory.newInstance(grpcServer)
(dataConnect as FirebaseDataConnectInternal).awaitAuthReady()
dataConnect.awaitAuthReady()
val queryRef = dataConnect.query("qryfyk7yfppfe", Unit, serializer<Unit>(), serializer<Unit>())
val metadatasJob = async { grpcServer.metadatas.first() }

Expand All @@ -151,7 +152,7 @@ class GrpcMetadataIntegrationTest : DataConnectIntegrationTestBase() {
fun executeMutationShouldNotSendAuthMetadataWhenNotLoggedIn() = runTest {
val grpcServer = inProcessDataConnectGrpcServer.newInstance()
val dataConnect = dataConnectFactory.newInstance(grpcServer)
(dataConnect as FirebaseDataConnectInternal).awaitAuthReady()
dataConnect.awaitAuthReady()
val mutationRef =
dataConnect.mutation("mutckjpte9v9j", Unit, serializer<Unit>(), serializer<Unit>())
val metadatasJob = async { grpcServer.metadatas.first() }
Expand All @@ -165,7 +166,7 @@ class GrpcMetadataIntegrationTest : DataConnectIntegrationTestBase() {
fun executeQueryShouldSendAuthMetadataWhenLoggedIn() = runTest {
val grpcServer = inProcessDataConnectGrpcServer.newInstance()
val dataConnect = dataConnectFactory.newInstance(grpcServer)
(dataConnect as FirebaseDataConnectInternal).awaitAuthReady()
dataConnect.awaitAuthReady()
val queryRef = dataConnect.query("qryyarwrxe2fv", Unit, serializer<Unit>(), serializer<Unit>())
val metadatasJob = async { grpcServer.metadatas.first() }
firebaseAuthSignIn(dataConnect)
Expand All @@ -179,7 +180,7 @@ class GrpcMetadataIntegrationTest : DataConnectIntegrationTestBase() {
fun executeMutationShouldSendAuthMetadataWhenLoggedIn() = runTest {
val grpcServer = inProcessDataConnectGrpcServer.newInstance()
val dataConnect = dataConnectFactory.newInstance(grpcServer)
(dataConnect as FirebaseDataConnectInternal).awaitAuthReady()
dataConnect.awaitAuthReady()
val mutationRef =
dataConnect.mutation("mutayn7as5k7d", Unit, serializer<Unit>(), serializer<Unit>())
val metadatasJob = async { grpcServer.metadatas.first() }
Expand All @@ -194,7 +195,7 @@ class GrpcMetadataIntegrationTest : DataConnectIntegrationTestBase() {
fun executeQueryShouldNotSendAuthMetadataAfterLogout() = runTest {
val grpcServer = inProcessDataConnectGrpcServer.newInstance()
val dataConnect = dataConnectFactory.newInstance(grpcServer)
(dataConnect as FirebaseDataConnectInternal).awaitAuthReady()
dataConnect.awaitAuthReady()
val queryRef = dataConnect.query("qryyarwrxe2fv", Unit, serializer<Unit>(), serializer<Unit>())
val metadatasJob1 = async { grpcServer.metadatas.first() }
val metadatasJob2 = async { grpcServer.metadatas.take(2).last() }
Expand All @@ -212,7 +213,7 @@ class GrpcMetadataIntegrationTest : DataConnectIntegrationTestBase() {
fun executeMutationShouldNotSendAuthMetadataAfterLogout() = runTest {
val grpcServer = inProcessDataConnectGrpcServer.newInstance()
val dataConnect = dataConnectFactory.newInstance(grpcServer)
(dataConnect as FirebaseDataConnectInternal).awaitAuthReady()
dataConnect.awaitAuthReady()
val mutationRef =
dataConnect.mutation("mutvw945ag3vv", Unit, serializer<Unit>(), serializer<Unit>())
val metadatasJob1 = async { grpcServer.metadatas.first() }
Expand All @@ -233,7 +234,7 @@ class GrpcMetadataIntegrationTest : DataConnectIntegrationTestBase() {
// appcheck token is sent at all.
val grpcServer = inProcessDataConnectGrpcServer.newInstance()
val dataConnect = dataConnectFactory.newInstance(grpcServer)
(dataConnect as FirebaseDataConnectInternal).awaitAppCheckReady()
dataConnect.awaitAppCheckReady()
val queryRef = dataConnect.query("qrybbeekpkkck", Unit, serializer<Unit>(), serializer<Unit>())
val metadatasJob = async { grpcServer.metadatas.first() }

Expand All @@ -248,7 +249,7 @@ class GrpcMetadataIntegrationTest : DataConnectIntegrationTestBase() {
// appcheck token is sent at all.
val grpcServer = inProcessDataConnectGrpcServer.newInstance()
val dataConnect = dataConnectFactory.newInstance(grpcServer)
(dataConnect as FirebaseDataConnectInternal).awaitAppCheckReady()
dataConnect.awaitAppCheckReady()
val mutationRef =
dataConnect.mutation("mutbs7hhxk39c", Unit, serializer<Unit>(), serializer<Unit>())
val metadatasJob = async { grpcServer.metadatas.first() }
Expand All @@ -262,7 +263,7 @@ class GrpcMetadataIntegrationTest : DataConnectIntegrationTestBase() {
fun executeQueryShouldSendAppCheckMetadataWhenAppCheckIsEnabled() = runTest {
val grpcServer = inProcessDataConnectGrpcServer.newInstance()
val dataConnect = dataConnectFactory.newInstance(grpcServer)
(dataConnect as FirebaseDataConnectInternal).awaitAppCheckReady()
dataConnect.awaitAppCheckReady()
val queryRef = dataConnect.query("qryyarwrxe2fv", Unit, serializer<Unit>(), serializer<Unit>())
val metadatasJob = async { grpcServer.metadatas.first() }
val appCheck = FirebaseAppCheck.getInstance(dataConnect.app)
Expand All @@ -277,7 +278,7 @@ class GrpcMetadataIntegrationTest : DataConnectIntegrationTestBase() {
fun executeMutationShouldSendAppCheckMetadataWhenAppCheckIsEnabled() = runTest {
val grpcServer = inProcessDataConnectGrpcServer.newInstance()
val dataConnect = dataConnectFactory.newInstance(grpcServer)
(dataConnect as FirebaseDataConnectInternal).awaitAppCheckReady()
dataConnect.awaitAppCheckReady()
val mutationRef =
dataConnect.mutation("mutz4hzqzpgb4", Unit, serializer<Unit>(), serializer<Unit>())
val metadatasJob = async { grpcServer.metadatas.first() }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.dataconnect.testutil

import com.google.firebase.dataconnect.FirebaseDataConnect
import com.google.firebase.dataconnect.core.FirebaseDataConnectInternal

suspend fun FirebaseDataConnect.awaitAuthReady() =
(this as FirebaseDataConnectInternal).awaitAuthReady()

suspend fun FirebaseDataConnect.awaitAppCheckReady() =
(this as FirebaseDataConnectInternal).awaitAppCheckReady()
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import com.google.firebase.inject.Provider
import com.google.firebase.internal.api.FirebaseNoSignedInUserException
import com.google.firebase.util.nextAlphanumericString
import java.lang.ref.WeakReference
import java.util.concurrent.atomic.AtomicReference
import kotlin.coroutines.coroutineContext
import kotlin.random.Random
import kotlinx.coroutines.CancellationException
Expand All @@ -46,10 +45,9 @@ import kotlinx.coroutines.async
import kotlinx.coroutines.cancel
import kotlinx.coroutines.ensureActive
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import kotlinx.coroutines.yield

/** Base class that shares logic for managing the Auth token and AppCheck token. */
internal sealed class DataConnectCredentialsTokenManager<T : Any>(
Expand All @@ -61,9 +59,6 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
val instanceId: String
get() = logger.nameWithId

private val _providerAvailable = MutableStateFlow(false)
val providerAvailable: StateFlow<Boolean> = _providerAvailable.asStateFlow()

@Suppress("LeakingThis") private val weakThis = WeakReference(this)

private val coroutineScope =
Expand All @@ -87,49 +82,39 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
}
}

private interface ProviderProvider<T> {
val provider: T?
}

private sealed interface State<out T> {

/** State indicating that [close] has been invoked. */
object Closed : State<Nothing>

/** State indicating that there is no outstanding "get token" request. */
class Idle<T>(

/**
* The [InternalAuthProvider] or [InteropAppCheckTokenProvider]; may be null if the deferred
* has not yet given us a provider.
*/
override val provider: T?,

sealed interface StateWithForceTokenRefresh<out T> : State<T> {
/** The value to specify for `forceRefresh` on the next invocation of [getToken]. */
val forceTokenRefresh: Boolean
) : State<T>, ProviderProvider<T>
}

/** State indicating that there _is_ an outstanding "get token" request. */
class Active<T>(
/** State indicating that the token provider is not (yet?) available. */
data class New(override val forceTokenRefresh: Boolean) : StateWithForceTokenRefresh<Nothing>

/**
* The [InternalAuthProvider] or [InteropAppCheckTokenProvider] that is performing the "get
* token" request.
*/
sealed interface StateWithProvider<out T> : State<T> {
/** The token provider, [InternalAuthProvider] or [InteropAppCheckTokenProvider] */
val provider: T
}

/** State indicating that there is no outstanding "get token" request. */
data class Idle<T>(override val provider: T, override val forceTokenRefresh: Boolean) :
StateWithProvider<T>, StateWithForceTokenRefresh<T>

/** State indicating that there _is_ an outstanding "get token" request. */
data class Active<out T>(
override val provider: T,

/** The job that is performing the "get token" request. */
val job: Deferred<SequencedReference<Result<GetTokenResult>>>
) : State<T>, ProviderProvider<T>
) : StateWithProvider<T>
}

/**
* The current state of this object. The value should only be changed in a compare-and-swap loop
* in order to be thread-safe. Such a loop should call `yield()` on each iteration to allow other
* coroutines to run on the thread.
*/
private val state =
AtomicReference<State<T>>(State.Idle(provider = null, forceTokenRefresh = false))
/** The current state of this object. */
private val state = MutableStateFlow<State<T>>(State.New(forceTokenRefresh = false))

/**
* Adds the token listener to the given provider.
Expand Down Expand Up @@ -168,19 +153,42 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
setClosedState()
}

/**
* Suspends until the token provider becomes available to this object.
*
* If [close] has been invoked, or is invoked _before_ a token provider becomes available, then
* this method returns normally, as if a token provider _had_ become available.
*/
suspend fun awaitTokenProvider() {
logger.debug { "awaitTokenProvider() start" }
val currentState =
state
.filter {
when (it) {
State.Closed -> true
is State.New -> false
is State.Idle -> true
is State.Active -> true
}
}
.first()
logger.debug { "awaitTokenProvider() done: currentState=$currentState" }
}

// This function must ONLY be called from close().
private fun setClosedState() {
while (true) {
val oldState = state.get()
val providerProvider: ProviderProvider<T> =
val oldState = state.value
val provider: T? =
when (oldState) {
is State.Closed -> return
is State.Idle -> oldState
is State.Active -> oldState
is State.New -> null
is State.Idle -> oldState.provider
is State.Active -> oldState.provider
}

if (state.compareAndSet(oldState, State.Closed)) {
providerProvider.provider?.let { removeTokenListener(it) }
provider?.let { removeTokenListener(it) }
break
}
}
Expand All @@ -191,27 +199,28 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
*
* If [close] has been called, this method does nothing.
*/
suspend fun forceRefresh() {
fun forceRefresh() {
logger.debug { "forceRefresh()" }
while (true) {
val oldState = state.get()
val oldStateProviderProvider =
val oldState = state.value
val newState: State.StateWithForceTokenRefresh<T> =
when (oldState) {
is State.Closed -> return
is State.Idle -> oldState
is State.New -> oldState.copy(forceTokenRefresh = true)
is State.Idle -> oldState.copy(forceTokenRefresh = true)
is State.Active -> {
val message = "needs token refresh (wgrwbrvjxt)"
oldState.job.cancel(message, ForceRefresh(message))
oldState
State.Idle(oldState.provider, forceTokenRefresh = true)
}
}

val newState = State.Idle(oldStateProviderProvider.provider, forceTokenRefresh = true)
check(newState.forceTokenRefresh) {
"newState.forceTokenRefresh should be true (error code gnvr2wx7nz)"
}
if (state.compareAndSet(oldState, newState)) {
break
}

yield()
}
}

Expand Down Expand Up @@ -246,7 +255,7 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
logger.debug { "$invocationId getToken(requestId=$requestId)" }
while (true) {
val attemptSequenceNumber = nextSequenceNumber()
val oldState = state.get()
val oldState = state.value

val newState: State.Active<T> =
when (oldState) {
Expand All @@ -257,13 +266,13 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
}
throw CredentialsTokenManagerClosedException(this)
}
is State.Idle -> {
if (oldState.provider === null) {
logger.debug {
"$invocationId getToken() returns null (token provider is not (yet?) available)"
}
return null
is State.New -> {
logger.debug {
"$invocationId getToken() returns null (token provider is not (yet?) available)"
}
return null
}
is State.Idle -> {
newActiveState(invocationId, oldState.provider, oldState.forceTokenRefresh)
}
is State.Active -> {
Expand Down Expand Up @@ -342,7 +351,7 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
addTokenListener(newProvider)

while (true) {
val oldState = state.get()
val oldState = state.value
val newState =
when (oldState) {
is State.Closed -> {
Expand All @@ -353,6 +362,7 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
removeTokenListener(newProvider)
break
}
is State.New -> State.Idle(newProvider, oldState.forceTokenRefresh)
is State.Idle -> State.Idle(newProvider, oldState.forceTokenRefresh)
is State.Active -> {
val newProviderClassName = newProvider::class.qualifiedName
Expand All @@ -366,8 +376,6 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
break
}
}

_providerAvailable.value = true
}

/**
Expand Down
Loading
Loading