Skip to content

Commit 0a8ab64

Browse files
authored
Make CW auto-reconnect on startup (#3619)
* Make CW auto-reconnect on startup This fixes the issue when CW does not auto-reconnect on IDE startup. CW will try to auto-connect when the connection(accessToken is expired) It will first try to use the refresh token to refresh, if failed, it will prompt a notification for people to re-auth(go through the browser login flow again). Also did some refactors to ensure: 1. CW actions are not performed when CW is logged out, or CW needs to re-auth, for the case when CW needs to refresh, refresh the token and continue with the CW action. 2. Reuse helpers for serveral parts for reconnect, refresh, showing expiry notification and expiry check. * add back codescan reset logic
1 parent d70e0cb commit 0a8ab64

File tree

8 files changed

+74
-78
lines changed

8 files changed

+74
-78
lines changed

jetbrains-core/src/software/aws/toolkits/jetbrains/core/credentials/ToolkitAuthManager.kt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,26 +224,30 @@ fun reauthProviderIfNeeded(project: Project?, tokenProvider: BearerTokenProvider
224224
return tokenProvider
225225
}
226226

227-
fun maybeReauthProviderIfNeeded(project: Project?, tokenProvider: BearerTokenProvider, onReauthRequired: (SsoOidcException?) -> Any) {
227+
// Return true if need to re-auth, false otherwise
228+
fun maybeReauthProviderIfNeeded(project: Project?, tokenProvider: BearerTokenProvider, onReauthRequired: (SsoOidcException?) -> Any): Boolean {
228229
val state = tokenProvider.state()
229230
when (state) {
230231
BearerTokenAuthState.NOT_AUTHENTICATED -> {
231232
getLogger<ToolkitAuthManager>().info { "Token provider NOT_AUTHENTICATED, requesting login" }
232233
onReauthRequired(null)
234+
return true
233235
}
234236

235237
BearerTokenAuthState.NEEDS_REFRESH -> {
236238
try {
237-
runUnderProgressIfNeeded(project, message("credentials.sono.login.refreshing"), true) {
239+
return runUnderProgressIfNeeded(project, message("credentials.sono.login.refreshing"), true) {
238240
tokenProvider.resolveToken()
239241
BearerTokenProviderListener.notifyCredUpdate(tokenProvider.id)
242+
return@runUnderProgressIfNeeded false
240243
}
241244
} catch (e: SsoOidcException) {
242245
getLogger<ToolkitAuthManager>().warn(e) { "Redriving AWS Builder ID login flow since token could not be refreshed" }
243246
onReauthRequired(e)
247+
return true
244248
}
245249
}
246250

247-
BearerTokenAuthState.AUTHORIZED -> {}
251+
BearerTokenAuthState.AUTHORIZED -> { return false }
248252
}
249253
}

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/codescan/CodeWhispererCodeScanManager.kt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.listeners
5858
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.sessionconfig.CodeScanSessionConfig
5959
import software.aws.toolkits.jetbrains.services.codewhisperer.editor.CodeWhispererEditorUtil.overlaps
6060
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager
61+
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.isCodeWhispererEnabled
6162
import software.aws.toolkits.jetbrains.services.codewhisperer.model.CodeScanTelemetryEvent
6263
import software.aws.toolkits.jetbrains.services.codewhisperer.telemetry.CodeWhispererTelemetryService
6364
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererColorUtil.INACTIVE_TEXT_COLOR
@@ -126,14 +127,15 @@ class CodeWhispererCodeScanManager(val project: Project) {
126127
* Triggers a code scan and displays results in the new tab in problems view panel.
127128
*/
128129
fun runCodeScan() {
130+
if (!isCodeWhispererEnabled(project)) return
131+
129132
// Return if a scan is already in progress.
130133
if (isCodeScanInProgress.getAndSet(true)) return
131-
if (CodeWhispererUtil.isConnectionExpired(project)) {
132-
promptReAuth(project) {
133-
isCodeScanInProgress.set(false)
134-
}
134+
if (promptReAuth(project)) {
135+
isCodeScanInProgress.set(false)
135136
return
136137
}
138+
137139
// Prepare for a code scan
138140
beforeCodeScan()
139141
// launch code scan coroutine

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/explorer/CodeWhispererExplorerActionManager.kt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ import software.aws.toolkits.jetbrains.core.explorer.refreshDevToolTree
2020
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererLoginType
2121
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants
2222
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.getConnectionStartUrl
23-
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.isConnectionExpired
23+
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.isAccessTokenExpired
24+
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.isRefreshTokenExpired
2425
import software.aws.toolkits.telemetry.AwsTelemetry
2526
import java.time.LocalDateTime
2627

@@ -119,7 +120,7 @@ class CodeWhispererExplorerActionManager : PersistentStateComponent<CodeWhispere
119120

120121
fun checkActiveCodeWhispererConnectionType(project: Project) = when {
121122
actionState.token != null -> CodeWhispererLoginType.Accountless
122-
isConnectionExpired(project) -> CodeWhispererLoginType.Expired
123+
isAccessTokenExpired(project) || isRefreshTokenExpired(project) -> CodeWhispererLoginType.Expired
123124
else -> {
124125
val conn = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
125126
if (conn != null) {
@@ -198,3 +199,7 @@ interface CodeWhispererActivationChangedListener {
198199
fun isCodeWhispererEnabled(project: Project) = with(CodeWhispererExplorerActionManager.getInstance()) {
199200
checkActiveCodeWhispererConnectionType(project) != CodeWhispererLoginType.Logout
200201
}
202+
203+
fun isCodeWhispererExpired(project: Project) = with(CodeWhispererExplorerActionManager.getInstance()) {
204+
checkActiveCodeWhispererConnectionType(project) == CodeWhispererLoginType.Expired
205+
}

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/explorer/nodes/CodeWhispererReconnectNode.kt

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,8 @@
44
package software.aws.toolkits.jetbrains.services.codewhisperer.explorer.nodes
55

66
import com.intellij.icons.AllIcons
7-
import com.intellij.openapi.application.ApplicationManager
87
import com.intellij.openapi.project.Project
9-
import software.aws.toolkits.jetbrains.core.credentials.BearerSsoConnection
10-
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
11-
import software.aws.toolkits.jetbrains.core.credentials.loginSso
12-
import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
13-
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.getConnectionStartUrl
8+
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.reconnectCodeWhisperer
149
import software.aws.toolkits.resources.message
1510
import java.awt.event.MouseEvent
1611

@@ -21,12 +16,6 @@ class CodeWhispererReconnectNode(nodeProject: Project) : CodeWhispererActionNode
2116
AllIcons.Actions.Execute
2217
) {
2318
override fun onDoubleClick(event: MouseEvent) {
24-
val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
25-
if (connection !is BearerSsoConnection) return
26-
ApplicationManager.getApplication().executeOnPooledThread {
27-
getConnectionStartUrl(connection)?.let { startUrl ->
28-
loginSso(project, startUrl, scopes = connection.scopes)
29-
}
30-
}
19+
reconnectCodeWhisperer(project)
3120
}
3221
}

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/service/CodeWhispererService.kt

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ import software.aws.toolkits.jetbrains.services.codewhisperer.util.CaretMovement
6262
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants
6363
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.checkCompletionType
6464
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.checkEmptyRecommendations
65-
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.isConnectionExpired
6665
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.notifyErrorCodeWhispererUsageLimit
6766
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.promptReAuth
6867
import software.aws.toolkits.resources.message
@@ -81,15 +80,9 @@ class CodeWhispererService {
8180
if (!isCodeWhispererEnabled(project)) return
8281

8382
latencyContext.credentialFetchingStart = System.nanoTime()
84-
if (isConnectionExpired(project)) {
85-
if (triggerTypeInfo.triggerType == CodewhispererTriggerType.AutoTrigger) {
86-
if (reAuthPromptShown) return
87-
promptReAuth(project, ::markReAuthPromptShown)
88-
} else {
89-
promptReAuth(project)
90-
}
91-
return
92-
}
83+
84+
if (promptReAuth(project)) return
85+
9386
latencyContext.credentialFetchingEnd = System.nanoTime()
9487
val psiFile = runReadAction { PsiDocumentManager.getInstance(project).getPsiFile(editor.document) }
9588

@@ -621,10 +614,12 @@ class CodeWhispererService {
621614
const val KET_SESSION_ID = "x-amzn-SessionId"
622615
private var reAuthPromptShown = false
623616

624-
private fun markReAuthPromptShown() {
617+
fun markReAuthPromptShown() {
625618
reAuthPromptShown = true
626619
}
627620

621+
fun hasReAuthPromptBeenShown() = reAuthPromptShown
622+
628623
fun buildCodeWhispererRequest(
629624
fileContextInfo: FileContextInfo
630625
): GenerateCompletionsRequest {

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/startup/CodeWhispererProjectStartupActivity.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import software.aws.toolkits.jetbrains.services.codewhisperer.status.CodeWhisper
1919
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererConstants
2020
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.notifyErrorAccountless
2121
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.notifyWarnAccountless
22+
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.promptReAuth
2223
import java.time.LocalDateTime
2324
import java.util.Date
2425
import java.util.Timer
@@ -40,6 +41,9 @@ class CodeWhispererProjectStartupActivity : StartupActivity.DumbAware {
4041
if (!isCodeWhispererEnabled(project)) return
4142
if (runOnce) return
4243

44+
// Reconnect CodeWhisperer on startup
45+
promptReAuth(project)
46+
4347
CodeWhispererAutoTriggerService.getInstance().determineUserGroupIfNeeded()
4448

4549
// install intellsense autotrigger listener, this only need to be executed 1 time

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/status/CodeWhispererStatusBarWidget.kt

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,11 @@ import com.intellij.openapi.wm.StatusBarWidget
1313
import com.intellij.openapi.wm.impl.status.EditorBasedWidget
1414
import com.intellij.ui.AnimatedIcon
1515
import com.intellij.util.Consumer
16-
import software.aws.toolkits.jetbrains.core.credentials.BearerSsoConnection
17-
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
18-
import software.aws.toolkits.jetbrains.core.credentials.loginSso
19-
import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
2016
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener
21-
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererLoginType
22-
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager
17+
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.isCodeWhispererExpired
2318
import software.aws.toolkits.jetbrains.services.codewhisperer.service.CodeWhispererInvocationStateChangeListener
2419
import software.aws.toolkits.jetbrains.services.codewhisperer.service.CodeWhispererInvocationStatus
25-
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil
20+
import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhispererUtil.reconnectCodeWhisperer
2621
import software.aws.toolkits.resources.message
2722
import java.awt.event.MouseEvent
2823
import javax.swing.Icon
@@ -59,39 +54,23 @@ class CodeWhispererStatusBarWidget(project: Project) :
5954

6055
override fun getClickConsumer(): Consumer<MouseEvent>? = null
6156

62-
override fun getPopupStep(): ListPopup? {
63-
val connectionType = CodeWhispererExplorerActionManager.getInstance().checkActiveCodeWhispererConnectionType(project)
64-
return if (connectionType == CodeWhispererLoginType.Expired) {
65-
JBPopupFactory.getInstance().createConfirmation(message("codewhisperer.statusbar.popup.title"), ::reconnect, 0)
57+
override fun getPopupStep(): ListPopup? =
58+
if (isCodeWhispererExpired(project)) {
59+
JBPopupFactory.getInstance().createConfirmation(message("codewhisperer.statusbar.popup.title"), { reconnectCodeWhisperer(project) }, 0)
6660
} else {
6761
null
6862
}
69-
}
70-
71-
private fun reconnect() {
72-
val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
73-
if (connection !is BearerSsoConnection) {
74-
return
75-
}
76-
ApplicationManager.getApplication().executeOnPooledThread {
77-
CodeWhispererUtil.getConnectionStartUrl(connection)?.let { startUrl ->
78-
loginSso(project, startUrl, scopes = connection.scopes)
79-
}
80-
}
81-
}
8263

8364
override fun getSelectedValue(): String = message("codewhisperer.statusbar.display_name")
8465

85-
override fun getIcon(): Icon {
86-
val connectionType = CodeWhispererExplorerActionManager.getInstance().checkActiveCodeWhispererConnectionType(project)
87-
return if (connectionType == CodeWhispererLoginType.Expired) {
66+
override fun getIcon(): Icon =
67+
if (isCodeWhispererExpired(project)) {
8868
AllIcons.General.BalloonWarning
8969
} else if (CodeWhispererInvocationStatus.getInstance().hasExistingInvocation()) {
9070
AnimatedIcon.Default()
9171
} else {
9272
AllIcons.Actions.Commit
9373
}
94-
}
9574

9675
companion object {
9776
const val ID = "aws.codewhisperer.statusWidget"

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererUtil.kt

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import software.aws.toolkits.jetbrains.services.codewhisperer.actions.ConnectWit
2525
import software.aws.toolkits.jetbrains.services.codewhisperer.actions.DoNotShowAgainActionError
2626
import software.aws.toolkits.jetbrains.services.codewhisperer.actions.DoNotShowAgainActionWarn
2727
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager
28+
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.isCodeWhispererExpired
29+
import software.aws.toolkits.jetbrains.services.codewhisperer.service.CodeWhispererService
2830
import software.aws.toolkits.jetbrains.utils.notifyError
2931
import software.aws.toolkits.jetbrains.utils.notifyInfo
3032
import software.aws.toolkits.jetbrains.utils.notifyWarn
@@ -92,40 +94,46 @@ object CodeWhispererUtil {
9294
listOf(CodeWhispererSsoLearnMoreAction(), ConnectWithAwsToContinueActionError(), DoNotShowAgainActionError())
9395
)
9496

95-
fun isConnectionExpired(project: Project): Boolean {
97+
fun isAccessTokenExpired(project: Project): Boolean {
9698
val tokenProvider = tokenProvider(project) ?: return false
9799
val state = tokenProvider.state()
98-
return state == BearerTokenAuthState.NEEDS_REFRESH || state == BearerTokenAuthState.NOT_AUTHENTICATED
100+
return state == BearerTokenAuthState.NEEDS_REFRESH
99101
}
100102

101-
fun promptReAuth(project: Project, callback: () -> Unit = {}) {
102-
val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
103-
if (connection !is BearerSsoConnection) return
104-
val tokenProvider = tokenProvider(project) ?: return
105-
maybeReauthProviderIfNeeded(project, tokenProvider) {
103+
fun isRefreshTokenExpired(project: Project): Boolean {
104+
val tokenProvider = tokenProvider(project) ?: return false
105+
val state = tokenProvider.state()
106+
return state == BearerTokenAuthState.NOT_AUTHENTICATED
107+
}
108+
109+
// This will be called only when there's a CW connection, but it has expired(either accessToken or refreshToken)
110+
// 1. If connection is expired, try to refresh
111+
// 2. If not able to refresh, requesting re-login by showing a notification
112+
// 3. The notification will be shown at most once per IDE session
113+
// Return true if need to re-auth, false otherwise
114+
fun promptReAuth(project: Project): Boolean {
115+
if (CodeWhispererService.hasReAuthPromptBeenShown()) return false
116+
if (!isCodeWhispererExpired(project)) return false
117+
val tokenProvider = tokenProvider(project) ?: return false
118+
return maybeReauthProviderIfNeeded(project, tokenProvider) {
106119
runInEdt {
107-
notifyConnectionExpiredRequestReauth(project, connection)
108-
callback()
120+
notifyConnectionExpiredRequestReauth(project)
121+
CodeWhispererService.markReAuthPromptShown()
109122
}
110123
}
111124
}
112125

113-
private fun notifyConnectionExpiredRequestReauth(project: Project, connection: BearerSsoConnection?) {
114-
connection ?: return
126+
private fun notifyConnectionExpiredRequestReauth(project: Project) {
115127
if (CodeWhispererExplorerActionManager.getInstance().getConnectionExpiredDoNotShowAgain()) {
116128
return
117129
}
118130
notifyError(
119-
message("toolkit.sso_expire.dialog.title", connection.label),
131+
message("toolkit.sso_expire.dialog.title"),
120132
message("toolkit.sso_expire.dialog_message"),
121133
project,
122134
listOf(
123135
NotificationAction.create(message("toolkit.sso_expire.dialog.yes_button")) { _, notification ->
124-
ApplicationManager.getApplication().executeOnPooledThread {
125-
getConnectionStartUrl(connection)?.let { startUrl ->
126-
loginSso(project, startUrl, scopes = connection.scopes)
127-
}
128-
}
136+
reconnectCodeWhisperer(project)
129137
notification.expire()
130138
},
131139
NotificationAction.create(message("toolkit.sso_expire.dialog.no_button")) { _, notification ->
@@ -150,6 +158,16 @@ object CodeWhispererUtil {
150158
?.getConnectionSettings()
151159
?.tokenProvider
152160
?.delegate as? BearerTokenProvider
161+
162+
fun reconnectCodeWhisperer(project: Project) {
163+
val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
164+
if (connection !is BearerSsoConnection) return
165+
ApplicationManager.getApplication().executeOnPooledThread {
166+
getConnectionStartUrl(connection)?.let { startUrl ->
167+
loginSso(project, startUrl, scopes = connection.scopes)
168+
}
169+
}
170+
}
153171
}
154172

155173
enum class CaretMovement {

0 commit comments

Comments
 (0)