Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ class DiskCache(
clientRegistrationCache(ssoRegion).tryDeleteIfExists()
}

override fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey): ClientRegistration? {
LOG.info { "loadClientRegistration for $cacheKey" }
override fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey, source: String): ClientRegistration? {
LOG.info { "loadClientRegistration:$source for $cacheKey" }
val inputStream = clientRegistrationCache(cacheKey).tryInputStreamIfExists()
if (inputStream == null) {
val stage = LoadCredentialStage.ACCESS_FILE
LOG.info { "Failed to load Client Registration: cache file does not exist" }
AuthTelemetry.modifyConnection(
action = "Load cache file",
source = "loadClientRegistration",
source = "loadClientRegistration:$source",
result = Result.Failed,
reason = "Failed to load Client Registration",
reasonDesc = "Load Step:$stage failed. Cache file does not exist"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@

@Deprecated("Device authorization grant flow is deprecated")
private fun registerDAGClient(): ClientRegistration {
loadDagClientRegistration()?.let {
loadDagClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let {
return it
}

Expand Down Expand Up @@ -235,7 +235,7 @@
}

private fun registerPkceClient(): PKCEClientRegistration {
loadPkceClientRegistration()?.let {
loadPkceClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let {
return it
}

Expand Down Expand Up @@ -431,8 +431,8 @@
stageName = RefreshCredentialStage.LOAD_REGISTRATION
val registration = try {
when (currentToken) {
is DeviceAuthorizationGrantToken -> loadDagClientRegistration()
is PKCEAuthorizationGrantToken -> loadPkceClientRegistration()
is DeviceAuthorizationGrantToken -> loadDagClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString())
is PKCEAuthorizationGrantToken -> loadPkceClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString())
}
} catch (e: Exception) {
val message = e.message ?: "$stageName: ${e::class.java.name}"
Expand Down Expand Up @@ -505,6 +505,11 @@
}
}

enum class SourceOfLoadRegistration {
REGISTER_CLIENT,
REFRESH_TOKEN,
}

Check warning on line 511 in plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt

View check run for this annotation

Codecov / codecov/patch

plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt#L511

Added line #L511 was not covered by tests

private enum class RefreshCredentialStage {
VALIDATE_REFRESH_TOKEN,
LOAD_REGISTRATION,
Expand All @@ -514,13 +519,13 @@
SAVE_TOKEN,
}

private fun loadDagClientRegistration(): ClientRegistration? =
cache.loadClientRegistration(dagClientRegistrationCacheKey)?.let {
private fun loadDagClientRegistration(source: String): ClientRegistration? =
cache.loadClientRegistration(dagClientRegistrationCacheKey, source)?.let {
return it
}

private fun loadPkceClientRegistration(): PKCEClientRegistration? =
cache.loadClientRegistration(pkceClientRegistrationCacheKey)?.let {
private fun loadPkceClientRegistration(source: String): PKCEClientRegistration? =
cache.loadClientRegistration(pkceClientRegistrationCacheKey, source)?.let {
return it as PKCEClientRegistration
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ interface SsoCache {
fun invalidateClientRegistration(ssoRegion: String)
fun invalidateAccessToken(ssoUrl: String)

fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey): ClientRegistration?
fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey, source: String): ClientRegistration?
fun saveClientRegistration(cacheKey: ClientRegistrationCacheKey, registration: ClientRegistration)
fun invalidateClientRegistration(cacheKey: ClientRegistrationCacheKey)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class DiskCacheTest {
startUrl = ssoUrl,
scopes = scopes,
region = ssoRegion
)
),
"testSource"
)
).isNull()
}
Expand All @@ -71,7 +72,7 @@ class DiskCacheTest {
)
cacheLocation.resolve("223224b6f0b4702c1a984be8284fe2c9d9718759.json").writeText("badData")

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
}

@Test
Expand All @@ -91,7 +92,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
}

@Test
Expand All @@ -112,7 +113,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
}

@Test
Expand All @@ -134,7 +135,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key))
assertThat(sut.loadClientRegistration(key, "testSource"))
.usingRecursiveComparison()
.isEqualTo(
DeviceAuthorizationClientRegistration(
Expand Down Expand Up @@ -217,7 +218,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key))
assertThat(sut.loadClientRegistration(key, "testSource"))
.usingRecursiveComparison()
.isEqualTo(
PKCEClientRegistration(
Expand Down Expand Up @@ -323,10 +324,10 @@ class DiskCacheTest {
)
)

assertThat(sut.loadClientRegistration(key1))
assertThat(sut.loadClientRegistration(key1, "testSource"))
.usingRecursiveComparison()
.isEqualTo(
sut.loadClientRegistration(key2)
sut.loadClientRegistration(key2, "testSource")
)
}

Expand All @@ -350,11 +351,11 @@ class DiskCacheTest {
region = ssoRegion
)

assertThat(sut.loadClientRegistration(key)).isNotNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNotNull()

sut.invalidateClientRegistration(key)

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
assertThat(cacheFile).doesNotExist()
}

