diff --git a/firebase-dataconnect/CHANGELOG.md b/firebase-dataconnect/CHANGELOG.md index 7f64bb4d9e8..e6a6f7d636b 100644 --- a/firebase-dataconnect/CHANGELOG.md +++ b/firebase-dataconnect/CHANGELOG.md @@ -1,5 +1,8 @@ # Unreleased +- [changed] Internal refactor for managing Auth and App Check tokens + ([#7184](https://github.com/firebase/firebase-android-sdk/pull/7184)) + # 17.1.0 - [fixed] Addressed minor reference documentation issues (#7399) diff --git a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectAppCheck.kt b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectAppCheck.kt index abe60656f73..ce069f54a61 100644 --- a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectAppCheck.kt +++ b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectAppCheck.kt @@ -20,6 +20,7 @@ import com.google.firebase.annotations.DeferredApi import com.google.firebase.appcheck.AppCheckTokenResult import com.google.firebase.appcheck.interop.AppCheckTokenListener import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider +import com.google.firebase.dataconnect.core.DataConnectAppCheck.GetAppCheckTokenResult import com.google.firebase.dataconnect.core.Globals.toScrubbedAccessToken import com.google.firebase.dataconnect.core.LoggerGlobals.debug import kotlinx.coroutines.CoroutineDispatcher @@ -32,7 +33,7 @@ internal class DataConnectAppCheck( blockingDispatcher: CoroutineDispatcher, logger: Logger, ) : - DataConnectCredentialsTokenManager( + DataConnectCredentialsTokenManager( deferredProvider = deferredAppCheckTokenProvider, parentCoroutineScope = parentCoroutineScope, blockingDispatcher = blockingDispatcher, @@ -48,7 +49,9 @@ internal class DataConnectAppCheck( provider.removeAppCheckTokenListener(appCheckTokenListener) override suspend fun getToken(provider: InteropAppCheckTokenProvider, forceRefresh: Boolean) = - provider.getToken(forceRefresh).await().let { GetTokenResult(it.token) } + provider.getToken(forceRefresh).await().let { GetAppCheckTokenResult(it.token) } + + data class GetAppCheckTokenResult(override val token: String?) : GetTokenResult private class AppCheckTokenListenerImpl(private val logger: Logger) : AppCheckTokenListener { override fun onAppCheckTokenChanged(tokenResult: AppCheckTokenResult) { diff --git a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectAuth.kt b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectAuth.kt index a3a27ccc31c..59131a1ddad 100644 --- a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectAuth.kt +++ b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectAuth.kt @@ -19,6 +19,7 @@ package com.google.firebase.dataconnect.core import com.google.firebase.annotations.DeferredApi import com.google.firebase.auth.internal.IdTokenListener import com.google.firebase.auth.internal.InternalAuthProvider +import com.google.firebase.dataconnect.core.DataConnectAuth.GetAuthTokenResult import com.google.firebase.dataconnect.core.Globals.toScrubbedAccessToken import com.google.firebase.dataconnect.core.LoggerGlobals.debug import com.google.firebase.internal.InternalTokenResult @@ -32,7 +33,7 @@ internal class DataConnectAuth( blockingDispatcher: CoroutineDispatcher, logger: Logger, ) : - DataConnectCredentialsTokenManager( + DataConnectCredentialsTokenManager( deferredProvider = deferredAuthProvider, parentCoroutineScope = parentCoroutineScope, blockingDispatcher = blockingDispatcher, @@ -48,7 +49,9 @@ internal class DataConnectAuth( provider.removeIdTokenListener(idTokenListener) override suspend fun getToken(provider: InternalAuthProvider, forceRefresh: Boolean) = - provider.getAccessToken(forceRefresh).await().let { GetTokenResult(it.token) } + provider.getAccessToken(forceRefresh).await().let { GetAuthTokenResult(it.token) } + + data class GetAuthTokenResult(override val token: String?) : GetTokenResult private class IdTokenListenerImpl(private val logger: Logger) : IdTokenListener { override fun onIdTokenChanged(tokenResult: InternalTokenResult) { diff --git a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectCredentialsTokenManager.kt b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectCredentialsTokenManager.kt index b67db33b36c..90330820bfc 100644 --- a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectCredentialsTokenManager.kt +++ b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectCredentialsTokenManager.kt @@ -20,6 +20,7 @@ import com.google.firebase.annotations.DeferredApi import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider import com.google.firebase.auth.internal.InternalAuthProvider import com.google.firebase.dataconnect.DataConnectException +import com.google.firebase.dataconnect.core.DataConnectCredentialsTokenManager.GetTokenResult import com.google.firebase.dataconnect.core.Globals.toScrubbedAccessToken import com.google.firebase.dataconnect.core.LoggerGlobals.debug import com.google.firebase.dataconnect.core.LoggerGlobals.warn @@ -52,7 +53,7 @@ import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch /** Base class that shares logic for managing the Auth token and AppCheck token. */ -internal sealed class DataConnectCredentialsTokenManager( +internal sealed class DataConnectCredentialsTokenManager( private val deferredProvider: com.google.firebase.inject.Deferred, parentCoroutineScope: CoroutineScope, private val blockingDispatcher: CoroutineDispatcher, @@ -75,13 +76,13 @@ internal sealed class DataConnectCredentialsTokenManager( } ) - private sealed interface State { + private sealed interface State { /** * State indicating that the object has just been created and [initialize] has not yet been * called. */ - object New : State + object New : State /** * State indicating that [initialize] has been invoked but the token provider is not (yet?) @@ -93,33 +94,33 @@ internal sealed class DataConnectCredentialsTokenManager( } /** State indicating that [close] has been invoked. */ - object Closed : State + object Closed : State - sealed interface StateWithForceTokenRefresh : State { + sealed interface StateWithForceTokenRefresh : State { /** The value to specify for `forceRefresh` on the next invocation of [getToken]. */ val forceTokenRefresh: Boolean } - sealed interface StateWithProvider : State { + sealed interface StateWithProvider : State { /** The token provider, [InternalAuthProvider] or [InteropAppCheckTokenProvider] */ val provider: T } /** State indicating that there is no outstanding "get token" request. */ data class Idle(override val provider: T, override val forceTokenRefresh: Boolean) : - StateWithProvider, StateWithForceTokenRefresh + StateWithProvider, StateWithForceTokenRefresh /** State indicating that there _is_ an outstanding "get token" request. */ - data class Active( + data class Active( override val provider: T, /** The job that is performing the "get token" request. */ - val job: Deferred>> - ) : StateWithProvider + val job: Deferred>> + ) : StateWithProvider } /** The current state of this object. */ - private val state = MutableStateFlow>(State.New) + private val state = MutableStateFlow>(State.New) /** * Adds the token listener to the given provider. @@ -139,7 +140,7 @@ internal sealed class DataConnectCredentialsTokenManager( * Starts an asynchronous task to get a new access token from the given provider, forcing a token * refresh if and only if `forceRefresh` is true. */ - protected abstract suspend fun getToken(provider: T, forceRefresh: Boolean): GetTokenResult + protected abstract suspend fun getToken(provider: T, forceRefresh: Boolean): R /** * Initializes this object. @@ -274,7 +275,7 @@ internal sealed class DataConnectCredentialsTokenManager( invocationId: String, provider: T, forceRefresh: Boolean - ): State.Active { + ): State.Active { val coroutineName = CoroutineName( "$instanceId 535gmcvv5a $invocationId getToken(" + @@ -296,14 +297,14 @@ internal sealed class DataConnectCredentialsTokenManager( * @throws DataConnectException if [close] has been called or is called while the operation is in * progress. */ - suspend fun getToken(requestId: String): String? { + suspend fun getToken(requestId: String): R? { val invocationId = "gat" + Random.nextAlphanumericString(length = 8) logger.debug { "$invocationId getToken(requestId=$requestId)" } while (true) { val attemptSequenceNumber = nextSequenceNumber() val oldState = state.value - val newState: State.Active = + val newState: State.Active = when (oldState) { is State.New -> throw IllegalStateException("initialize() must be called before getToken()") @@ -381,11 +382,12 @@ internal sealed class DataConnectCredentialsTokenManager( } } - val accessToken = sequencedResult!!.ref.getOrThrow().token + val tokenResult: R = sequencedResult!!.ref.getOrThrow() logger.debug { - "$invocationId getToken() returns retrieved token: ${accessToken?.toScrubbedAccessToken()}" + "$invocationId getToken() returns retrieved token: " + + tokenResult.token?.toScrubbedAccessToken() } - return accessToken + return tokenResult } } @@ -440,8 +442,9 @@ internal sealed class DataConnectCredentialsTokenManager( * strong reference to the [DataConnectCredentialsTokenManager] instance indefinitely, in the case * that the callback never occurs. */ - private class DeferredProviderHandlerImpl( - private val weakCredentialsTokenManagerRef: WeakReference> + private class DeferredProviderHandlerImpl( + private val weakCredentialsTokenManagerRef: + WeakReference> ) : DeferredHandler { override fun handle(provider: Provider) { weakCredentialsTokenManagerRef.get()?.onProviderAvailable(provider.get()) @@ -449,7 +452,7 @@ internal sealed class DataConnectCredentialsTokenManager( } private class CredentialsTokenManagerClosedException( - tokenProvider: DataConnectCredentialsTokenManager<*> + tokenProvider: DataConnectCredentialsTokenManager<*, *> ) : DataConnectException( "DataConnectCredentialsTokenManager ${tokenProvider.instanceId} was closed (code cqrbq4zfvy)" @@ -458,7 +461,9 @@ internal sealed class DataConnectCredentialsTokenManager( private class GetTokenCancelledException(cause: Throwable) : DataConnectException("getToken() was cancelled, likely by close() (code rqdd4jam9d)", cause) - protected data class GetTokenResult(val token: String?) + interface GetTokenResult { + val token: String? + } private companion object { diff --git a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectGrpcMetadata.kt b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectGrpcMetadata.kt index dd5ae849264..4e988aee89f 100644 --- a/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectGrpcMetadata.kt +++ b/firebase-dataconnect/src/main/kotlin/com/google/firebase/dataconnect/core/DataConnectGrpcMetadata.kt @@ -85,12 +85,8 @@ internal class DataConnectGrpcMetadata( if (appId.isNotBlank()) { it.put(gmpAppIdHeader, appId) } - if (authToken !== null) { - it.put(firebaseAuthTokenHeader, authToken) - } - if (appCheckToken !== null) { - it.put(firebaseAppCheckTokenHeader, appCheckToken) - } + authToken?.token?.let { token -> it.put(firebaseAuthTokenHeader, token) } + appCheckToken?.token?.let { token -> it.put(firebaseAppCheckTokenHeader, token) } } } diff --git a/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/core/DataConnectAuthUnitTest.kt b/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/core/DataConnectAuthUnitTest.kt index b35b3c0a402..e3c9576f813 100644 --- a/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/core/DataConnectAuthUnitTest.kt +++ b/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/core/DataConnectAuthUnitTest.kt @@ -49,6 +49,7 @@ import io.kotest.assertions.withClue import io.kotest.matchers.collections.shouldContain import io.kotest.matchers.collections.shouldContainExactly import io.kotest.matchers.nulls.shouldBeNull +import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.shouldBe import io.kotest.matchers.types.shouldBeSameInstanceAs import io.kotest.property.Arb @@ -302,7 +303,7 @@ class DataConnectAuthUnitTest { val result = dataConnectAuth.getToken(requestId) - withClue("result=$result") { result shouldBe accessToken } + withClue("result=$result") { result.shouldNotBeNull().token shouldBe accessToken } mockLogger.shouldHaveLoggedExactlyOneMessageContaining(requestId) mockLogger.shouldHaveLoggedExactlyOneMessageContaining( "returns retrieved token: ${accessToken.toScrubbedAccessToken()}" @@ -363,7 +364,7 @@ class DataConnectAuthUnitTest { dataConnectAuth.forceRefresh() val result = dataConnectAuth.getToken(requestId) - withClue("result=$result") { result shouldBe accessToken } + withClue("result=$result") { result.shouldNotBeNull().token shouldBe accessToken } verify(exactly = 1) { mockInternalAuthProvider.getAccessToken(true) } verify(exactly = 0) { mockInternalAuthProvider.getAccessToken(false) } mockLogger.shouldHaveLoggedExactlyOneMessageContaining(requestId) @@ -419,7 +420,7 @@ class DataConnectAuthUnitTest { taskForToken(accessTokenGenerator.next().also { tokens.add(it) }) } - val results = List(5) { dataConnectAuth.getToken(requestId) } + val results = List(5) { dataConnectAuth.getToken(requestId)?.token } results shouldContainExactly tokens } @@ -447,7 +448,7 @@ class DataConnectAuthUnitTest { } } - val actualTokens = jobs.map { it.await() } + val actualTokens = jobs.map { it.await()?.token } actualTokens.forEachIndexed { index, token -> withClue("actualTokens[$index]") { tokens shouldContain token } } @@ -481,7 +482,7 @@ class DataConnectAuthUnitTest { val result = dataConnectAuth.getToken(requestId) - withClue("result=$result") { result shouldBe tokens.last() } + withClue("result=$result") { result.shouldNotBeNull().token shouldBe tokens.last() } verify(exactly = 2) { mockInternalAuthProvider.getAccessToken(true) } verify(exactly = 1) { mockInternalAuthProvider.getAccessToken(false) } mockLogger.shouldHaveLoggedAtLeastOneMessageContaining("retrying due to needs token refresh") @@ -496,11 +497,7 @@ class DataConnectAuthUnitTest { advanceUntilIdle() val invocationCount = AtomicInteger(0) val tokens = CopyOnWriteArrayList() - val getTokenJob2 = - async(start = CoroutineStart.LAZY) { - val accessToken = dataConnectAuth.getToken(requestId) - accessToken - } + val getTokenJob2 = async(start = CoroutineStart.LAZY) { dataConnectAuth.getToken(requestId) } coEvery { mockInternalAuthProvider.getAccessToken(any()) } coAnswers { if (invocationCount.getAndIncrement() == 0) { @@ -509,16 +506,15 @@ class DataConnectAuthUnitTest { getTokenJob2.start() advanceUntilIdle() } - val rv = taskForToken(accessTokenGenerator.next().also { tokens.add(it) }) - rv + taskForToken(accessTokenGenerator.next().also { tokens.add(it) }) } val result1 = dataConnectAuth.getToken(requestId) withClue("getTokenJob2.isActive") { getTokenJob2.isActive shouldBe true } val result2 = getTokenJob2.await() - withClue("result1=$result1") { result1 shouldBe tokens[0] } - withClue("result2=$result2") { result2 shouldBe tokens[1] } + withClue("result1=$result1") { result1.shouldNotBeNull().token shouldBe tokens[0] } + withClue("result2=$result2") { result2.shouldNotBeNull().token shouldBe tokens[1] } verify(exactly = 2) { mockInternalAuthProvider.getAccessToken(false) } verify(exactly = 0) { mockInternalAuthProvider.getAccessToken(true) } mockLogger.shouldHaveLoggedExactlyOneMessageContaining("got an old result; retrying") diff --git a/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/core/DataConnectGrpcMetadataUnitTest.kt b/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/core/DataConnectGrpcMetadataUnitTest.kt index e6f45753132..772caf9af07 100644 --- a/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/core/DataConnectGrpcMetadataUnitTest.kt +++ b/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/core/DataConnectGrpcMetadataUnitTest.kt @@ -20,6 +20,8 @@ import androidx.test.ext.junit.runners.AndroidJUnit4 import com.google.firebase.dataconnect.BuildConfig import com.google.firebase.dataconnect.FirebaseDataConnect.CallerSdkType import com.google.firebase.dataconnect.testutil.FirebaseAppUnitTestingRule +import com.google.firebase.dataconnect.testutil.property.arbitrary.appCheckTokenResult +import com.google.firebase.dataconnect.testutil.property.arbitrary.authTokenResult import com.google.firebase.dataconnect.testutil.property.arbitrary.dataConnect import com.google.firebase.dataconnect.testutil.property.arbitrary.dataConnectGrpcMetadata import io.grpc.Metadata @@ -175,8 +177,8 @@ class DataConnectGrpcMetadataUnitTest { @Test fun `should include x-firebase-auth-token when the auth token is not null`() = runTest { val dataConnectAuth: DataConnectAuth = mockk() - val accessToken = Arb.dataConnect.accessToken().next() - coEvery { dataConnectAuth.getToken(any()) } returns accessToken + val authTokenResult = Arb.dataConnect.authTokenResult().next() + coEvery { dataConnectAuth.getToken(any()) } returns authTokenResult val dataConnectGrpcMetadata = Arb.dataConnect .dataConnectGrpcMetadata(dataConnectAuth = Arb.constant(dataConnectAuth)) @@ -189,7 +191,7 @@ class DataConnectGrpcMetadataUnitTest { metadata.asClue { it.keys() shouldContain "x-firebase-auth-token" val metadataKey = Metadata.Key.of("x-firebase-auth-token", Metadata.ASCII_STRING_MARSHALLER) - it.get(metadataKey) shouldBe accessToken + it.get(metadataKey) shouldBe authTokenResult.token } } @@ -212,9 +214,9 @@ class DataConnectGrpcMetadataUnitTest { @Test fun `should include x-firebase-appcheck when the AppCheck token is not null`() = runTest { - val accessToken = Arb.dataConnect.accessToken().next() + val appCheckTokenResult = Arb.dataConnect.appCheckTokenResult().next() val dataConnectAppCheck: DataConnectAppCheck = mockk { - coEvery { getToken(any()) } returns accessToken + coEvery { getToken(any()) } returns appCheckTokenResult } val dataConnectGrpcMetadata = Arb.dataConnect @@ -228,7 +230,7 @@ class DataConnectGrpcMetadataUnitTest { metadata.asClue { it.keys() shouldContain "x-firebase-appcheck" val metadataKey = Metadata.Key.of("x-firebase-appcheck", Metadata.ASCII_STRING_MARSHALLER) - it.get(metadataKey) shouldBe accessToken + it.get(metadataKey) shouldBe appCheckTokenResult.token } } diff --git a/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/testutil/property/arbitrary/arbs.kt b/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/testutil/property/arbitrary/arbs.kt index 89b2c89bb0f..6d66f01fa88 100644 --- a/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/testutil/property/arbitrary/arbs.kt +++ b/firebase-dataconnect/src/test/kotlin/com/google/firebase/dataconnect/testutil/property/arbitrary/arbs.kt @@ -22,7 +22,9 @@ import com.google.firebase.dataconnect.DataConnectPathSegment import com.google.firebase.dataconnect.FirebaseDataConnect.CallerSdkType import com.google.firebase.dataconnect.OperationRef import com.google.firebase.dataconnect.core.DataConnectAppCheck +import com.google.firebase.dataconnect.core.DataConnectAppCheck.GetAppCheckTokenResult import com.google.firebase.dataconnect.core.DataConnectAuth +import com.google.firebase.dataconnect.core.DataConnectAuth.GetAuthTokenResult import com.google.firebase.dataconnect.core.DataConnectGrpcClient import com.google.firebase.dataconnect.core.DataConnectGrpcMetadata import com.google.firebase.dataconnect.core.DataConnectOperationFailureResponseImpl @@ -50,14 +52,17 @@ import io.kotest.property.arbitrary.list import io.kotest.property.arbitrary.map import io.kotest.property.arbitrary.orNull import io.kotest.property.arbitrary.string +import io.mockk.coEvery import io.mockk.mockk import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.SerializationStrategy import kotlinx.serialization.modules.SerializersModule internal fun DataConnectArb.dataConnectGrpcMetadata( - dataConnectAuth: Arb = Arb.constant(mockk(relaxed = true)), - dataConnectAppCheck: Arb = Arb.constant(mockk(relaxed = true)), + dataConnectAuth: Arb = + Arb.constant(mockk(relaxed = true) { coEvery { getToken(any()) } returns null }), + dataConnectAppCheck: Arb = + Arb.constant(mockk(relaxed = true) { coEvery { getToken(any()) } returns null }), connectorLocation: Arb = connectorLocation(), kotlinVersion: Arb = Arb.string(size = 8, Codepoint.alphanumeric()), androidVersion: Arb = Arb.int(0..100), @@ -326,3 +331,11 @@ internal inline fun DataConnectArb.operationRefConstru variablesSerializersModule = variablesSerializersModule.bind(), ) } + +internal fun DataConnectArb.authTokenResult( + accessToken: Arb = accessToken() +): Arb = accessToken.map { GetAuthTokenResult(it) } + +internal fun DataConnectArb.appCheckTokenResult( + accessToken: Arb = accessToken() +): Arb = accessToken.map { GetAppCheckTokenResult(it) }