44package software.aws.toolkits.jetbrains.services.codewhisperer
55
66import com.intellij.openapi.application.ApplicationManager
7- import com.intellij.openapi.util.Disposer
87import com.intellij.openapi.util.SystemInfo
98import com.intellij.testFramework.DisposableRule
109import com.intellij.testFramework.RuleChain
@@ -15,6 +14,7 @@ import org.junit.After
1514import org.junit.Before
1615import org.junit.Rule
1716import org.junit.Test
17+ import org.junit.jupiter.api.assertThrows
1818import org.mockito.kotlin.any
1919import org.mockito.kotlin.argThat
2020import org.mockito.kotlin.argumentCaptor
@@ -24,7 +24,6 @@ import org.mockito.kotlin.mock
2424import org.mockito.kotlin.stub
2525import org.mockito.kotlin.times
2626import org.mockito.kotlin.verify
27- import org.mockito.kotlin.whenever
2827import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient
2928import software.amazon.awssdk.services.codewhispererruntime.model.ArtifactType
3029import software.amazon.awssdk.services.codewhispererruntime.model.CodeAnalysisFindingsSchema
@@ -54,21 +53,22 @@ import software.amazon.awssdk.services.codewhispererruntime.model.SuggestionStat
5453import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable
5554import software.amazon.awssdk.services.codewhispererruntime.paginators.ListAvailableCustomizationsIterable
5655import software.amazon.awssdk.services.ssooidc.SsoOidcClient
57- import software.aws.toolkits.core.TokenConnectionSettings
5856import software.aws.toolkits.core.utils.test.aString
5957import software.aws.toolkits.jetbrains.core.MockClientManagerRule
6058import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
6159import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
6260import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
6361import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
6462import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
63+ import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection
64+ import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
65+ import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
6566import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION
6667import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME
6768import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata
6869import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonRequest
6970import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponseWithToken
7071import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.sdkHttpResponse
71- import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor
7272import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl
7373import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
7474import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererModelConfigurator
@@ -93,13 +93,12 @@ class CodeWhispererClientAdaptorTest {
9393
9494 @Rule
9595 @JvmField
96- val ruleChain = RuleChain (projectRule, mockCredentialRule, mockClientManagerRule, disposableRule)
96+ val ruleChain = RuleChain (projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule)
9797
9898 private lateinit var bearerClient: CodeWhispererRuntimeClient
9999 private lateinit var ssoClient: SsoOidcClient
100100
101- private lateinit var sut: CodeWhispererClientAdaptor
102- private lateinit var connectionManager: ToolkitConnectionManager
101+ private lateinit var sut: CodeWhispererClientAdaptorImpl
103102 private var isTelemetryEnabledDefault: Boolean = false
104103
105104 @Before
@@ -117,15 +116,8 @@ class CodeWhispererClientAdaptorTest {
117116 on { listFeatureEvaluations(any<ListFeatureEvaluationsRequest >()) } doReturn listFeatureEvaluationsResponse
118117 }
119118
120- val mockConnection = mock<AwsBearerTokenConnection >()
121- whenever(mockConnection.getConnectionSettings()) doReturn mock<TokenConnectionSettings >()
122-
123- connectionManager = mock {
124- on {
125- activeConnectionForFeature(any())
126- } doReturn authManagerRule.createConnection(ManagedSsoProfile (" us-east-1" , aString(), listOf (" scopes" ))) as AwsBearerTokenConnection
127- }
128- projectRule.project.replaceService(ToolkitConnectionManager ::class .java, connectionManager, disposableRule.disposable)
119+ val conn = authManagerRule.createConnection(ManagedSsoProfile (" us-east-1" , " url" , Q_SCOPES ))
120+ ToolkitConnectionManager .getInstance(projectRule.project).switchConnection(conn)
129121
130122 isTelemetryEnabledDefault = AwsSettings .getInstance().isTelemetryEnabled
131123 }
@@ -135,16 +127,37 @@ class CodeWhispererClientAdaptorTest {
135127 AwsSettings .getInstance().isTelemetryEnabled = isTelemetryEnabledDefault
136128 }
137129
138- @After
139- fun cleanup () {
140- Disposer .dispose(sut)
141- }
142-
143130 @Test
144131 fun `Sono region is us-east-1` () {
145132 assertThat(" us-east-1" ).isEqualTo(SONO_REGION )
146133 }
147134
135+ @Test
136+ fun `should throw if there is no valid credential, otherwise return codewhispererRuntimeClient` () {
137+ val connectionManager = ToolkitConnectionManager .getInstance(projectRule.project)
138+
139+ assertThat(connectionManager.activeConnectionForFeature(QConnection .getInstance()))
140+ .isNotNull
141+ assertThat(sut.bearerClient())
142+ .isNotNull
143+ .isInstanceOf(CodeWhispererRuntimeClient ::class .java)
144+
145+ logoutFromSsoConnection(projectRule.project, connectionManager.activeConnectionForFeature(QConnection .getInstance()) as AwsBearerTokenConnection )
146+ assertThat(connectionManager.activeConnectionForFeature(QConnection .getInstance())).isNull()
147+ assertThrows<Exception >(" attempt to get bearer client while there is no valid credential" ) {
148+ sut.bearerClient()
149+ }
150+
151+ val anotherQConnection = authManagerRule.createConnection(ManagedSsoProfile (" us-east-1" , aString(), Q_SCOPES ))
152+ connectionManager.switchConnection(anotherQConnection)
153+ assertThat(connectionManager.activeConnectionForFeature(QConnection .getInstance()))
154+ .isNotNull
155+ .isEqualTo(anotherQConnection)
156+ assertThat(sut.bearerClient())
157+ .isNotNull
158+ .isInstanceOf(CodeWhispererRuntimeClient ::class .java)
159+ }
160+
148161 @Test
149162 fun `listCustomizations` () {
150163 val sdkIterable = ListAvailableCustomizationsIterable (bearerClient, ListAvailableCustomizationsRequest .builder().build())
0 commit comments