Skip to content

Commit d8cce7c

Browse files
committed
fix(ssoAccessTokenProvider): use globalState abstraction
1 parent 68983d0 commit d8cce7c

File tree

4 files changed

+100
-31
lines changed

4 files changed

+100
-31
lines changed

packages/core/src/auth/sso/ssoAccessTokenProvider.ts

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ export abstract class SsoAccessTokenProvider {
107107
const access = await this.runFlow()
108108
const identity = this.tokenCacheKey
109109
await this.cache.token.save(identity, access)
110-
await setSessionCreationDate(this.tokenCacheKey, new Date())
110+
await globals.globalState.setSsoSessionCreationDate(this.tokenCacheKey, new globals.clock.Date())
111111

112112
return { ...access.token, identity }
113113
}
@@ -309,25 +309,13 @@ async function pollForTokenWithProgress<T extends { requestId?: string }>(
309309
)
310310
}
311311

312-
const sessionCreationDateKey = '#sessionCreationDates'
313-
async function setSessionCreationDate(id: string, date: Date, memento = globals.context.globalState) {
314-
try {
315-
await memento.update(sessionCreationDateKey, {
316-
...memento.get(sessionCreationDateKey),
317-
[id]: date.getTime(),
318-
})
319-
} catch (err) {
320-
getLogger().verbose('auth: failed to set session creation date: %s', err)
321-
}
322-
}
323-
324-
function getSessionCreationDate(id: string, memento = globals.context.globalState): number | undefined {
325-
return memento.get(sessionCreationDateKey, {} as Record<string, number>)[id]
326-
}
327-
328-
function getSessionDuration(id: string, memento = globals.context.globalState) {
329-
const creationDate = getSessionCreationDate(id, memento)
330-
312+
/**
313+
* Gets SSO session creation timestamp for the given session `id`.
314+
*
315+
* @param id Session id
316+
*/
317+
function getSessionDuration(id: string) {
318+
const creationDate = globals.globalState.getSsoSessionCreationDate(id)
331319
return creationDate !== undefined ? Date.now() - creationDate : undefined
332320
}
333321

packages/core/src/shared/extensionGlobals.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ export { globals as default }
170170
*/
171171
interface ToolkitGlobals {
172172
readonly context: ExtensionContext
173-
/** Global, shared, mutable, persisted state (survives IDE restart), namespaced to the extension (i.e. not shared with other vscode extensions). */
173+
/** Global, shared (with all vscode instances, including remote!), mutable, persisted state (survives IDE restart), namespaced to the extension (not shared with other vscode extensions). */
174174
readonly globalState: GlobalState
175175
/** Decides the prefix for package.json extension parameters, e.g. commands, 'setContext' values, etc. */
176176
contextPrefix: string

packages/core/src/shared/globalState.ts

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ type globalKey =
2020
| 'CODEWHISPERER_USER_GROUP'
2121
| 'gumby.wasQCodeTransformationUsed'
2222
| 'hasAlreadyOpenedAmazonQ'
23+
// Legacy name from `ssoAccessTokenProvider.ts`.
24+
| '#sessionCreationDates'
2325

2426
/**
2527
* Extension-local (not visible to other vscode extensions) shared state which persists after IDE
@@ -136,20 +138,66 @@ export class GlobalState implements vscode.Memento {
136138
* explorer node is not refreshed yet.
137139
*/
138140
getRedshiftConnection(warehouseArn: string): redshift.ConnectionParams | undefined | 'DELETE_CONNECTION' {
139-
const allCxns = this.tryGet(
141+
const all = this.tryGet<Record<string, redshift.ConnectionParams | 'DELETE_CONNECTION'>>(
140142
'aws.redshift.connections',
141143
v => {
142144
if (v !== undefined && typeof v !== 'object') {
143145
throw new Error()
144146
}
145-
const cxn = (v as any)?.[warehouseArn]
146-
if (cxn !== undefined && typeof cxn !== 'object' && cxn !== 'DELETE_CONNECTION') {
147+
const item = (v as any)?.[warehouseArn]
148+
// Requested item must be object or 'DELETE_CONNECTION'.
149+
if (item !== undefined && typeof item !== 'object' && item !== 'DELETE_CONNECTION') {
147150
throw new Error()
148151
}
149152
return v
150153
},
151154
undefined
152155
)
153-
return (allCxns as any)?.[warehouseArn]
156+
return all?.[warehouseArn]
157+
}
158+
159+
/**
160+
* Sets SSO session creation timestamp for the given session `id`.
161+
*
162+
* TODO: this never garbage-collects old connections, so the state will grow forever...
163+
*
164+
* @param id Session id
165+
* @param date Session timestamp
166+
*/
167+
async setSsoSessionCreationDate(id: string, date: Date) {
168+
try {
169+
const all = this.tryGet('#sessionCreationDates', Object, {})
170+
// TODO: race condition...
171+
await this.update('#sessionCreationDates', {
172+
...all,
173+
[id]: date.getTime(),
174+
})
175+
} catch (err) {
176+
getLogger().error('auth: failed to set session creation date: %O', err)
177+
}
178+
}
179+
180+
/**
181+
* Gets SSO session creation timestamp for the given session `id`.
182+
*
183+
* @param id Session id
184+
*/
185+
getSsoSessionCreationDate(id: string): number | undefined {
186+
const all = this.tryGet<Record<string, number>>(
187+
'#sessionCreationDates',
188+
v => {
189+
if (v !== undefined && typeof v !== 'object') {
190+
throw new Error()
191+
}
192+
const item = (v as any)?.[id]
193+
// Requested item must be a number.
194+
if (item !== undefined && typeof item !== 'number') {
195+
throw new Error()
196+
}
197+
return v
198+
},
199+
undefined
200+
)
201+
return all?.[id]
154202
}
155203
}

packages/core/src/test/shared/globalState.test.ts

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ describe('GlobalState', function () {
9898
})
9999

100100
describe('redshift state', function () {
101+
const testArn1 = 'arn:foo/bar/baz/1'
102+
const testArn2 = 'arn:foo/bar/baz/2'
103+
101104
const fakeCxn1: redshift.ConnectionParams = {
102105
connectionType: redshift.ConnectionType.SecretsManager,
103106
database: 'fake-db',
@@ -119,8 +122,6 @@ describe('GlobalState', function () {
119122
}
120123

121124
it('get/set connection state and special DELETE_CONNECTION value', async () => {
122-
const testArn1 = 'arn:foo/bar/baz/1'
123-
const testArn2 = 'arn:foo/bar/baz/2'
124125
await globalState.saveRedshiftConnection(testArn1, 'DELETE_CONNECTION')
125126
await globalState.saveRedshiftConnection(testArn2, undefined)
126127
assert.deepStrictEqual(globalState.getRedshiftConnection(testArn1), 'DELETE_CONNECTION')
@@ -131,9 +132,7 @@ describe('GlobalState', function () {
131132
assert.deepStrictEqual(globalState.getRedshiftConnection(testArn2), fakeCxn2)
132133
})
133134

134-
it('validates state', async () => {
135-
const testArn1 = 'arn:foo/bar/baz/1'
136-
const testArn2 = 'arn:foo/bar/baz/2'
135+
it('validation', async () => {
137136
await globalState.saveRedshiftConnection(testArn1, 'foo' as any)
138137
await globalState.saveRedshiftConnection(testArn2, 99 as any)
139138

@@ -146,14 +145,48 @@ describe('GlobalState', function () {
146145
// Bad state is logged and returns undefined.
147146
assert.deepStrictEqual(globalState.getRedshiftConnection(testArn1), undefined)
148147
assert.deepStrictEqual(globalState.getRedshiftConnection(testArn2), undefined)
148+
149149
await globalState.saveRedshiftConnection(testArn2, fakeCxn2)
150150
assert.deepStrictEqual(globalState.getRedshiftConnection(testArn2), fakeCxn2)
151-
152151
// Stored state is now "partially bad".
153152
assert.deepStrictEqual(globalState.get('aws.redshift.connections'), {
154153
[testArn1]: 'foo',
155154
[testArn2]: fakeCxn2,
156155
})
157156
})
158157
})
158+
159+
describe('SSO sessions', function () {
160+
const session1 = 'session-1'
161+
const session2 = 'session-2'
162+
const time1 = new Date(Date.now() - 42 * 1000) // in the past.
163+
const time2 = new Date()
164+
165+
it('get/set', async () => {
166+
await globalState.setSsoSessionCreationDate(session1, time1)
167+
await globalState.setSsoSessionCreationDate(session2, time2)
168+
assert.deepStrictEqual(globalState.getSsoSessionCreationDate(session1), time1.getTime())
169+
assert.deepStrictEqual(globalState.getSsoSessionCreationDate(session2), time2.getTime())
170+
})
171+
172+
it('validation', async () => {
173+
// Set bad state.
174+
await globalState.update('#sessionCreationDates', {
175+
[session1]: 'foo',
176+
[session2]: {},
177+
})
178+
179+
// Bad state is logged and returns undefined.
180+
assert.deepStrictEqual(globalState.getSsoSessionCreationDate(session1), undefined)
181+
assert.deepStrictEqual(globalState.getSsoSessionCreationDate(session2), undefined)
182+
183+
await globalState.setSsoSessionCreationDate(session2, time2)
184+
assert.deepStrictEqual(globalState.getSsoSessionCreationDate(session2), time2.getTime())
185+
// Stored state is now "partially bad".
186+
assert.deepStrictEqual(globalState.get('#sessionCreationDates'), {
187+
[session1]: 'foo',
188+
[session2]: time2.getTime(),
189+
})
190+
})
191+
})
159192
})

0 commit comments

Comments
 (0)