Skip to content

Commit 42e38c6

Browse files
authored
fix(amazonq): fix iam credential update logic to use custom comparator and added buffer time in cred validation (aws#8070)
## Problem - IAM credentials is not updating in Sagemaker instances due to incorrect comparison logic which prevents credential refresh and hence users cant interact with Q chat after the initial expiration time - <img width="2513" height="1284" alt="image" src="https://github.com/user-attachments/assets/0dc8d158-00ef-4c86-aff8-7e147a101881" /> ## Solution - Add custom comparator logic and method to properly compare credentials by their actual values (accessKeyId, secretAccessKey, sessionToken) instead of string comparison - Added 60-second expiration buffer to credential validation. similar to SSO token logic [here](https://github.com/aws/aws-toolkit-vscode/blob/c3685274fc4e09e72c98db4c43b7959634bc63b0/packages/core/src/auth/sso/model.ts#L158) for grace-time - Tested by building a debug artefact on a SMUS CodeEditor instance and verified q chat is triggering refresh credentials - ```npm run package && npm run test``` succeeded - https://drive.corp.amazon.com/documents/parameja@/PR-8070/IAM-Credentials-Refresh-Q-Chat.mov --- - Treat all work as PUBLIC. Private `feature/x` branches will not be squash-merged at release time. - Your code changes must meet the guidelines in [CONTRIBUTING.md](https://github.com/aws/aws-toolkit-vscode/blob/master/CONTRIBUTING.md#guidelines). - License: I confirm that my contribution is made under the terms of the Apache 2.0 license.
1 parent c368527 commit 42e38c6

File tree

8 files changed

+94
-6
lines changed

8 files changed

+94
-6
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "Bug Fix",
3+
"description": "Amazon Q automatically refreshes expired IAM Credentials in Sagemaker instances"
4+
}

packages/amazonq/src/lsp/auth.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ import * as crypto from 'crypto'
1717
import { LanguageClient } from 'vscode-languageclient'
1818
import { AuthUtil } from 'aws-core-vscode/codewhisperer'
1919
import { Writable } from 'stream'
20-
import { onceChanged } from 'aws-core-vscode/utils'
20+
import { onceChanged, onceChangedWithComparator } from 'aws-core-vscode/utils'
2121
import { getLogger, oneMinute, isSageMaker } from 'aws-core-vscode/shared'
22-
import { isSsoConnection, isIamConnection } from 'aws-core-vscode/auth'
22+
import { isSsoConnection, isIamConnection, areCredentialsEqual } from 'aws-core-vscode/auth'
2323

2424
export const encryptionKey = crypto.randomBytes(32)
2525

@@ -108,7 +108,10 @@ export class AmazonQLspAuth {
108108
this.client.info(`UpdateBearerToken: ${JSON.stringify(request)}`)
109109
}
110110

