@@ -6,6 +6,7 @@ package software.aws.toolkits.jetbrains.core.credentials.sso.pkce
6
6
import com.intellij.collaboration.auth.OAuthCallbackHandlerBase
7
7
import com.intellij.collaboration.auth.services.OAuthCredentialsAcquirer
8
8
import com.intellij.collaboration.auth.services.OAuthRequest
9
+ import com.intellij.collaboration.auth.services.OAuthService
9
10
import com.intellij.collaboration.auth.services.OAuthServiceBase
10
11
import com.intellij.collaboration.auth.services.PkceUtils
11
12
import com.intellij.openapi.application.ApplicationNamesInfo
@@ -61,17 +62,23 @@ class ToolkitOAuthService : OAuthServiceBase<AccessToken>() {
61
62
return authorize(ToolkitOAuthRequest (registration))
62
63
}
63
64
64
- override fun handleServerCallback (path : String , parameters : Map <String , List <String >>): Boolean {
65
- val request = currentRequest.get() ? : return false
66
- val toolkitRequest = request.request as ? ToolkitOAuthRequest ? : return false
65
+ override fun handleOAuthServerCallback (path : String , parameters : Map <String , List <String >>): OAuthService . OAuthResult < AccessToken > ? {
66
+ val request = currentRequest.get() ? : return OAuthService . OAuthResult ( null , false )
67
+ val toolkitRequest = request.request as ? ToolkitOAuthRequest ? : return OAuthService . OAuthResult (request.request, false )
67
68
68
69
val callbackState = parameters[" state" ]?.firstOrNull()
69
70
if (toolkitRequest.csrfToken != callbackState) {
70
71
request.result.completeExceptionally(RuntimeException (" Invalid CSRF token" ))
71
- return false
72
+ return OAuthService .OAuthResult (toolkitRequest, false )
73
+ }
74
+
75
+ if (parameters[" code" ] == null ) {
76
+ val error = parameters[" error" ]?.firstOrNull()
77
+ val errorDescription = parameters[" error_description" ]?.firstOrNull()
78
+ toolkitRequest.error = OAuthError (error = error, errorDescription = errorDescription)
72
79
}
73
80
74
- return super .handleServerCallback (path, parameters)
81
+ return super .handleOAuthServerCallback (path, parameters)
75
82
}
76
83
77
84
override fun revokeToken (token : String ) {
@@ -83,6 +90,11 @@ class ToolkitOAuthService : OAuthServiceBase<AccessToken>() {
83
90
}
84
91
}
85
92
93
+ private data class OAuthError (
94
+ val error : String? ,
95
+ val errorDescription : String?
96
+ )
97
+
86
98
private class ToolkitOAuthRequest (internal val registration : PKCEClientRegistration ) : OAuthRequest<AccessToken> {
87
99
private val port: Int get() = BuiltInServerManager .getInstance().port
88
100
private val base64Encoder = Base64 .getUrlEncoder().withoutPadding()
@@ -120,6 +132,8 @@ private class ToolkitOAuthRequest(internal val registration: PKCEClientRegistrat
120
132
)
121
133
122
134
private fun randB64url (bits : Int ): String = base64Encoder.encodeToString(BigInteger (bits, DigestUtil .random).toByteArray())
135
+
136
+ internal var error: OAuthError ? = null
123
137
}
124
138
125
139
// exchange for real token
@@ -157,7 +171,7 @@ internal class ToolkitOAuthCallbackHandler : OAuthCallbackHandlerBase() {
157
171
override fun oauthService () = ToolkitOAuthService .getInstance()
158
172
159
173
// on success / fail
160
- override fun handleAcceptCode ( isAccepted : Boolean ): AcceptCodeHandleResult {
174
+ override fun handleOAuthResult ( oAuthResult : OAuthService . OAuthResult < * > ): AcceptCodeHandleResult {
161
175
// focus should be on requesting component?
162
176
runInEdt {
163
177
IdeFocusManager .getGlobalInstance().getLastFocusedIdeWindow()?.toFront()
@@ -166,16 +180,22 @@ internal class ToolkitOAuthCallbackHandler : OAuthCallbackHandlerBase() {
166
180
val urlBase = newFromEncoded(
167
181
" http://127.0.0.1:${BuiltInServerManager .getInstance().port} /api/${ToolkitOAuthCallbackResultService .SERVICE_NAME } /index.html"
168
182
)
169
- val params = if (isAccepted) {
183
+ val params = if (oAuthResult. isAccepted) {
170
184
mapOf (
171
185
" productName" to PKCE_CLIENT_NAME ,
172
186
// we don't have the request context to get the requested scopes in this callback until 233
173
187
" scopes" to ApplicationNamesInfo .getInstance().fullProductName
174
188
)
175
189
} else {
190
+ val (error, errorDescription) = (oAuthResult.request as ? ToolkitOAuthRequest )?.error ? : OAuthError (null , null )
191
+ val errorString = if (error != null && errorDescription != null ) {
192
+ " $error : $errorDescription "
193
+ } else {
194
+ errorDescription ? : error ? : message(" general.unknown_error" )
195
+ }
196
+
176
197
mapOf (
177
- // when 233, check if we can retrieve the underlying error
178
- " error" to message(" general.unknown_error" )
198
+ " error" to errorString
179
199
)
180
200
}
181
201
0 commit comments