Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion packages/core/src/auth/sso/ssoAccessTokenProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import { randomUUID } from '../../shared/crypto'
import { getExtRuntimeContext } from '../../shared/vscode/env'
import { showInputBox } from '../../shared/ui/inputPrompter'
import { AmazonQPromptSettings, DevSettings, PromptSettings, ToolkitPromptSettings } from '../../shared/settings'
import { onceChanged } from '../../shared/utilities/functionUtils'
import { debounce, onceChanged } from '../../shared/utilities/functionUtils'
import { NestedMap } from '../../shared/utilities/map'
import { asStringifiedStack } from '../../shared/telemetry/spans'
import { showViewLogsMessage } from '../../shared/utilities/messages'
Expand Down Expand Up @@ -97,7 +97,20 @@ export abstract class SsoAccessTokenProvider {
this.reAuthState.set(this.profile, { reAuthReason: `invalidate():${reason}` })
}

/**
* Sometimes we get many calls at once and this
* can trigger redundant disk reads, or token refreshes.
* We debounce to avoid this.
*
* NOTE: The property {@link getTokenDebounced()} does not work with being stubbed for tests, so
* this redundant function was created to work around that.
*/
public async getToken(): Promise<SsoToken | undefined> {
return this.getTokenDebounced()
}
private getTokenDebounced = debounce(() => this._getToken(), 50)
/** Exposed for testing purposes only */
public async _getToken(): Promise<SsoToken | undefined> {
const data = await this.cache.token.load(this.tokenCacheKey)
SsoAccessTokenProvider.logIfChanged(
indent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
*/

import assert from 'assert'
import * as FakeTimers from '@sinonjs/fake-timers'
import * as sinon from 'sinon'
import { SharedCredentialsProvider } from '../../../auth/providers/sharedCredentialsProvider'
import { stripUndefined } from '../../../shared/utilities/collectionUtils'
import * as process from '@aws-sdk/credential-provider-process'
import { ParsedIniData } from '@smithy/shared-ini-file-loader'
import { installFakeClock } from '../../testUtil'
import { SsoClient } from '../../../auth/sso/clients'
import { stub } from '../../utilities/stubber'
import { SsoAccessTokenProvider } from '../../../auth/sso/ssoAccessTokenProvider'
Expand All @@ -19,20 +17,13 @@ import { createTestSections } from '../testUtil'
const missingPropertiesFragment = 'missing properties'

describe('SharedCredentialsProvider', async function () {
let clock: FakeTimers.InstalledClock
let sandbox: sinon.SinonSandbox

before(function () {
sandbox = sinon.createSandbox()
clock = installFakeClock()
})

after(function () {
clock.uninstall()
})

afterEach(function () {
clock.reset()
sandbox.restore()
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { ToolkitError } from '../../../shared/errors'
import * as fs from 'fs' // eslint-disable-line no-restricted-imports
import * as path from 'path'
import { Stub, stub } from '../../utilities/stubber'
import { globals } from '../../../shared'

const hourInMs = 3600000

Expand All @@ -37,14 +38,14 @@ describe('SsoAccessTokenProvider', function () {
let oidcClient: Stub<OidcClient>
let sut: SsoAccessTokenProvider
let cache: ReturnType<typeof getCache>
let clock: FakeTimers.InstalledClock
let clock: FakeTimers.InstalledClock | undefined
let tempDir: string
let reAuthState: TestReAuthState

function createToken(timeDelta: number, extras: Partial<SsoToken> = {}) {
return {
accessToken: 'dummyAccessToken',
expiresAt: new clock.Date(clock.Date.now() + timeDelta),
expiresAt: new globals.clock.Date(globals.clock.Date.now() + timeDelta),
...extras,
}
}
Expand All @@ -54,7 +55,7 @@ describe('SsoAccessTokenProvider', function () {
scopes: [],
clientId: 'dummyClientId',
clientSecret: 'dummyClientSecret',
expiresAt: new clock.Date(clock.Date.now() + timeDelta),
expiresAt: new globals.clock.Date(globals.clock.Date.now() + timeDelta),
startUrl,
...extras,
}
Expand All @@ -66,7 +67,7 @@ describe('SsoAccessTokenProvider', function () {
deviceCode: 'dummyCode',
userCode: 'dummyUserCode',
verificationUri: 'dummyLink',
expiresAt: new clock.Date(clock.Date.now() + timeDelta),
expiresAt: new globals.clock.Date(globals.clock.Date.now() + timeDelta),
}
}

Expand All @@ -77,14 +78,6 @@ describe('SsoAccessTokenProvider', function () {
return cacheDir
}

before(function () {
clock = installFakeClock()
})

after(function () {
clock.uninstall()
})

beforeEach(async function () {
oidcClient = stub(OidcClient)
tempDir = await makeTemporaryTokenCacheFolder()
Expand All @@ -95,7 +88,7 @@ describe('SsoAccessTokenProvider', function () {

afterEach(async function () {
sinon.restore()
clock.reset()
clock?.uninstall()
await tryRemoveFolder(tempDir)
})

Expand Down Expand Up @@ -163,6 +156,20 @@ describe('SsoAccessTokenProvider', function () {
assert.strictEqual(cachedToken, undefined)
})

it('concurrent calls are debounced', async function () {
const validToken = createToken(hourInMs)
await cache.token.save(startUrl, { region, startUrl, token: validToken })
const actualGetToken = sinon.spy(sut, '_getToken')

const result = await Promise.all([sut.getToken(), sut.getToken(), sut.getToken()])

// Subsequent other calls were debounced so this was only called once
assert.strictEqual(actualGetToken.callCount, 1)
for (const r of result) {
assert.deepStrictEqual(r, validToken)
}
})

describe('Exceptions', function () {
it('drops expired tokens if failure was a client-fault', async function () {
const exception = new UnauthorizedClientException({ message: '', $metadata: {} })
Expand Down Expand Up @@ -267,6 +274,7 @@ describe('SsoAccessTokenProvider', function () {
})

it(`emits session duration between logins of the same startUrl`, async function () {
clock = installFakeClock()
setupFlow()
stubOpen()

Expand Down Expand Up @@ -311,6 +319,7 @@ describe('SsoAccessTokenProvider', function () {
})

it('respects the device authorization expiration time', async function () {
clock = installFakeClock()
setupFlow()
stubOpen()
const exception = new AuthorizationPendingException({ message: '', $metadata: {} })
Expand Down Expand Up @@ -352,7 +361,7 @@ describe('SsoAccessTokenProvider', function () {
const registration = {
clientId: 'myExpiredClientId',
clientSecret: 'myExpiredClientSecret',
expiresAt: new clock.Date(clock.Date.now() - 1), // expired date
expiresAt: new globals.clock.Date(globals.clock.Date.now() - 1), // expired date
startUrl: key.startUrl,
}
await cache.registration.save(key, registration)
Expand Down
Loading