diff --git a/packages/amazonq/.changes/next-release/Bug Fix-da0e805d-faab-4274-b37a-943c7263e42b.json b/packages/amazonq/.changes/next-release/Bug Fix-da0e805d-faab-4274-b37a-943c7263e42b.json new file mode 100644 index 00000000000..3bb0428ebac --- /dev/null +++ b/packages/amazonq/.changes/next-release/Bug Fix-da0e805d-faab-4274-b37a-943c7263e42b.json @@ -0,0 +1,4 @@ +{ + "type": "Bug Fix", + "description": "Amazon Q automatically refreshes expired IAM Credentials in Sagemaker instances" +} diff --git a/packages/amazonq/src/lsp/auth.ts b/packages/amazonq/src/lsp/auth.ts index 161ba4d9762..5dfc269f5ef 100644 --- a/packages/amazonq/src/lsp/auth.ts +++ b/packages/amazonq/src/lsp/auth.ts @@ -17,9 +17,9 @@ import * as crypto from 'crypto' import { LanguageClient } from 'vscode-languageclient' import { AuthUtil } from 'aws-core-vscode/codewhisperer' import { Writable } from 'stream' -import { onceChanged } from 'aws-core-vscode/utils' +import { onceChanged, onceChangedWithComparator } from 'aws-core-vscode/utils' import { getLogger, oneMinute, isSageMaker } from 'aws-core-vscode/shared' -import { isSsoConnection, isIamConnection } from 'aws-core-vscode/auth' +import { isSsoConnection, isIamConnection, areCredentialsEqual } from 'aws-core-vscode/auth' export const encryptionKey = crypto.randomBytes(32) @@ -108,7 +108,10 @@ export class AmazonQLspAuth { this.client.info(`UpdateBearerToken: ${JSON.stringify(request)}`) } - public updateIamCredentials = onceChanged(this._updateIamCredentials.bind(this)) + public updateIamCredentials = onceChangedWithComparator( + this._updateIamCredentials.bind(this), + ([prevCreds], [currentCreds]) => areCredentialsEqual(prevCreds, currentCreds) + ) private async _updateIamCredentials(credentials: any) { getLogger().info( `[SageMaker Debug] Updating IAM credentials - credentials received: ${credentials ? 'YES' : 'NO'}` diff --git a/packages/core/src/auth/auth.ts b/packages/core/src/auth/auth.ts index 9a050c9f185..053df50321e 100644 --- a/packages/core/src/auth/auth.ts +++ b/packages/core/src/auth/auth.ts @@ -862,6 +862,7 @@ export class Auth implements AuthService, ConnectionManager { private async createCachedCredentials(provider: CredentialsProvider) { const providerId = provider.getCredentialsId() + getLogger().debug(`credentials: create cache credentials for ${provider.getProviderType()}`) globals.loginManager.store.invalidateCredentials(providerId) const { credentials, endpointUrl } = await globals.loginManager.store.upsertCredentials(providerId, provider) await globals.loginManager.validateCredentials(credentials, endpointUrl, provider.getDefaultRegion()) diff --git a/packages/core/src/auth/connection.ts b/packages/core/src/auth/connection.ts index fca46da53e4..fea929fc8af 100644 --- a/packages/core/src/auth/connection.ts +++ b/packages/core/src/auth/connection.ts @@ -71,6 +71,18 @@ export const isBuilderIdConnection = (conn?: Connection): conn is SsoConnection export const isValidCodeCatalystConnection = (conn?: Connection): conn is SsoConnection => isSsoConnection(conn) && hasScopes(conn, scopesCodeCatalyst) +export const areCredentialsEqual = (creds1: any, creds2: any): boolean => { + if (!creds1 || !creds2) { + return creds1 === creds2 + } + + return ( + creds1.accessKeyId === creds2.accessKeyId && + creds1.secretAccessKey === creds2.secretAccessKey && + creds1.sessionToken === creds2.sessionToken + ) +} + export function hasScopes(target: SsoConnection | SsoProfile | string[], scopes: string[]): boolean { return scopes?.every((s) => (Array.isArray(target) ? target : target.scopes)?.includes(s)) } diff --git a/packages/core/src/auth/credentials/store.ts b/packages/core/src/auth/credentials/store.ts index d99595a3877..9fd73c9130c 100644 --- a/packages/core/src/auth/credentials/store.ts +++ b/packages/core/src/auth/credentials/store.ts @@ -31,11 +31,16 @@ export class CredentialsStore { * If the expiration property does not exist, it is assumed to never expire. */ public isValid(key: string): boolean { + // Apply 60-second buffer similar to SSO token expiry logic + const expirationBufferMs = 60000 + if (this.credentialsCache[key]) { const expiration = this.credentialsCache[key].credentials.expiration - return expiration !== undefined ? expiration >= new globals.clock.Date() : true + const now = new globals.clock.Date() + const bufferedNow = new globals.clock.Date(now.getTime() + expirationBufferMs) + return expiration !== undefined ? expiration >= bufferedNow : true } - + getLogger().debug(`credentials: no credentials found for ${key}`) return false } diff --git a/packages/core/src/auth/index.ts b/packages/core/src/auth/index.ts index c180d603c67..a5a3ca0edd9 100644 --- a/packages/core/src/auth/index.ts +++ b/packages/core/src/auth/index.ts @@ -19,6 +19,7 @@ export { getTelemetryMetadataForConn, isIamConnection, isSsoConnection, + areCredentialsEqual, } from './connection' export { Auth } from './auth' export { CredentialsStore } from './credentials/store' diff --git a/packages/core/src/shared/utilities/functionUtils.ts b/packages/core/src/shared/utilities/functionUtils.ts index 214721b1cdb..fa0e61847bb 100644 --- a/packages/core/src/shared/utilities/functionUtils.ts +++ b/packages/core/src/shared/utilities/functionUtils.ts @@ -63,6 +63,32 @@ export function onceChanged(fn: (...args: U) => T): (...args : ((val = fn(...args)), (ran = true), (prevArgs = args.map(String).join(':')), val) } +/** + * Creates a function that runs only if the args changed versus the previous invocation, + * using a custom comparator function for argument comparison. + * + * @param fn The function to wrap + * @param comparator Function that returns true if arguments are equal + */ +export function onceChangedWithComparator( + fn: (...args: U) => T, + comparator: (prev: U, current: U) => boolean +): (...args: U) => T { + let val: T + let ran = false + let prevArgs: U + + return (...args) => { + if (ran && comparator(prevArgs, args)) { + return val + } + val = fn(...args) + ran = true + prevArgs = args + return val + } +} + /** * Creates a new function that stores the result of a call. * diff --git a/packages/core/src/test/shared/utilities/functionUtils.test.ts b/packages/core/src/test/shared/utilities/functionUtils.test.ts index b675fe74feb..3ba11518414 100644 --- a/packages/core/src/test/shared/utilities/functionUtils.test.ts +++ b/packages/core/src/test/shared/utilities/functionUtils.test.ts @@ -4,7 +4,13 @@ */ import assert from 'assert' -import { once, onceChanged, debounce, oncePerUniqueArg } from '../../../shared/utilities/functionUtils' +import { + once, + onceChanged, + debounce, + oncePerUniqueArg, + onceChangedWithComparator, +} from '../../../shared/utilities/functionUtils' import { installFakeClock } from '../../testUtil' describe('functionUtils', function () { @@ -49,6 +55,36 @@ describe('functionUtils', function () { assert.strictEqual(counter, 3) }) + it('onceChangedWithComparator()', function () { + let counter = 0 + const credentialsEqual = ([prev]: [any], [current]: [any]) => { + if (!prev && !current) { + return true + } + if (!prev || !current) { + return false + } + return prev.accessKeyId === current.accessKeyId && prev.secretAccessKey === current.secretAccessKey + } + const fn = onceChangedWithComparator((creds: any) => void counter++, credentialsEqual) + + const creds1 = { accessKeyId: 'key1', secretAccessKey: 'secret1' } + const creds2 = { accessKeyId: 'key1', secretAccessKey: 'secret1' } + const creds3 = { accessKeyId: 'key2', secretAccessKey: 'secret2' } + + fn(creds1) + assert.strictEqual(counter, 1) + + fn(creds2) // Same values, should not execute + assert.strictEqual(counter, 1) + + fn(creds3) // Different values, should execute + assert.strictEqual(counter, 2) + + fn(creds3) // Same as previous, should not execute + assert.strictEqual(counter, 2) + }) + it('oncePerUniqueArg()', function () { let counter = 0 const fn = oncePerUniqueArg((s: string) => {