111-
public updateIamCredentials = onceChanged(this._updateIamCredentials.bind(this))
111+
public updateIamCredentials = onceChangedWithComparator(
112+
this._updateIamCredentials.bind(this),
113+
([prevCreds], [currentCreds]) => areCredentialsEqual(prevCreds, currentCreds)
114+
)
112115
private async _updateIamCredentials(credentials: any) {
113116
getLogger().info(
114117
`[SageMaker Debug] Updating IAM credentials - credentials received: ${credentials ? 'YES' : 'NO'}`

packages/core/src/auth/auth.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,7 @@ export class Auth implements AuthService, ConnectionManager {
862862

863863
private async createCachedCredentials(provider: CredentialsProvider) {
864864
const providerId = provider.getCredentialsId()
865+
getLogger().debug(`credentials: create cache credentials for ${provider.getProviderType()}`)
865866
globals.loginManager.store.invalidateCredentials(providerId)
866867
const { credentials, endpointUrl } = await globals.loginManager.store.upsertCredentials(providerId, provider)
867868
await globals.loginManager.validateCredentials(credentials, endpointUrl, provider.getDefaultRegion())

packages/core/src/auth/connection.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,18 @@ export const isBuilderIdConnection = (conn?: Connection): conn is SsoConnection
7171
export const isValidCodeCatalystConnection = (conn?: Connection): conn is SsoConnection =>
7272
isSsoConnection(conn) && hasScopes(conn, scopesCodeCatalyst)
7373

74+
export const areCredentialsEqual = (creds1: any, creds2: any): boolean => {
75+
if (!creds1 || !creds2) {
76+
return creds1 === creds2
77+
}
78+
79+
return (
80+
creds1.accessKeyId === creds2.accessKeyId &&
81+
creds1.secretAccessKey === creds2.secretAccessKey &&
82+
creds1.sessionToken === creds2.sessionToken
83+
)
84+
}
85+
7486
export function hasScopes(target: SsoConnection | SsoProfile | string[], scopes: string[]): boolean {
7587
return scopes?.every((s) => (Array.isArray(target) ? target : target.scopes)?.includes(s))
7688
}

packages/core/src/auth/credentials/store.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,16 @@ export class CredentialsStore {
3131
* If the expiration property does not exist, it is assumed to never expire.
3232
*/
3333
public isValid(key: string): boolean {
34+
// Apply 60-second buffer similar to SSO token expiry logic
35+
const expirationBufferMs = 60000
36+
3437
if (this.credentialsCache[key]) {
3538
const expiration = this.credentialsCache[key].credentials.expiration
36-
return expiration !== undefined ? expiration >= new globals.clock.Date() : true
39+
const now = new globals.clock.Date()
40+
const bufferedNow = new globals.clock.Date(now.getTime() + expirationBufferMs)
41+
return expiration !== undefined ? expiration >= bufferedNow : true
3742
}
38-
43+
getLogger().debug(`credentials: no credentials found for ${key}`)
3944
return false
4045
}
4146

packages/core/src/auth/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export {
1919
getTelemetryMetadataForConn,
2020
isIamConnection,
2121
isSsoConnection,
22+
areCredentialsEqual,
2223
} from './connection'
2324
export { Auth } from './auth'
2425
export { CredentialsStore } from './credentials/store'

packages/core/src/shared/utilities/functionUtils.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,32 @@ export function onceChanged<T, U extends any[]>(fn: (...args: U) => T): (...args
6363
: ((val = fn(...args)), (ran = true), (prevArgs = args.map(String).join(':')), val)
6464
}
6565

66+
/**
67+
* Creates a function that runs only if the args changed versus the previous invocation,
68+
* using a custom comparator function for argument comparison.
69+
*
70+
* @param fn The function to wrap
71+
* @param comparator Function that returns true if arguments are equal
72+
*/
73+
export function onceChangedWithComparator<T, U extends any[]>(
74+
fn: (...args: U) => T,
75+
comparator: (prev: U, current: U) => boolean
76+
): (...args: U) => T {
77+
let val: T
78+
let ran = false
79+
let prevArgs: U
80+
81+
return (...args) => {
82+
if (ran && comparator(prevArgs, args)) {
83+
return val
84+
}
85+
val = fn(...args)
86+
ran = true
87+
prevArgs = args
88+
return val
89+
}
90+
}
91+
6692
/**
6793
* Creates a new function that stores the result of a call.
6894
*

packages/core/src/test/shared/utilities/functionUtils.test.ts

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
*/
55

66
import assert from 'assert'
7-
import { once, onceChanged, debounce, oncePerUniqueArg } from '../../../shared/utilities/functionUtils'
7+
import {
8+
once,
9+
onceChanged,
10+
debounce,
11+
oncePerUniqueArg,
12+
onceChangedWithComparator,
13+
} from '../../../shared/utilities/functionUtils'
814
import { installFakeClock } from '../../testUtil'
915

1016
describe('functionUtils', function () {
@@ -49,6 +55,36 @@ describe('functionUtils', function () {
4955
assert.strictEqual(counter, 3)
5056
})
5157

58+
it('onceChangedWithComparator()', function () {
59+
let counter = 0
60+
const credentialsEqual = ([prev]: [any], [current]: [any]) => {
61+
if (!prev && !current) {
62+
return true
63+
}
64+
if (!prev || !current) {
65+
return false
66+
}
67+
return prev.accessKeyId === current.accessKeyId && prev.secretAccessKey === current.secretAccessKey
68+
}
69+
const fn = onceChangedWithComparator((creds: any) => void counter++, credentialsEqual)
70+
71+
const creds1 = { accessKeyId: 'key1', secretAccessKey: 'secret1' }
72+
const creds2 = { accessKeyId: 'key1', secretAccessKey: 'secret1' }
73+
const creds3 = { accessKeyId: 'key2', secretAccessKey: 'secret2' }
74+
75+
fn(creds1)
76+
assert.strictEqual(counter, 1)
77+
78+
fn(creds2) // Same values, should not execute
79+
assert.strictEqual(counter, 1)
80+
81+
fn(creds3) // Different values, should execute
82+
assert.strictEqual(counter, 2)
83+
84+
fn(creds3) // Same as previous, should not execute
85+
assert.strictEqual(counter, 2)
86+
})
87+
5288
it('oncePerUniqueArg()', function () {
5389
let counter = 0
5490
const fn = oncePerUniqueArg((s: string) => {

0 commit comments

Comments
 (0)