Skip to content

Commit 374c0cc

Browse files
authored
fix codewhisperer auth (#3961)
1 parent cb08c9a commit 374c0cc

File tree

6 files changed

+175
-98
lines changed

6 files changed

+175
-98
lines changed

jetbrains-core/src/software/aws/toolkits/jetbrains/core/credentials/ToolkitConnectionImpls.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import com.intellij.openapi.Disposable
77
import com.intellij.openapi.util.Disposer
88
import software.aws.toolkits.core.TokenConnectionSettings
99
import software.aws.toolkits.core.credentials.ToolkitBearerTokenProvider
10+
import software.aws.toolkits.jetbrains.core.credentials.sso.DiskCache
1011
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
1112
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.InteractiveBearerTokenProvider
1213
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.ProfileSdkTokenProviderWrapper
@@ -16,6 +17,7 @@ class ManagedBearerSsoConnection(
1617
val startUrl: String,
1718
val region: String,
1819
override val scopes: List<String>,
20+
cache: DiskCache = diskCache
1921
) : BearerSsoConnection, Disposable {
2022
override val id: String = ToolkitBearerTokenProvider.ssoIdentifier(startUrl, region)
2123
override val label: String = ToolkitBearerTokenProvider.ssoDisplayName(startUrl)
@@ -25,7 +27,8 @@ class ManagedBearerSsoConnection(
2527
InteractiveBearerTokenProvider(
2628
startUrl,
2729
region,
28-
scopes
30+
scopes,
31+
cache
2932
),
3033
region
3134
)

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/explorer/CodeWhispererExplorerActionManager.kt

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ import com.intellij.util.xmlb.annotations.Property
1313
import org.jetbrains.annotations.ApiStatus.ScheduledForRemoval
1414
import software.aws.toolkits.core.utils.getLogger
1515
import software.aws.toolkits.core.utils.warn
16+
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
1617
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
1718
import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
1819
import software.aws.toolkits.jetbrains.core.credentials.sono.isSono
20+
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState
21+
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
1922
import software.aws.toolkits.jetbrains.core.explorer.refreshDevToolTree
2023
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererLoginType
2124
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants
2225
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.getConnectionStartUrl
23-
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.isRefreshTokenExpired
2426
import software.aws.toolkits.telemetry.AwsTelemetry
2527
import java.time.LocalDateTime
2628

@@ -119,21 +121,25 @@ class CodeWhispererExplorerActionManager : PersistentStateComponent<CodeWhispere
119121
return actionState.token
120122
}
121123

122-
fun checkActiveCodeWhispererConnectionType(project: Project) = when {
123-
actionState.token != null -> CodeWhispererLoginType.Accountless
124-
isRefreshTokenExpired(project) -> CodeWhispererLoginType.Expired
125-
else -> {
126-
val conn = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
127-
if (conn != null) {
128-
if (conn.isSono()) {
129-
CodeWhispererLoginType.Sono
130-
} else {
131-
CodeWhispererLoginType.SSO
124+
fun checkActiveCodeWhispererConnectionType(project: Project): CodeWhispererLoginType {
125+
val conn = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance()) as? AwsBearerTokenConnection
126+
return conn?.let {
127+
val provider = (it.getConnectionSettings().tokenProvider.delegate as? BearerTokenProvider) ?: return@let CodeWhispererLoginType.Logout
128+
129+
when (provider.state()) {
130+
BearerTokenAuthState.AUTHORIZED -> {
131+
if (it.isSono()) {
132+
CodeWhispererLoginType.Sono
133+
} else {
134+
CodeWhispererLoginType.SSO
135+
}
132136
}
133-
} else {
134-
CodeWhispererLoginType.Logout
137+
138+
BearerTokenAuthState.NEEDS_REFRESH -> CodeWhispererLoginType.Expired
139+
140+
BearerTokenAuthState.NOT_AUTHENTICATED -> CodeWhispererLoginType.Logout
135141
}
136-
}
142+
} ?: CodeWhispererLoginType.Logout
137143
}
138144

139145
fun nullifyAccountlessCredentialIfNeeded() {
@@ -202,6 +208,11 @@ fun isCodeWhispererEnabled(project: Project) = with(CodeWhispererExplorerActionM
202208
checkActiveCodeWhispererConnectionType(project) != CodeWhispererLoginType.Logout
203209
}
204210

211+
/**
212+
* Note: please use this util with extra caution, it will return "false" for a "logout" scenario,
213+
* the reasoning is we need handling specifically for a "Expired" condition thus excluding logout from here
214+
* If callers rather need a predicate "isInvalidConnection", please use the combination of the two (!isCodeWhispererEnabled() || isCodeWhispererExpired())
215+
*/
205216
fun isCodeWhispererExpired(project: Project) = with(CodeWhispererExplorerActionManager.getInstance()) {
206217
checkActiveCodeWhispererConnectionType(project) == CodeWhispererLoginType.Expired
207218
}

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererUtil.kt

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import software.aws.toolkits.jetbrains.core.credentials.loginSso
2121
import software.aws.toolkits.jetbrains.core.credentials.maybeReauthProviderIfNeeded
2222
import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
2323
import software.aws.toolkits.jetbrains.core.credentials.sono.isSono
24-
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState
2524
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
2625
import software.aws.toolkits.jetbrains.core.explorer.refreshDevToolTree
2726
import software.aws.toolkits.jetbrains.services.codewhisperer.actions.CodeWhispererLoginLearnMoreAction
@@ -196,18 +195,6 @@ object CodeWhispererUtil {
196195
listOf(CodeWhispererSsoLearnMoreAction(), ConnectWithAwsToContinueActionError(), DoNotShowAgainActionError())
197196
)
198197

199-
fun isAccessTokenExpired(project: Project): Boolean {
200-
val tokenProvider = tokenProvider(project) ?: return false
201-
val state = tokenProvider.state()
202-
return state == BearerTokenAuthState.NEEDS_REFRESH
203-
}
204-
205-
fun isRefreshTokenExpired(project: Project): Boolean {
206-
val tokenProvider = tokenProvider(project) ?: return false
207-
val state = tokenProvider.state()
208-
return state == BearerTokenAuthState.NOT_AUTHENTICATED
209-
}
210-
211198
// This will be called only when there's a CW connection, but it has expired(either accessToken or refreshToken)
212199
// 1. If connection is expired, try to refresh
213200
// 2. If not able to refresh, requesting re-login by showing a notification

jetbrains-core/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererExplorerActionManagerTest.kt

Lines changed: 138 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@
33

44
package software.aws.toolkits.jetbrains.services.codewhisperer
55

6-
import com.intellij.openapi.application.ApplicationManager
76
import com.intellij.openapi.project.Project
8-
import com.intellij.testFramework.ApplicationRule
97
import com.intellij.testFramework.DisposableRule
108
import com.intellij.testFramework.ProjectRule
119
import com.intellij.testFramework.replaceService
1210
import org.assertj.core.api.Assertions.assertThat
1311
import org.junit.Before
1412
import org.junit.Rule
1513
import org.junit.Test
14+
import org.junit.rules.TemporaryFolder
1615
import org.mockito.kotlin.any
1716
import org.mockito.kotlin.mock
1817
import org.mockito.kotlin.spy
@@ -23,16 +22,31 @@ import software.aws.toolkits.jetbrains.core.MockClientManagerRule
2322
import software.aws.toolkits.jetbrains.core.credentials.ManagedBearerSsoConnection
2423
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
2524
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
25+
import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
26+
import software.aws.toolkits.jetbrains.core.credentials.pinning.ConnectionPinningManager
27+
import software.aws.toolkits.jetbrains.core.credentials.sono.CODEWHISPERER_SCOPES
2628
import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_URL
29+
import software.aws.toolkits.jetbrains.core.credentials.sso.AccessToken
30+
import software.aws.toolkits.jetbrains.core.credentials.sso.AccessTokenCacheKey
31+
import software.aws.toolkits.jetbrains.core.credentials.sso.DiskCache
32+
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState
33+
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.InteractiveBearerTokenProvider
2734
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererLoginType
28-
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExploreActionState
2935
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager
3036
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.isCodeWhispererEnabled
37+
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.isCodeWhispererExpired
38+
import java.nio.file.Files
39+
import java.nio.file.Path
40+
import java.nio.file.Paths
41+
import java.time.Clock
42+
import java.time.Instant
43+
import java.time.ZoneOffset
44+
import java.time.temporal.ChronoUnit
3145

3246
class CodeWhispererExplorerActionManagerTest {
3347
@JvmField
3448
@Rule
35-
val applicationRule = ApplicationRule()
49+
val tempFolder = TemporaryFolder()
3650

3751
@JvmField
3852
@Rule
@@ -50,17 +64,24 @@ class CodeWhispererExplorerActionManagerTest {
5064
@Rule
5165
val mockClientManager = MockClientManagerRule()
5266

67+
private val now = Instant.now()
68+
private val clock = Clock.fixed(now, ZoneOffset.UTC)
69+
5370
private lateinit var mockManager: CodeWhispererExplorerActionManager
5471
private lateinit var project: Project
55-
private lateinit var connectionManager: ToolkitConnectionManager
72+
private lateinit var cacheRoot: Path
73+
private lateinit var cacheLocation: Path
74+
private lateinit var testDiskCache: DiskCache
5675

5776
@Before
5877
fun setup() {
78+
cacheRoot = tempFolder.root.toPath().toAbsolutePath()
79+
cacheLocation = Paths.get(cacheRoot.toString(), "fakehome", ".aws", "sso", "cache")
80+
Files.createDirectories(cacheLocation)
81+
testDiskCache = DiskCache(cacheLocation, clock)
82+
5983
mockClientManager.create<SsoOidcClient>()
6084
project = projectRule.project
61-
connectionManager = mock()
62-
63-
project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposableRule.disposable)
6485
}
6586

6687
/**
@@ -69,49 +90,14 @@ class CodeWhispererExplorerActionManagerTest {
6990
@Test
7091
fun `when there is no connection, should return logout`() {
7192
mockManager = spy()
72-
whenever(connectionManager.activeConnectionForFeature(any())).thenReturn(null)
93+
val mockConnectionManager = mock<ToolkitConnectionManager>()
94+
whenever(mockConnectionManager.activeConnectionForFeature(any())).thenReturn(null)
95+
project.replaceService(ToolkitConnectionManager::class.java, mockConnectionManager, disposableRule.disposable)
7396

7497
val actual = mockManager.checkActiveCodeWhispererConnectionType(project)
7598
assertThat(actual).isEqualTo(CodeWhispererLoginType.Logout)
76-
}
77-
78-
@Test
79-
fun `when ToS accepted and there is an accountless token, should return accountless`() {
80-
mockManager = spy()
81-
mockManager.loadState(
82-
// set up accountless token
83-
CodeWhispererExploreActionState().apply {
84-
this.token = "foo"
85-
}
86-
)
87-
88-
val actual = mockManager.checkActiveCodeWhispererConnectionType(project)
89-
assertThat(actual).isEqualTo(CodeWhispererLoginType.Accountless)
90-
}
91-
92-
@Test
93-
fun `when ToS accepted, no accountless token and there is an AWS Builder ID connection, should return Sono`() {
94-
assertLoginType(SONO_URL, CodeWhispererLoginType.Sono)
95-
}
96-
97-
@Test
98-
fun `when ToS accepted, no accountless token and there is an SSO connection, should return SSO`() {
99-
assertLoginType(aString(), CodeWhispererLoginType.SSO)
100-
}
101-
102-
@Test
103-
fun `test nullifyAccountlessCredentialIfNeeded`() {
104-
mockManager = CodeWhispererExplorerActionManager()
105-
mockManager.loadState(CodeWhispererExploreActionState().apply { this.token = "foo" })
106-
107-
assertThat(mockManager.state.token)
108-
.isNotNull
109-
.isEqualTo("foo")
110-
111-
mockManager.nullifyAccountlessCredentialIfNeeded()
112-
113-
assertThat(mockManager.state.token)
114-
.isNull()
99+
assertThat(isCodeWhispererEnabled(project)).isFalse
100+
assertThat(isCodeWhispererExpired(project)).isFalse
115101
}
116102

117103
/**
@@ -120,31 +106,113 @@ class CodeWhispererExplorerActionManagerTest {
120106
* - should return true if loginType == Accountless || Sono || SSO
121107
*/
122108
@Test
123-
fun `test isCodeWhispererEnabled`() {
124-
mockManager = mock()
125-
ApplicationManager.getApplication().replaceService(CodeWhispererExplorerActionManager::class.java, mockManager, disposableRule.disposable)
126-
127-
whenever(mockManager.checkActiveCodeWhispererConnectionType(project)).thenReturn(CodeWhispererLoginType.Logout)
128-
assertThat(isCodeWhispererEnabled(project)).isFalse
129-
130-
whenever(mockManager.checkActiveCodeWhispererConnectionType(project)).thenReturn(CodeWhispererLoginType.Accountless)
131-
assertThat(isCodeWhispererEnabled(project)).isTrue
109+
fun `test connection state`() {
110+
assertConnectionState(
111+
startUrl = SONO_URL,
112+
refreshToken = aString(),
113+
expirationTime = now.plus(1, ChronoUnit.DAYS),
114+
expectedState = BearerTokenAuthState.AUTHORIZED,
115+
expectedLoginType = CodeWhispererLoginType.Sono,
116+
expectedIsCwEnabled = true,
117+
expectedIsCwExpired = false
118+
)
119+
assertThat(ConnectionPinningManager.getInstance().isFeaturePinned(CodeWhispererConnection.getInstance())).isFalse
120+
121+
assertConnectionState(
122+
startUrl = SONO_URL,
123+
refreshToken = aString(),
124+
expirationTime = now.minus(1, ChronoUnit.DAYS),
125+
expectedState = BearerTokenAuthState.NEEDS_REFRESH,
126+
expectedLoginType = CodeWhispererLoginType.Expired,
127+
expectedIsCwEnabled = true,
128+
expectedIsCwExpired = true
129+
)
130+
assertThat(ConnectionPinningManager.getInstance().isFeaturePinned(CodeWhispererConnection.getInstance())).isFalse
131+
132+
assertConnectionState(
133+
startUrl = SONO_URL,
134+
refreshToken = null,
135+
expirationTime = now.minus(1, ChronoUnit.DAYS),
136+
expectedState = BearerTokenAuthState.NOT_AUTHENTICATED,
137+
expectedLoginType = CodeWhispererLoginType.Logout,
138+
expectedIsCwEnabled = false,
139+
expectedIsCwExpired = false
140+
)
141+
assertThat(ConnectionPinningManager.getInstance().isFeaturePinned(CodeWhispererConnection.getInstance())).isFalse
142+
143+
assertConnectionState(
144+
startUrl = aString(),
145+
refreshToken = aString(),
146+
expirationTime = now.plus(1, ChronoUnit.DAYS),
147+
expectedState = BearerTokenAuthState.AUTHORIZED,
148+
expectedLoginType = CodeWhispererLoginType.SSO,
149+
expectedIsCwEnabled = true,
150+
expectedIsCwExpired = false
151+
)
152+
assertThat(ConnectionPinningManager.getInstance().isFeaturePinned(CodeWhispererConnection.getInstance())).isFalse
153+
154+
assertConnectionState(
155+
startUrl = aString(),
156+
refreshToken = aString(),
157+
expirationTime = now.minus(1, ChronoUnit.DAYS),
158+
expectedState = BearerTokenAuthState.NEEDS_REFRESH,
159+
expectedLoginType = CodeWhispererLoginType.Expired,
160+
expectedIsCwEnabled = true,
161+
expectedIsCwExpired = true
162+
)
163+
assertThat(ConnectionPinningManager.getInstance().isFeaturePinned(CodeWhispererConnection.getInstance())).isFalse
164+
165+
assertConnectionState(
166+
startUrl = aString(),
167+
refreshToken = null,
168+
expirationTime = now.minus(1, ChronoUnit.DAYS),
169+
expectedState = BearerTokenAuthState.NOT_AUTHENTICATED,
170+
expectedLoginType = CodeWhispererLoginType.Logout,
171+
expectedIsCwEnabled = false,
172+
expectedIsCwExpired = false
173+
)
174+
assertThat(ConnectionPinningManager.getInstance().isFeaturePinned(CodeWhispererConnection.getInstance())).isFalse
175+
}
132176

133-
whenever(mockManager.checkActiveCodeWhispererConnectionType(project)).thenReturn(CodeWhispererLoginType.Sono)
134-
assertThat(isCodeWhispererEnabled(project)).isTrue
177+
private fun assertConnectionState(
178+
startUrl: String,
179+
refreshToken: String?,
180+
expirationTime: Instant,
181+
expectedState: BearerTokenAuthState,
182+
expectedLoginType: CodeWhispererLoginType,
183+
expectedIsCwEnabled: Boolean,
184+
expectedIsCwExpired: Boolean
185+
) {
186+
testDiskCache.saveAccessToken(
187+
AccessTokenCacheKey(
188+
connectionId = "us-east-1",
189+
startUrl = startUrl,
190+
scopes = CODEWHISPERER_SCOPES
191+
),
192+
AccessToken(
193+
startUrl = startUrl,
194+
region = "us-east-1",
195+
accessToken = aString(),
196+
refreshToken = refreshToken,
197+
expiresAt = expirationTime
198+
)
199+
)
135200

136-
whenever(mockManager.checkActiveCodeWhispererConnectionType(project)).thenReturn(CodeWhispererLoginType.SSO)
137-
assertThat(isCodeWhispererEnabled(project)).isTrue
138-
}
201+
val myConnection = ManagedBearerSsoConnection(
202+
startUrl,
203+
"us-east-1",
204+
CODEWHISPERER_SCOPES,
205+
testDiskCache
206+
)
139207

140-
private fun assertLoginType(startUrl: String, expectedType: CodeWhispererLoginType) {
141-
mockManager = spy()
142-
val conn: ManagedBearerSsoConnection = mock()
143-
whenever(connectionManager.activeConnectionForFeature(any())).thenReturn(conn)
144-
whenever(conn.startUrl).thenReturn(startUrl)
145-
whenever(conn.getConnectionSettings()).thenReturn(null)
208+
ToolkitConnectionManager.getInstance(project).switchConnection(myConnection)
209+
val activeCwConn = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
210+
val myTokenProvider = myConnection.getConnectionSettings().tokenProvider.delegate as InteractiveBearerTokenProvider
146211

147-
val actual = mockManager.checkActiveCodeWhispererConnectionType(project)
148-
assertThat(actual).isEqualTo(expectedType)
212+
assertThat(activeCwConn).isEqualTo(myConnection)
213+
assertThat(myTokenProvider.state()).isEqualTo(expectedState)
214+
assertThat(CodeWhispererExplorerActionManager.getInstance().checkActiveCodeWhispererConnectionType(project)).isEqualTo(expectedLoginType)
215+
assertThat(isCodeWhispererEnabled(project)).isEqualTo(expectedIsCwEnabled)
216+
assertThat(isCodeWhispererExpired(project)).isEqualTo(expectedIsCwExpired)
149217
}
150218
}

0 commit comments

Comments
 (0)