Skip to content

Commit 67c8da0

Browse files
authored
Q Code Transform: Refresh token if grant expires during transformation (#4638)
1 parent c9dd554 commit 67c8da0

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

plugins/amazonq/codetransform/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codemodernizer/utils/CodeTransformApiUtils.kt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import software.amazon.awssdk.services.codewhispererruntime.model.Transformation
1616
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationProgressUpdate
1717
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationStatus
1818
import software.amazon.awssdk.services.codewhispererruntime.model.ValidationException
19+
import software.amazon.awssdk.services.ssooidc.model.InvalidGrantException
1920
import software.aws.toolkits.core.utils.WaiterUnrecoverableException
2021
import software.aws.toolkits.core.utils.Waiters.waitUntil
2122
import software.aws.toolkits.jetbrains.services.codemodernizer.CodeTransformTelemetryManager
@@ -53,6 +54,9 @@ suspend fun JobId.pollTransformationStatusAndPlan(
5354
var didSleepOnce = false
5455
val maxRefreshes = 10
5556
var numRefreshes = 0
57+
58+
// We refresh token at the start of polling, but for some long jobs that runs for 30 minutes plus, tokens may need to be
59+
// refreshed again when AccessDeniedException or InvalidGrantException are caught.
5660
refreshToken(project)
5761

5862
try {
@@ -99,6 +103,10 @@ suspend fun JobId.pollTransformationStatusAndPlan(
99103
if (numRefreshes++ > maxRefreshes) throw e
100104
refreshToken(project)
101105
return@waitUntil state
106+
} catch (e: InvalidGrantException) {
107+
if (numRefreshes++ > maxRefreshes) throw e
108+
refreshToken(project)
109+
return@waitUntil state
102110
} finally {
103111
sleep(sleepDurationMillis)
104112
}
@@ -107,7 +115,7 @@ suspend fun JobId.pollTransformationStatusAndPlan(
107115
// Still call onStateChange to update the UI
108116
onStateChange(state, TransformationStatus.FAILED, transformationPlan)
109117
when (e) {
110-
is WaiterUnrecoverableException, is AccessDeniedException -> {
118+
is WaiterUnrecoverableException, is AccessDeniedException, is InvalidGrantException -> {
111119
return PollingResult(false, transformationResponse?.transformationJob(), state, transformationPlan)
112120
}
113121
else -> throw e

plugins/amazonq/codetransform/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codemodernizer/CodeWhispererCodeModernizerUtilsTest.kt

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
package software.aws.toolkits.jetbrains.services.codemodernizer
5+
import io.mockk.every
6+
import io.mockk.just
7+
import io.mockk.mockkStatic
8+
import io.mockk.runs
59
import kotlinx.coroutines.runBlocking
610
import org.assertj.core.api.Assertions.assertThat
711
import org.junit.Before
@@ -11,10 +15,13 @@ import org.mockito.kotlin.any
1115
import org.mockito.kotlin.times
1216
import org.mockito.kotlin.verify
1317
import org.mockito.kotlin.whenever
18+
import software.amazon.awssdk.services.codewhispererruntime.model.AccessDeniedException
1419
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationProgressUpdate
1520
import software.amazon.awssdk.services.codewhispererruntime.model.TransformationStatus
21+
import software.amazon.awssdk.services.ssooidc.model.InvalidGrantException
1622
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.getTableMapping
1723
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.pollTransformationStatusAndPlan
24+
import software.aws.toolkits.jetbrains.services.codemodernizer.utils.refreshToken
1825
import java.util.concurrent.atomic.AtomicBoolean
1926

2027
class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase() {
@@ -57,6 +64,88 @@ class CodeWhispererCodeModernizerUtilsTest : CodeWhispererCodeModernizerTestBase
5764
assertThat(expected).isEqualTo(mutableList)
5865
}
5966

67+
@Test
68+
fun `refresh on access denied`() {
69+
val mockAccessDeniedException = Mockito.mock(AccessDeniedException::class.java)
70+
71+
mockkStatic(::refreshToken)
72+
every { refreshToken(any()) } just runs
73+
74+
Mockito.doThrow(
75+
mockAccessDeniedException
76+
).doReturn(
77+
exampleGetCodeMigrationResponse,
78+
exampleGetCodeMigrationResponse.replace(TransformationStatus.STARTED),
79+
exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED), // Should stop before this point
80+
).whenever(clientAdaptorSpy).getCodeModernizationJob(any())
81+
82+
Mockito.doReturn(exampleGetCodeMigrationPlanResponse)
83+
.whenever(clientAdaptorSpy).getCodeModernizationPlan(any())
84+
85+
val mutableList = mutableListOf<TransformationStatus>()
86+
runBlocking {
87+
jobId.pollTransformationStatusAndPlan(
88+
setOf(TransformationStatus.STARTED),
89+
setOf(TransformationStatus.FAILED),
90+
clientAdaptorSpy,
91+
0,
92+
0,
93+
AtomicBoolean(false),
94+
project
95+
) { _, status, _ ->
96+
mutableList.add(status)
97+
}
98+
}
99+
val expected =
100+
listOf<TransformationStatus>(
101+
exampleGetCodeMigrationResponse.transformationJob().status(),
102+
TransformationStatus.STARTED,
103+
)
104+
assertThat(expected).isEqualTo(mutableList)
105+
io.mockk.verify { refreshToken(any()) }
106+
}
107+
108+
@Test
109+
fun `refresh on invalid grant`() {
110+
val mockInvalidGrantException = Mockito.mock(InvalidGrantException::class.java)
111+
112+
mockkStatic(::refreshToken)
113+
every { refreshToken(any()) } just runs
114+
115+
Mockito.doThrow(
116+
mockInvalidGrantException
117+
).doReturn(
118+
exampleGetCodeMigrationResponse,
119+
exampleGetCodeMigrationResponse.replace(TransformationStatus.STARTED),
120+
exampleGetCodeMigrationResponse.replace(TransformationStatus.COMPLETED), // Should stop before this point
121+
).whenever(clientAdaptorSpy).getCodeModernizationJob(any())
122+
123+
Mockito.doReturn(exampleGetCodeMigrationPlanResponse)
124+
.whenever(clientAdaptorSpy).getCodeModernizationPlan(any())
125+
126+
val mutableList = mutableListOf<TransformationStatus>()
127+
runBlocking {
128+
jobId.pollTransformationStatusAndPlan(
129+
setOf(TransformationStatus.STARTED),
130+
setOf(TransformationStatus.FAILED),
131+
clientAdaptorSpy,
132+
0,
133+
0,
134+
AtomicBoolean(false),
135+
project
136+
) { _, status, _ ->
137+
mutableList.add(status)
138+
}
139+
}
140+
val expected =
141+
listOf<TransformationStatus>(
142+
exampleGetCodeMigrationResponse.transformationJob().status(),
143+
TransformationStatus.STARTED,
144+
)
145+
assertThat(expected).isEqualTo(mutableList)
146+
io.mockk.verify { refreshToken(any()) }
147+
}
148+
60149
@Test
61150
fun `stops polling when status transitions to failOn`() {
62151
Mockito.doReturn(

0 commit comments

Comments
 (0)