@@ -5,6 +5,7 @@ package software.aws.toolkits.jetbrains.core.credentials
55
66import com.intellij.openapi.project.Project
77import com.intellij.testFramework.ApplicationExtension
8+ import io.mockk.clearAllMocks
89import org.junit.jupiter.api.BeforeEach
910import org.junit.jupiter.api.Test
1011import org.junit.jupiter.api.TestInstance
@@ -13,15 +14,20 @@ import org.mockito.kotlin.never
1314import io.mockk.mockkStatic
1415import io.mockk.verify
1516import io.mockk.every
17+ import io.mockk.just
18+ import io.mockk.mockkObject
19+ import io.mockk.runs
1620import org.mockito.kotlin.verify
1721import org.mockito.kotlin.whenever
1822import org.junit.jupiter.api.Assertions.*
1923import org.junit.jupiter.api.extension.ExtendWith
2024import org.mockito.kotlin.doThrow
25+ import org.mockito.kotlin.reset
2126import software.amazon.awssdk.auth.token.credentials.SdkToken
2227import software.aws.toolkits.jetbrains.core.credentials.sso.DeviceAuthorizationGrantToken
2328import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenAuthState
2429import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
30+ import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener
2531import software.aws.toolkits.jetbrains.utils.notifyInfo
2632import software.aws.toolkits.resources.AwsCoreBundle.message
2733import java.net.UnknownHostException
@@ -34,22 +40,23 @@ class ToolkitAuthManagerTest {
3440 private lateinit var project: Project
3541 private lateinit var tokenProvider: BearerTokenProvider
3642 private var reauthCallCount = 0
37- private var notificationShown = false
3843
3944 @BeforeEach
4045 fun setUp () {
4146 project = mock()
4247 tokenProvider = mock()
4348 reauthCallCount = 0
44- notificationShown = false
49+ val field = Class .forName(" software.aws.toolkits.jetbrains.core.credentials.ToolkitAuthManagerKt" )
50+ .getDeclaredField(" hasNotifiedNetworkErrorOnce" )
51+ field.isAccessible = true
52+ field.set(null , false )
4553
46- // Mock the notifyInfo function
54+ mockkObject( BearerTokenProviderListener )
4755 mockkStatic(" software.aws.toolkits.jetbrains.utils.NotificationUtilsKt" )
4856 every {
4957 notifyInfo(any(), any(), any())
50- } answers {
51- notificationShown = true
52- }
58+ } just runs
59+ every { BearerTokenProviderListener .notifyCredUpdate(any<String >()) } just runs
5360 }
5461
5562
@@ -68,8 +75,7 @@ class ToolkitAuthManagerTest {
6875
6976 assertFalse(result)
7077 assertEquals(0 , reauthCallCount)
71- assertTrue(notificationShown)
72- verify {
78+ verify(exactly = 1 ){
7379 notifyInfo(
7480 message(" general.auth.network.error" ),
7581 message(" general.auth.network.error.message" ),
@@ -92,9 +98,6 @@ class ToolkitAuthManagerTest {
9298 tokenProvider
9399 ) { _ -> reauthCallCount++ }
94100
95- // Reset our tracking variable
96- notificationShown = false
97-
98101 // Second call - should not show notification
99102 val result = maybeReauthProviderIfNeeded(
100103 project,
@@ -104,7 +107,6 @@ class ToolkitAuthManagerTest {
104107
105108 assertFalse(result)
106109 assertEquals(0 , reauthCallCount)
107- assertFalse(notificationShown)
108110 verify(exactly = 1 ) {
109111 notifyInfo(
110112 message(" general.auth.network.error" ),
@@ -129,7 +131,10 @@ class ToolkitAuthManagerTest {
129131 tokenProvider
130132 ) { _ -> reauthCallCount++ }
131133
134+ reset(tokenProvider)
135+
132136 // Now simulate successful refresh
137+ whenever(tokenProvider.state()).thenReturn(BearerTokenAuthState .NEEDS_REFRESH )
133138 whenever(tokenProvider.resolveToken()).thenReturn(
134139 DeviceAuthorizationGrantToken (
135140 startUrl = " https://example.com" ,
@@ -145,10 +150,9 @@ class ToolkitAuthManagerTest {
145150 tokenProvider
146151 ) { _ -> reauthCallCount++ }
147152
148- // Reset tracking
149- notificationShown = false
150-
153+ reset(tokenProvider)
151154 // Now trigger another network error - should show notification again
155+ whenever(tokenProvider.state()).thenReturn(BearerTokenAuthState .NEEDS_REFRESH )
152156 doThrow(RuntimeException (" Unable to execute HTTP request" ))
153157 .`when `(tokenProvider)
154158 .resolveToken()
@@ -158,7 +162,6 @@ class ToolkitAuthManagerTest {
158162 tokenProvider
159163 ) { _ -> reauthCallCount++ }
160164
161- assertTrue(notificationShown)
162165 verify(exactly = 2 ) {
163166 notifyInfo(
164167 message(" general.auth.network.error" ),
0 commit comments