Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions packages/amazonq/src/extensionNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,11 @@ async function activateAmazonQNode(context: vscode.ExtensionContext) {
async function getAuthState(): Promise<Omit<AuthUserState, 'source'>> {
const state = AuthUtil.instance.getAuthState()

if (AuthUtil.instance.isConnected() && !(AuthUtil.instance.isSsoSession() || isSageMaker())) {
getLogger().error('Current Amazon Q connection is not SSO')
if (
AuthUtil.instance.isConnected() &&
!(AuthUtil.instance.isSsoSession() || AuthUtil.instance.isIamSession() || isSageMaker())
) {
getLogger().error('Current Amazon Q connection is not SSO nor IAM')
}

return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ export class InlineChatProvider {
private async generateResponse(
triggerPayload: TriggerPayload & { projectContextQueryLatencyMs?: number },
triggerID: string
) {
): Promise<GenerateAssistantResponseCommandOutput | undefined> {
const triggerEvent = this.triggerEventsStorage.getTriggerEvent(triggerID)
if (triggerEvent === undefined) {
return
Expand Down Expand Up @@ -182,7 +182,12 @@ export class InlineChatProvider {
let response: GenerateAssistantResponseCommandOutput | undefined = undefined
session.createNewTokenSource()
try {
response = await session.chatSso(request)
if (AuthUtil.instance.isSsoSession()) {
response = await session.chatSso(request)
} else {
// Call sendMessage because Q Developer Streaming Client does not have generateAssistantResponse
throw new ToolkitError('Inline chat is only available with SSO authentication')
}
getLogger().info(
`response to tab: ${tabID} conversationID: ${session.sessionIdentifier} requestID: ${response.$metadata.requestId} metadata: %O`,
response.$metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ describe('AuthUtil', async function () {
await auth.getIamCredential()
assert.fail('Should have thrown an error')
} catch (err) {
assert.strictEqual((err as Error).message, 'Cannot get token with SSO session')
assert.strictEqual((err as Error).message, 'Cannot get credential without logging in with IAM.')
}
})

Expand All @@ -494,7 +494,7 @@ describe('AuthUtil', async function () {
await auth.getIamCredential()
assert.fail('Should have thrown an error')
} catch (err) {
assert.strictEqual((err as Error).message, 'Cannot get credential without logging in.')
assert.strictEqual((err as Error).message, 'Cannot get credential without logging in with IAM.')
}
})
})
Expand Down
9 changes: 7 additions & 2 deletions packages/core/src/auth/auth2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,13 @@ export abstract class BaseLogin {
* Decrypts an encrypted string, removes its quotes, and returns the resulting string
*/
protected async decrypt(encrypted: string): Promise<string> {
const decrypted = await jose.compactDecrypt(encrypted, this.lspAuth.encryptionKey)
return decrypted.plaintext.toString().replaceAll('"', '')
try {
const decrypted = await jose.compactDecrypt(encrypted, this.lspAuth.encryptionKey)
return decrypted.plaintext.toString().replaceAll('"', '')
} catch (e) {
getLogger().error(`Failed to decrypt: ${encrypted}`)
return encrypted
}
}
}

Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/codewhisperer/ui/codeWhispererNodes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,11 @@ export function createManageSubscription(): DataQuickPickItem<'manageSubscriptio
export function createSignout(): DataQuickPickItem<'signout'> {
const label = localize('AWS.codewhisperer.signoutNode.label', 'Sign Out')
const icon = getIcon('vscode-export')
const connection = AuthUtil.instance.isBuilderIdConnection() ? 'AWS Builder ID' : 'IAM Identity Center'
const connection = AuthUtil.instance.isIamSession()
? 'IAM Credentials'
: AuthUtil.instance.isBuilderIdConnection()
? 'AWS Builder ID'
: 'IAM Identity Center'

return {
data: 'signout',
Expand Down
58 changes: 32 additions & 26 deletions packages/core/src/codewhisperer/util/authUtil.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@ import { showAmazonQWalkthroughOnce } from '../../amazonq/onboardingPage/walkthr
import { setContext } from '../../shared/vscode/setContext'
import { openUrl } from '../../shared/utilities/vsCodeUtils'
import { telemetry } from '../../shared/telemetry/telemetry'
import { AuthStateEvent, cacheChangedEvent, LanguageClientAuth, Login, SsoLogin, IamLogin } from '../../auth/auth2'
import {
AuthStateEvent,
cacheChangedEvent,
LanguageClientAuth,
Login,
SsoLogin,
IamLogin,
AuthState,
LoginTypes,
} from '../../auth/auth2'
import { builderIdStartUrl, internalStartUrl } from '../../auth/sso/constants'
import { VSCODE_EXTENSION_ID } from '../../shared/extensions'
import { RegionProfileManager } from '../region/regionProfileManager'
Expand Down Expand Up @@ -108,11 +117,11 @@ export class AuthUtil implements IAuthProvider {
}

isSsoSession(): boolean {
return this.session instanceof SsoLogin
return this.session?.loginType === LoginTypes.SSO
}

isIamSession(): boolean {
return this.session instanceof IamLogin
return this.session?.loginType === LoginTypes.IAM
}

/**
Expand Down Expand Up @@ -204,36 +213,31 @@ export class AuthUtil implements IAuthProvider {
}

async getToken() {
if (this.session) {
const token = (await this.session.getCredential()).credential
if (typeof token !== 'string') {
throw new ToolkitError('Cannot get token with IAM session')
}
return token
if (this.isSsoSession()) {
const response = await this.session!.getCredential()
return response.credential as string
} else {
throw new ToolkitError('Cannot get credential without logging in.')
throw new ToolkitError('Cannot get credential without logging in with SSO.')
}
}

async getIamCredential() {
if (this.session) {
const credential = (await this.session.getCredential()).credential
if (typeof credential !== 'object') {
throw new ToolkitError('Cannot get token with SSO session')
}
return credential
if (this.isIamSession()) {
const response = await this.session!.getCredential()
return response.credential as IamCredentials
} else {
throw new ToolkitError('Cannot get credential without logging in.')
throw new ToolkitError('Cannot get credential without logging in with IAM.')
}
}

get connection() {
return this.session?.data
}

getAuthState() {
if (this.session) {
return this.session.getConnectionState()
getAuthState(): AuthState {
// Check if getConnectionState exists in case of type casts
if (typeof this.session?.getConnectionState === 'function') {
return this.session!.getConnectionState()
} else {
return 'notConnected'
}
Expand Down Expand Up @@ -356,11 +360,12 @@ export class AuthUtil implements IAuthProvider {

private async stateChangeHandler(e: AuthStateEvent) {
if (e.state === 'refreshed') {
const params = this.session ? (await this.session.getCredential()).updateCredentialsParams : undefined
if (this.isSsoSession()) {
await this.lspAuth.updateBearerToken(params)
const params = await this.session!.getCredential()
await this.lspAuth.updateBearerToken(params.updateCredentialsParams)
} else if (this.isIamSession()) {
await this.lspAuth.updateIamCredential(params)
const params = await this.session!.getCredential()
await this.lspAuth.updateIamCredential(params.updateCredentialsParams)
}
} else {
this.logger.info(`codewhisperer: connection changed to ${e.state}`)
Expand All @@ -383,11 +388,12 @@ export class AuthUtil implements IAuthProvider {
this.session = undefined
}
if (state === 'connected') {
const params = this.session ? (await this.session.getCredential()).updateCredentialsParams : undefined
if (this.isSsoSession()) {
await this.lspAuth.updateBearerToken(params)
const params = await this.session!.getCredential()
await this.lspAuth.updateBearerToken(params.updateCredentialsParams)
} else if (this.isIamSession()) {
await this.lspAuth.updateIamCredential(params)
const params = await this.session!.getCredential()
await this.lspAuth.updateIamCredential(params.updateCredentialsParams)
}

if (this.isIdcConnection()) {
Expand Down
12 changes: 11 additions & 1 deletion packages/core/src/login/webview/vue/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import { getLogger } from '../../../shared/logger/logger'
import { isValidUrl } from '../../../shared/utilities/uriUtils'
import { RegionProfile } from '../../../codewhisperer/models/model'
import { ProfileSwitchIntent } from '../../../codewhisperer/region/regionProfileManager'
import { showMessage } from '../../../shared/utilities/messages'

export abstract class CommonAuthWebview extends VueWebview {
private readonly className = 'CommonAuthWebview'
Expand Down Expand Up @@ -183,7 +184,7 @@ export abstract class CommonAuthWebview extends VueWebview {
abstract fetchConnections(): Promise<AwsConnection[] | undefined>

async errorNotification(e: AuthError) {
void vscode.window.showInformationMessage(`${e.text}`)
await showMessage('error', e.text)
}

abstract quitLoginScreen(): Promise<void>
Expand Down Expand Up @@ -296,6 +297,15 @@ export abstract class CommonAuthWebview extends VueWebview {
return globals.globalState.tryGet('recentSso', Object, { startUrl: '', region: 'us-east-1' })
}

getDefaultIamKeys(): { accessKey: string } {
const devSettings = DevSettings.instance.get('autofillAccessKey', '')
if (devSettings) {
return { accessKey: devSettings }
}

return globals.globalState.tryGet('recentIamKeys', Object, { accessKey: '' })
}

cancelAuthFlow() {
AuthSSOServer.lastInstance?.cancelCurrentFlow()
}
Expand Down
55 changes: 42 additions & 13 deletions packages/core/src/login/webview/vue/login.vue
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@

<template v-if="stage === 'AUTHENTICATING'">
<div class="auth-container-section">
<div v-if="app === 'TOOLKIT' && profileName.length > 0" class="header bottomMargin">
<div v-if="selectedLoginOption === LoginOption.IAM_CREDENTIAL" class="header bottomMargin">
Connecting to IAM...
</div>
<div v-else class="header bottomMargin">Authenticating in browser...</div>
Expand Down Expand Up @@ -273,7 +273,7 @@
<div class="title">Secret Access Key</div>
<input
class="iamInput bottomMargin"
type="text"
type="password"
id="secretKey"
name="secretKey"
v-model="secretKey"
Expand Down Expand Up @@ -330,6 +330,10 @@ interface ImportedLogin {
type: number
startUrl: string
region: string
// Add IAM credential fields
profileName?: string
accessKey?: string
secretKey?: string // Note: storing secrets has security implications
}

export default defineComponent({
Expand All @@ -349,6 +353,7 @@ export default defineComponent({
data() {
return {
existingStartUrls: [] as string[],
existingIamAccessKeys: [] as string[],
importedLogins: [] as ImportedLogin[],
selectedLoginOption: LoginOption.NONE,
stage: 'START' as Stage,
Expand All @@ -368,6 +373,8 @@ export default defineComponent({
const defaultSso = await this.getDefaultSso()
this.startUrl = defaultSso.startUrl
this.selectedRegion = defaultSso.region
const defaultIamAccessKey = await this.getDefaultIamAccessKey()
this.accessKey = defaultIamAccessKey.accessKey
await this.emitUpdate('created')
},

Expand Down Expand Up @@ -397,6 +404,10 @@ export default defineComponent({
}
},
handleDocumentClick(event: any) {
// Only reset selection when in START stage to avoid clearing during authentication
if (this.stage !== 'START') {
return
}
const isClickInsideSelectableItems = event.target.closest('.selectable-item')
if (!isClickInsideSelectableItems) {
this.selectedLoginOption = 0
Expand Down Expand Up @@ -437,17 +448,32 @@ export default defineComponent({
const selectedConnection =
this.importedLogins[this.selectedLoginOption - LoginOption.IMPORTED_LOGINS]

// Imported connections cannot be Builder IDs, they are filtered out in the client.
const error = await client.startEnterpriseSetup(
selectedConnection.startUrl,
selectedConnection.region,
this.app
)
if (error) {
this.stage = 'START'
void client.errorNotification(error)
} else {
this.stage = 'CONNECTED'
// Handle both SSO and IAM imported connections
if (selectedConnection.type === LoginOption.ENTERPRISE_SSO) {
const error = await client.startEnterpriseSetup(
selectedConnection.startUrl,
selectedConnection.region,
this.app
)
if (error) {
this.stage = 'START'
void client.errorNotification(error)
} else {
this.stage = 'CONNECTED'
}
} else if (selectedConnection.type === LoginOption.IAM_CREDENTIAL) {
// Use stored IAM credentials
const error = await client.startIamCredentialSetup(
selectedConnection.profileName || '',
selectedConnection.accessKey || '',
selectedConnection.secretKey || ''
)
if (error) {
this.stage = 'START'
void client.errorNotification(error)
} else {
this.stage = 'CONNECTED'
}
}
} else if (this.selectedLoginOption === LoginOption.IAM_CREDENTIAL) {
this.stage = 'AWS_PROFILE'
Expand Down Expand Up @@ -581,6 +607,9 @@ export default defineComponent({
async getDefaultSso() {
return await client.getDefaultSsoProfile()
},
async getDefaultIamAccessKey() {
return await client.getDefaultIamKeys()
},
handleHelpLinkClick() {
void client.emitUiClick('auth_helpLink')
},
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/shared/featureConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ export class FeatureConfigProvider {
}

async fetchFeatureConfigs(): Promise<void> {
if (AuthUtil.instance.isConnectionExpired()) {
if (AuthUtil.instance.isConnectionExpired() || AuthUtil.instance.isIamSession()) {
return
}

Expand Down
1 change: 1 addition & 0 deletions packages/core/src/shared/globalState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export type globalKey =
| 'lastOsStartTime'
| 'recentCredentials'
| 'recentSso'
| 'recentIamKeys'
// List of regions enabled in AWS Explorer.
| 'region'
// TODO: implement this via `PromptSettings` instead of globalState.
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/shared/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,7 @@ const devSettings = {
amazonqWorkspaceLsp: Record(String, String),
ssoCacheDirectory: String,
autofillStartUrl: String,
autofillAccessKey: String,
webAuth: Boolean,
notificationsPollInterval: Number,
}
Expand Down
14 changes: 7 additions & 7 deletions packages/core/src/test/amazonq/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,22 @@ export async function createSession({

const client = sinon.createStubInstance(FeatureDevClient)
client.createConversation.resolves(conversationID)
const session = new Session(sessionConfig, messenger, tabID, sessionState, client)
const sessionNew = new Session(sessionConfig, messenger, tabID, sessionState, client)

sinon.stub(session, 'conversationId').get(() => conversationID)
sinon.stub(session, 'uploadId').get(() => uploadID)
sinon.stub(sessionNew, 'conversationId').get(() => conversationID)
sinon.stub(sessionNew, 'uploadId').get(() => uploadID)

return session
return sessionNew
}

export async function sessionRegisterProvider(session: Session, uri: vscode.Uri, fileContents: Uint8Array) {
session.config.fs.registerProvider(uri, new VirtualMemoryFile(fileContents))
}

export function generateVirtualMemoryUri(uploadID: string, filePath: string, scheme: string) {
const generationFilePath = path.join(uploadID, filePath)
const uri = vscode.Uri.from({ scheme, path: generationFilePath })
return uri
const generationFilePathNew = path.join(uploadID, filePath)
const uriNew = vscode.Uri.from({ scheme, path: generationFilePathNew })
return uriNew
}

export async function sessionWriteFile(session: Session, uri: vscode.Uri, encodedContent: Uint8Array) {
Expand Down
Loading
Loading