Skip to content

Commit 218787c

Browse files
authored
Backport SSO scope changes to public (#3321)
1 parent 162acab commit 218787c

File tree

10 files changed

+163
-48
lines changed

10 files changed

+163
-48
lines changed

gradle/libs.versions.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[versions]
22
apacheCommons = "2.8.0"
33
assertJ = "3.20.2" # Upgrading leads to SAM errors: https://youtrack.jetbrains.com/issue/KT-17765
4-
awsSdk = "2.17.219"
4+
awsSdk = "2.17.289"
55
commonmark = "0.17.1"
66
detekt = "1.21.0"
77
intellijGradle = "1.9.0"

jetbrains-core/src/software/aws/toolkits/jetbrains/core/credentials/profiles/ProfileCredentialProviderFactory.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import software.aws.toolkits.core.credentials.CredentialType
2525
import software.aws.toolkits.core.credentials.CredentialsChangeEvent
2626
import software.aws.toolkits.core.credentials.CredentialsChangeListener
2727
import software.aws.toolkits.core.region.AwsRegion
28+
import software.aws.toolkits.core.utils.getLogger
29+
import software.aws.toolkits.core.utils.warn
2830
import software.aws.toolkits.jetbrains.core.credentials.MfaRequiredInteractiveCredentials
2931
import software.aws.toolkits.jetbrains.core.credentials.SsoRequiredInteractiveCredentials
3032
import software.aws.toolkits.jetbrains.core.credentials.ToolkitCredentialProcessProvider
@@ -143,6 +145,8 @@ class ProfileCredentialProviderFactory(private val ssoCache: SsoCache = diskCach
143145
": $it"
144146
} ?: ""
145147

148+
LOG.warn(e) { loadingFailureMessage }
149+
146150
if (AwsSettings.getInstance().profilesNotification != ProfilesNotification.Never) {
147151
notifyError(
148152
title = message("credentials.profile.refresh_ok_title"),
@@ -298,6 +302,10 @@ class ProfileCredentialProviderFactory(private val ssoCache: SsoCache = diskCach
298302

299303
private fun Profile.requiresSso(profiles: Map<String, Profile>) = this.traverseCredentialChain(profiles)
300304
.any { it.propertyExists(ProfileProperty.SSO_START_URL) }
305+
306+
companion object {
307+
private val LOG = getLogger<ProfileCredentialProviderFactory>()
308+
}
301309
}
302310

303311
private fun Profile.toCredentialType(): CredentialType? = when {

jetbrains-core/src/software/aws/toolkits/jetbrains/core/credentials/profiles/ProfileReader.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ data class Profiles(val validProfiles: Map<String, Profile>, val invalidProfiles
1414
* Reads the AWS shared credentials files and produces what profiles are valid and if not why it is not
1515
*/
1616
fun validateAndGetProfiles(): Profiles {
17-
val allProfiles: Map<String, Profile> = ProfileFile.defaultProfileFile().profiles()
17+
val allProfiles: Map<String, Profile> = ProfileFile.defaultProfileFile().profiles().orEmpty()
1818

1919
val validProfiles = mutableMapOf<String, Profile>()
2020
val invalidProfiles = mutableMapOf<String, Exception>()

jetbrains-core/src/software/aws/toolkits/jetbrains/core/credentials/sso/AccessToken.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
package software.aws.toolkits.jetbrains.core.credentials.sso
55

6+
import com.fasterxml.jackson.annotation.JsonInclude
67
import software.amazon.awssdk.services.sso.SsoClient
78
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
89
import java.time.Instant
@@ -14,5 +15,7 @@ data class AccessToken(
1415
val startUrl: String,
1516
val region: String,
1617
val accessToken: String,
18+
@JsonInclude(JsonInclude.Include.NON_NULL)
19+
val refreshToken: String? = null,
1720
val expiresAt: Instant
1821
)

jetbrains-core/src/software/aws/toolkits/jetbrains/core/credentials/sso/DiskCache.kt

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ import java.nio.file.Paths
2424
import java.nio.file.attribute.PosixFilePermission
2525
import java.security.MessageDigest
2626
import java.time.Clock
27+
import java.time.Duration
2728
import java.time.Instant
2829
import java.time.ZoneOffset
2930
import java.time.format.DateTimeFormatter.ISO_INSTANT
30-
import java.time.temporal.ChronoUnit
3131
import java.util.TimeZone
3232

3333
/**
@@ -78,11 +78,11 @@ class DiskCache(
7878
val inputStream = cacheFile.inputStreamIfExists() ?: return null
7979

8080
return tryOrNull {
81-
val clientRegistration = objectMapper.readValue<AccessToken>(inputStream)
81+
val accessToken = objectMapper.readValue<AccessToken>(inputStream)
8282
// Use same expiration logic as client registration even though RFC/SEP does not specify it.
8383
// This prevents a cache entry being returned as valid and then expired when we go to use it.
84-
if (clientRegistration.expiresAt.isNotExpired()) {
85-
clientRegistration
84+
if (!accessToken.isDefinitelyExpired()) {
85+
accessToken
8686
} else {
8787
null
8888
}
@@ -113,7 +113,9 @@ class DiskCache(
113113
}
114114

115115
// If the item is going to expire in the next 15 mins, we must treat it as already expired
116-
private fun Instant.isNotExpired(): Boolean = this.isAfter(Instant.now(clock).plus(15, ChronoUnit.MINUTES))
116+
private fun Instant.isNotExpired(): Boolean = this.isAfter(Instant.now(clock).plus(EXPIRATION_THRESHOLD))
117+
118+
private fun AccessToken.isDefinitelyExpired(): Boolean = refreshToken == null && !expiresAt.isNotExpired()
117119

118120
private class CliCompatibleInstantDeserializer : StdDeserializer<Instant>(Instant::class.java) {
119121
override fun deserialize(parser: JsonParser, context: DeserializationContext): Instant {
@@ -129,4 +131,8 @@ class DiskCache(
129131
return ISO_INSTANT.parse(sanitized) { Instant.from(it) }
130132
}
131133
}
134+
135+
companion object {
136+
val EXPIRATION_THRESHOLD = Duration.ofMinutes(15)
137+
}
132138
}

jetbrains-core/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ package software.aws.toolkits.jetbrains.core.credentials.sso
66
import com.intellij.openapi.progress.ProgressManager
77
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
88
import software.amazon.awssdk.services.ssooidc.model.AuthorizationPendingException
9+
import software.amazon.awssdk.services.ssooidc.model.CreateTokenResponse
910
import software.amazon.awssdk.services.ssooidc.model.InvalidClientException
11+
import software.amazon.awssdk.services.ssooidc.model.InvalidRequestException
1012
import software.amazon.awssdk.services.ssooidc.model.SlowDownException
1113
import software.aws.toolkits.jetbrains.utils.assertIsNonDispatchThread
1214
import software.aws.toolkits.jetbrains.utils.sleepWithCancellation
@@ -23,6 +25,7 @@ class SsoAccessTokenProvider(
2325
private val onPendingToken: SsoLoginCallback,
2426
private val cache: SsoCache,
2527
private val client: SsoOidcClient,
28+
private val scopes: List<String> = emptyList(),
2629
private val clock: Clock = Clock.systemUTC()
2730
) {
2831
fun accessToken(): AccessToken {
@@ -47,6 +50,7 @@ class SsoAccessTokenProvider(
4750
// Based on botocore: https://github.com/boto/botocore/blob/5dc8ee27415dc97cfff75b5bcfa66d410424e665/botocore/utils.py#L1753
4851
val registerResponse = client.registerClient {
4952
it.clientType(CLIENT_REGISTRATION_TYPE)
53+
it.scopes(scopes)
5054
it.clientName("aws-toolkit-jetbrains-${Instant.now(clock)}")
5155
}
5256

@@ -99,20 +103,13 @@ class SsoAccessTokenProvider(
99103
val tokenResponse = client.createToken {
100104
it.clientId(registration.clientId)
101105
it.clientSecret(registration.clientSecret)
102-
it.grantType(GRANT_TYPE)
106+
it.grantType(DEVICE_GRANT_TYPE)
103107
it.deviceCode(authorization.deviceCode)
104108
}
105109

106-
val expirationTime = Instant.now(clock).plusSeconds(tokenResponse.expiresIn().toLong())
107-
108110
onPendingToken.tokenRetrieved()
109111

110-
return AccessToken(
111-
ssoUrl,
112-
ssoRegion,
113-
tokenResponse.accessToken(),
114-
expirationTime
115-
)
112+
return tokenResponse.toAccessToken()
116113
} catch (e: SlowDownException) {
117114
backOffTime = backOffTime.plusSeconds(SLOW_DOWN_DELAY_SECS)
118115
} catch (e: AuthorizationPendingException) {
@@ -126,13 +123,46 @@ class SsoAccessTokenProvider(
126123
}
127124
}
128125

126+
fun refreshToken(currentToken: AccessToken): AccessToken {
127+
if (currentToken.refreshToken == null) {
128+
throw InvalidRequestException.builder().build()
129+
}
130+
131+
val registration = cache.loadClientRegistration(ssoRegion) ?: throw InvalidClientException.builder().build()
132+
133+
val newToken = client.createToken {
134+
it.clientId(registration.clientId)
135+
it.clientSecret(registration.clientSecret)
136+
it.grantType(REFRESH_GRANT_TYPE)
137+
it.refreshToken(currentToken.refreshToken)
138+
}
139+
140+
val token = newToken.toAccessToken()
141+
cache.saveAccessToken(ssoUrl, token)
142+
143+
return token
144+
}
145+
129146
fun invalidate() {
130147
cache.invalidateAccessToken(ssoUrl)
131148
}
132149

150+
private fun CreateTokenResponse.toAccessToken(): AccessToken {
151+
val expirationTime = Instant.now(clock).plusSeconds(expiresIn().toLong())
152+
153+
return AccessToken(
154+
startUrl = ssoUrl,
155+
region = ssoRegion,
156+
accessToken = accessToken(),
157+
refreshToken = refreshToken(),
158+
expiresAt = expirationTime
159+
)
160+
}
161+
133162
private companion object {
134163
const val CLIENT_REGISTRATION_TYPE = "public"
135-
const val GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code"
164+
const val DEVICE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code"
165+
const val REFRESH_GRANT_TYPE = "refresh_token"
136166

137167
// Default number of seconds to poll for token, https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5
138168
const val DEFAULT_INTERVAL_SECS = 5L

jetbrains-core/tst/software/aws/toolkits/jetbrains/core/credentials/ToolkitCredentialProcessProviderTest.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,15 @@ class ToolkitCredentialProcessProviderTest {
116116
}
117117

118118
@Test
119-
fun `expiry in the past means command is re-run`() {
119+
fun `expiry in the past means command is not re-run`() {
120+
// Java SDK prefers threads to block when this happens https://github.com/aws/aws-sdk-java-v2/commit/5151e4049382bdb5ea6b487e6f150314b579181d
120121
val sut = createSut("echo")
121122
stubParser(CredentialProcessOutput("foo", "bar", null, Instant.now().minus(Duration.ofHours(1))))
122123

123124
sut.resolveCredentials()
124125
sut.resolveCredentials()
125126

126-
verify(parser, times(2)).parse(any())
127+
verify(parser).parse(any())
127128
}
128129

129130
@Test

0 commit comments

Comments
 (0)