Expand Down Expand Up @@ -619,7 +620,7 @@ class DiskCacheTest {
registration.setPosixFilePermissions(emptySet())
assertPosixPermissions(registration, "---------")

assertThat(sut.loadClientRegistration(key)).isNotNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNotNull()

assertPosixPermissions(registration, "rw-------")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class SsoAccessTokenProviderTest {
verify(ssoOidcClient).startDeviceAuthorization(any<StartDeviceAuthorizationRequest>())
verify(ssoOidcClient).createToken(any<CreateTokenRequest>())
verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any<String>())
verify(ssoCache).saveAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl }, eq(accessToken))
}

Expand Down Expand Up @@ -170,7 +170,7 @@ class SsoAccessTokenProviderTest {
verify(ssoOidcClient).startDeviceAuthorization(any<StartDeviceAuthorizationRequest>())
verify(ssoOidcClient).createToken(any<CreateTokenRequest>())
verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion }, any<String>())
verify(ssoCache).saveClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion }, any())
verify(ssoCache).saveAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl }, eq(accessToken))
}
Expand Down Expand Up @@ -267,7 +267,7 @@ class SsoAccessTokenProviderTest {
verify(ssoOidcClient).startDeviceAuthorization(any<StartDeviceAuthorizationRequest>())
verify(ssoOidcClient, times(2)).createToken(any<CreateTokenRequest>())
verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion }, any<String>())
verify(ssoCache).saveAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl }, eq(accessToken))
}

Expand Down Expand Up @@ -296,7 +296,7 @@ class SsoAccessTokenProviderTest {
val refreshedToken = runBlocking { sut.refreshToken(sut.accessToken()) }

verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any<String>())
verify(ssoOidcClient).createToken(any<CreateTokenRequest>())
verify(ssoCache).saveAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl }, eq(refreshedToken))
}
Expand Down Expand Up @@ -342,7 +342,7 @@ class SsoAccessTokenProviderTest {
)

on(
ssoCache.loadClientRegistration(any<PKCEClientRegistrationCacheKey>())
ssoCache.loadClientRegistration(any<PKCEClientRegistrationCacheKey>(), any<String>())
).thenReturn(
PKCEClientRegistration(
clientType = "public",
Expand All @@ -369,7 +369,7 @@ class SsoAccessTokenProviderTest {
val refreshedToken = runBlocking { sut.refreshToken(sut.accessToken()) }

verify(ssoCache).loadAccessToken(any<PKCEAccessTokenCacheKey>())
verify(ssoCache).loadClientRegistration(any<PKCEClientRegistrationCacheKey>())
verify(ssoCache).loadClientRegistration(any<PKCEClientRegistrationCacheKey>(), any<String>())
verify(ssoOidcClient).createToken(any<CreateTokenRequest>())
verify(ssoCache).saveAccessToken(any<PKCEAccessTokenCacheKey>(), eq(refreshedToken))
}
Expand All @@ -390,7 +390,7 @@ class SsoAccessTokenProviderTest {
verify(ssoOidcClient).startDeviceAuthorization(any<StartDeviceAuthorizationRequest>())
verify(ssoOidcClient).createToken(any<CreateTokenRequest>())
verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion }, any<String>())
}

@Test
Expand Down Expand Up @@ -432,7 +432,7 @@ class SsoAccessTokenProviderTest {
verify(ssoOidcClient).startDeviceAuthorization(any<StartDeviceAuthorizationRequest>())
verify(ssoOidcClient, times(2)).createToken(any<CreateTokenRequest>())
verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion }, any<String>())
verify(ssoCache).saveAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl }, eq(accessToken))
}

Expand All @@ -452,7 +452,7 @@ class SsoAccessTokenProviderTest {

verify(ssoOidcClient).registerClient(any<RegisterClientRequest>())
verify(ssoCache).loadAccessToken(any())
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any<String>())
}

@Test
Expand Down Expand Up @@ -492,7 +492,7 @@ class SsoAccessTokenProviderTest {
)

on(
ssoCache.loadClientRegistration(argThat { region == ssoRegion })
ssoCache.loadClientRegistration(argThat { region == ssoRegion }, any<String>())
).thenReturn(
returnValue
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import org.junit.jupiter.api.assertThrows
import org.mockito.Mockito
import org.mockito.kotlin.any
import org.mockito.kotlin.argThat
import org.mockito.kotlin.eq
import org.mockito.kotlin.mock
import org.mockito.kotlin.spy
import org.mockito.kotlin.times
Expand Down Expand Up @@ -273,7 +274,7 @@ class InteractiveBearerTokenProviderTest {
)

private fun stubClientRegistration() {
whenever(diskCache.loadClientRegistration(any<DeviceAuthorizationClientRegistrationCacheKey>())).thenReturn(
whenever(diskCache.loadClientRegistration(any<DeviceAuthorizationClientRegistrationCacheKey>(), eq("testSource"))).thenReturn(
DeviceAuthorizationClientRegistration(
"",
"",
Expand Down
Loading