Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ class DiskCache(
clientRegistrationCache(ssoRegion).tryDeleteIfExists()
}

override fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey): ClientRegistration? {
LOG.info { "loadClientRegistration for $cacheKey" }
override fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey, source: String): ClientRegistration? {
LOG.info { "loadClientRegistration:$source for $cacheKey" }
val inputStream = clientRegistrationCache(cacheKey).tryInputStreamIfExists()
if (inputStream == null) {
val stage = LoadCredentialStage.ACCESS_FILE
LOG.info { "Failed to load Client Registration: cache file does not exist" }
AuthTelemetry.modifyConnection(
action = "Load cache file",
source = "loadClientRegistration",
source = "loadClientRegistration:$source",
result = Result.Failed,
reason = "Failed to load Client Registration",
reasonDesc = "Load Step:$stage failed. Cache file does not exist"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class SsoAccessTokenProvider(

@Deprecated("Device authorization grant flow is deprecated")
private fun registerDAGClient(): ClientRegistration {
loadDagClientRegistration()?.let {
loadDagClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let {
return it
}

Expand Down Expand Up @@ -235,7 +235,7 @@ class SsoAccessTokenProvider(
}

private fun registerPkceClient(): PKCEClientRegistration {
loadPkceClientRegistration()?.let {
loadPkceClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let {
return it
}

Expand Down Expand Up @@ -431,8 +431,8 @@ class SsoAccessTokenProvider(
stageName = RefreshCredentialStage.LOAD_REGISTRATION
val registration = try {
when (currentToken) {
is DeviceAuthorizationGrantToken -> loadDagClientRegistration()
is PKCEAuthorizationGrantToken -> loadPkceClientRegistration()
is DeviceAuthorizationGrantToken -> loadDagClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString())
is PKCEAuthorizationGrantToken -> loadPkceClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString())
}
} catch (e: Exception) {
val message = e.message ?: "$stageName: ${e::class.java.name}"
Expand Down Expand Up @@ -505,6 +505,11 @@ class SsoAccessTokenProvider(
}
}

enum class SourceOfLoadRegistration {
REGISTER_CLIENT,
REFRESH_TOKEN,
}

private enum class RefreshCredentialStage {
VALIDATE_REFRESH_TOKEN,
LOAD_REGISTRATION,
Expand All @@ -514,13 +519,13 @@ class SsoAccessTokenProvider(
SAVE_TOKEN,
}

private fun loadDagClientRegistration(): ClientRegistration? =
cache.loadClientRegistration(dagClientRegistrationCacheKey)?.let {
private fun loadDagClientRegistration(source: String): ClientRegistration? =
cache.loadClientRegistration(dagClientRegistrationCacheKey, source)?.let {
return it
}

private fun loadPkceClientRegistration(): PKCEClientRegistration? =
cache.loadClientRegistration(pkceClientRegistrationCacheKey)?.let {
private fun loadPkceClientRegistration(source: String): PKCEClientRegistration? =
cache.loadClientRegistration(pkceClientRegistrationCacheKey, source)?.let {
return it as PKCEClientRegistration
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ interface SsoCache {
fun invalidateClientRegistration(ssoRegion: String)
fun invalidateAccessToken(ssoUrl: String)

fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey): ClientRegistration?
fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey, source: String): ClientRegistration?
fun saveClientRegistration(cacheKey: ClientRegistrationCacheKey, registration: ClientRegistration)
fun invalidateClientRegistration(cacheKey: ClientRegistrationCacheKey)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class DiskCacheTest {
startUrl = ssoUrl,
scopes = scopes,
region = ssoRegion
)
),
"testSource"
)
).isNull()
}
Expand All @@ -71,7 +72,7 @@ class DiskCacheTest {
)
cacheLocation.resolve("223224b6f0b4702c1a984be8284fe2c9d9718759.json").writeText("badData")

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
}

@Test
Expand All @@ -91,7 +92,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
}

@Test
Expand All @@ -112,7 +113,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
}

@Test
Expand All @@ -134,7 +135,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key))
assertThat(sut.loadClientRegistration(key, "testSource"))
.usingRecursiveComparison()
.isEqualTo(
DeviceAuthorizationClientRegistration(
Expand Down Expand Up @@ -217,7 +218,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key))
assertThat(sut.loadClientRegistration(key, "testSource"))
.usingRecursiveComparison()
.isEqualTo(
PKCEClientRegistration(
Expand Down Expand Up @@ -323,10 +324,10 @@ class DiskCacheTest {
)
)

assertThat(sut.loadClientRegistration(key1))
assertThat(sut.loadClientRegistration(key1, "testSource"))
.usingRecursiveComparison()
.isEqualTo(
sut.loadClientRegistration(key2)
sut.loadClientRegistration(key2, "testSource")
)
}

Expand All @@ -350,11 +351,11 @@ class DiskCacheTest {
region = ssoRegion
)

assertThat(sut.loadClientRegistration(key)).isNotNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNotNull()

sut.invalidateClientRegistration(key)

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
assertThat(cacheFile).doesNotExist()
}

Expand Down Expand Up @@ -619,7 +620,7 @@ class DiskCacheTest {
registration.setPosixFilePermissions(emptySet())
assertPosixPermissions(registration, "---------")

assertThat(sut.loadClientRegistration(key)).isNotNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNotNull()

assertPosixPermissions(registration, "rw-------")
}
Expand Down
Loading