Skip to content

Commit e8a8c59

Browse files
vpbhargavBhargava Varadharajan
andauthored
fix(smsus): Update DER cred expiry time and project pick (aws#2206)
**Description** Reduced the DER cred expiry time to 10 min default. The API is being updated as well. Also updated the auth logic to invoke project picker on sign in, re-auth. **Testing Done** Unit tests, tested manually on VSCode as well - The signin, re-auth and sign-out cases. --- - Treat all work as PUBLIC. Private `feature/x` branches will not be squash-merged at release time. - Your code changes must meet the guidelines in [CONTRIBUTING.md](https://github.com/aws/aws-toolkit-vscode/blob/master/CONTRIBUTING.md#guidelines). - License: I confirm that my contribution is made under the terms of the Apache 2.0 license. Co-authored-by: Bhargava Varadharajan <[email protected]>
1 parent 8ce40c2 commit e8a8c59

File tree

6 files changed

+161
-12
lines changed

6 files changed

+161
-12
lines changed

packages/core/src/sagemakerunifiedstudio/auth/providers/domainExecRoleCredentialsProvider.ts

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import { SmusCredentialExpiry, SmusTimeouts, SmusErrorCodes, validateCredentialF
1616
* Credentials provider for SageMaker Unified Studio Domain Execution Role (DER)
1717
* Uses SSO tokens to get DER credentials via the /sso/redeem-token endpoint
1818
*
19-
* This provider implements internal caching with 55-minute expiry and handles
19+
* This provider implements internal caching with 10-minute expiry and handles
2020
* its own credential lifecycle independently
2121
*/
2222
export class DomainExecRoleCredentialsProvider implements CredentialsProvider {
@@ -122,7 +122,7 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider {
122122
public async getCredentials(): Promise<AWS.Credentials> {
123123
this.logger.debug(`SMUS DER: Getting DER credentials for domain ${this.domainId}`)
124124

125-
// Check cache first (55-minute expiry with 5-minute buffer for proactive refresh)
125+
// Check cache first (10-minute expiry with 5-minute buffer for proactive refresh)
126126
if (this.credentialCache && this.credentialCache.expiresAt > new Date()) {
127127
this.logger.debug(`SMUS DER: Using cached DER credentials for domain ${this.domainId}`)
128128
return this.credentialCache.credentials
@@ -207,6 +207,7 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider {
207207
accessKeyId: string
208208
secretAccessKey: string
209209
sessionToken: string
210+
expiration: string
210211
}
211212
}
212213

@@ -225,9 +226,23 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider {
225226
validateCredentialFields(credentials, 'InvalidCredentialResponse', 'API response')
226227

227228
// Create credentials with expiration
228-
// Note: The response doesn't include expiration, so we set it to 55 minutes from now
229-
// TODO: Update when the API provides actual expiration time
230-
const credentialExpiresAt = new globals.clock.Date(Date.now() + SmusCredentialExpiry.derExpiryMs)
229+
// Note: The response doesn't include expiration yet, so we set it to 10 minutes for now if it does't exist
230+
let credentialExpiresAt: Date
231+
if (credentials.expiration) {
232+
// The API returns expiration as a string, convert to Date
233+
const parsedExpiration = new Date(credentials.expiration)
234+
// Check if the parsed date is valid
235+
if (isNaN(parsedExpiration.getTime())) {
236+
this.logger.warn(
237+
`SMUS DER: Invalid expiration date string: ${credentials.expiration}, using default expiration`
238+
)
239+
credentialExpiresAt = new Date(Date.now() + SmusCredentialExpiry.derExpiryMs)
240+
} else {
241+
credentialExpiresAt = parsedExpiration
242+
}
243+
} else {
244+
credentialExpiresAt = new Date(Date.now() + SmusCredentialExpiry.derExpiryMs)
245+
}
231246

232247
const awsCredentials: AWS.Credentials = {
233248
accessKeyId: credentials.accessKeyId as string,
@@ -236,7 +251,7 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider {
236251
expiration: credentialExpiresAt,
237252
}
238253

239-
// Cache DER credentials with 55-minute expiry (5-minute buffer for proactive refresh)
254+
// Cache DER credentials with 10-minute expiry (5-minute buffer for proactive refresh)
240255
const cacheExpiresAt = new globals.clock.Date(Date.now() + SmusCredentialExpiry.derExpiryMs)
241256
this.credentialCache = {
242257
credentials: awsCredentials,

packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,12 @@ export class SmusAuthenticationProvider {
175175

176176
// Use the existing connection
177177
const result = await this.secondaryAuth.useNewConnection(existingConn)
178+
179+
// Auto-invoke project selection after successful sign-in (but not in SMUS space environment)
180+
if (!SmusUtils.isInSmusSpaceEnvironment()) {
181+
void vscode.commands.executeCommand('aws.smus.switchProject')
182+
}
183+
178184
return result
179185
}
180186

@@ -192,6 +198,12 @@ export class SmusAuthenticationProvider {
192198

193199
const result = await this.secondaryAuth.useNewConnection(smusConn)
194200
logger.debug(`SMUS: Reauthenticated connection successfully, id=${result.id}`)
201+
202+
// Auto-invoke project selection after successful reauthentication (but not in SMUS space environment)
203+
if (!SmusUtils.isInSmusSpaceEnvironment()) {
204+
void vscode.commands.executeCommand('aws.smus.switchProject')
205+
}
206+
195207
return result
196208
}
197209
}
@@ -214,6 +226,12 @@ export class SmusAuthenticationProvider {
214226
}
215227

216228
const result = await this.secondaryAuth.useNewConnection(smusConn)
229+
230+
// Auto-invoke project selection after successful sign-in (but not in SMUS space environment)
231+
if (!SmusUtils.isInSmusSpaceEnvironment()) {
232+
void vscode.commands.executeCommand('aws.smus.switchProject')
233+
}
234+
217235
return result
218236
} catch (e) {
219237
throw ToolkitError.chain(e, 'Failed to connect to SageMaker Unified Studio', {

packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ interface DataZoneSsoLoginResponse {
3030
* Credential expiry time constants for SMUS providers (in milliseconds)
3131
*/
3232
export const SmusCredentialExpiry = {
33-
/** Domain Execution Role (DER) credentials expiry time: 55 minutes */
34-
derExpiryMs: 55 * 60 * 1000,
33+
/** Domain Execution Role (DER) credentials expiry time: 10 minutes */
34+
derExpiryMs: 10 * 60 * 1000,
3535
/** Project Role credentials expiry time: 10 minutes */
3636
projectExpiryMs: 10 * 60 * 1000,
3737
/** Connection credentials expiry time: 10 minutes */

packages/core/src/test/sagemakerunifiedstudio/auth/domainExecRoleCredentialsProvider.test.ts

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,40 @@ describe('DomainExecRoleCredentialsProvider', function () {
303303
it('should set default expiration when not provided in response', async function () {
304304
const credentials = await derProvider.getCredentials()
305305

306-
// Should have expiration set to 55 mins from now
306+
// Should have expiration set to 10 mins from now
307307
assert.ok(credentials.expiration)
308308
const expirationTime = credentials.expiration!.getTime()
309-
const expectedTime = Date.now() + 55 * 60 * 1000 // 1 hour
309+
const expectedTime = Date.now() + 10 * 60 * 1000 // 10 minutes
310310
const timeDiff = Math.abs(expirationTime - expectedTime)
311-
assert.ok(timeDiff < 5000, 'Expiration should be 55 mins from now')
311+
assert.ok(timeDiff < 5000, 'Expiration should be 10 mins from now')
312+
})
313+
314+
it('should use expiration from API response when provided', async function () {
315+
const futureExpiration = new Date(Date.now() + 2 * 60 * 60 * 1000) // 2 hours from now
316+
const responseWithExpiration = {
317+
credentials: {
318+
accessKeyId: 'AKIA-DER-KEY',
319+
secretAccessKey: 'der-secret-key',
320+
sessionToken: 'der-session-token',
321+
expiration: futureExpiration.toISOString(), // API returns expiration as ISO string
322+
},
323+
}
324+
325+
fetchStub.resolves({
326+
ok: true,
327+
status: 200,
328+
statusText: 'OK',
329+
json: sinon.stub().resolves(responseWithExpiration),
330+
} as any)
331+
332+
const credentials = await derProvider.getCredentials()
333+
334+
// Should use the expiration from the API response
335+
assert.ok(credentials.expiration)
336+
const expirationTime = credentials.expiration!.getTime()
337+
const expectedTime = futureExpiration.getTime()
338+
const timeDiff = Math.abs(expirationTime - expectedTime)
339+
assert.ok(timeDiff < 1000, 'Should use expiration from API response')
312340
})
313341

314342
it('should handle JSON parsing errors', async function () {
@@ -326,6 +354,38 @@ describe('DomainExecRoleCredentialsProvider', function () {
326354
}
327355
)
328356
})
357+
358+
it('should handle invalid expiration string in response', async function () {
359+
const responseWithInvalidExpiration = {
360+
credentials: {
361+
accessKeyId: 'AKIA-DER-KEY',
362+
secretAccessKey: 'der-secret-key',
363+
sessionToken: 'der-session-token',
364+
expiration: 'invalid-date-string', // Invalid date string
365+
},
366+
}
367+
368+
fetchStub.resolves({
369+
ok: true,
370+
status: 200,
371+
statusText: 'OK',
372+
json: sinon.stub().resolves(responseWithInvalidExpiration),
373+
} as any)
374+
375+
const credentials = await derProvider.getCredentials()
376+
377+
// Should fall back to default expiration when date parsing fails
378+
assert.ok(credentials.expiration)
379+
const expirationTime = credentials.expiration!.getTime()
380+
381+
// Should be a valid timestamp (not NaN) using the default expiration
382+
assert.ok(!isNaN(expirationTime), 'Should have valid expiration timestamp')
383+
384+
// Should be close to now + 10 minutes (default expiration)
385+
const expectedTime = Date.now() + 10 * 60 * 1000
386+
const timeDiff = Math.abs(expirationTime - expectedTime)
387+
assert.ok(timeDiff < 5000, 'Should fall back to default expiration for invalid date string')
388+
})
329389
})
330390

331391
describe('invalidate', function () {

packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import assert from 'assert'
77
import sinon from 'sinon'
8+
import * as vscode from 'vscode'
89

910
// Mock the setContext function BEFORE importing modules that use it
1011
const setContextModule = require('../../../shared/vscode/setContext')
@@ -24,6 +25,8 @@ describe('SmusAuthenticationProvider', function () {
2425
let smusAuthProvider: SmusAuthenticationProvider
2526
let extractDomainInfoStub: sinon.SinonStub
2627
let getSsoInstanceInfoStub: sinon.SinonStub
28+
let isInSmusSpaceEnvironmentStub: sinon.SinonStub
29+
let executeCommandStub: sinon.SinonStub
2730
let mockSecondaryAuthState: {
2831
activeConnection: SmusConnection | undefined
2932
hasSavedConnection: boolean
@@ -94,9 +97,14 @@ describe('SmusAuthenticationProvider', function () {
9497
.stub(SmusUtils, 'extractDomainInfoFromUrl')
9598
.returns({ domainId: testDomainId, region: testRegion })
9699
getSsoInstanceInfoStub = sinon.stub(SmusUtils, 'getSsoInstanceInfo').resolves(testSsoInstanceInfo)
100+
isInSmusSpaceEnvironmentStub = sinon.stub(SmusUtils, 'isInSmusSpaceEnvironment').returns(false)
101+
executeCommandStub = sinon.stub(vscode.commands, 'executeCommand').resolves()
97102
sinon.stub(require('../../../auth/secondaryAuth'), 'getSecondaryAuth').returns(mockSecondaryAuth)
98103

99104
smusAuthProvider = new SmusAuthenticationProvider(mockAuth, mockSecondaryAuth)
105+
106+
// Reset the executeCommand stub for clean state
107+
executeCommandStub.resetHistory()
100108
})
101109

102110
afterEach(function () {
@@ -189,6 +197,7 @@ describe('SmusAuthenticationProvider', function () {
189197
assert.ok(getSsoInstanceInfoStub.calledWith(testDomainUrl))
190198
assert.ok(mockAuth.createConnection.called)
191199
assert.ok(mockSecondaryAuth.useNewConnection.called)
200+
assert.ok(executeCommandStub.calledWith('aws.smus.switchProject'))
192201
})
193202

194203
it('should reuse existing valid connection', async function () {
@@ -201,6 +210,7 @@ describe('SmusAuthenticationProvider', function () {
201210
assert.strictEqual(result, mockSmusConnection)
202211
assert.ok(mockAuth.createConnection.notCalled)
203212
assert.ok(mockSecondaryAuth.useNewConnection.calledWith(existingConnection))
213+
assert.ok(executeCommandStub.calledWith('aws.smus.switchProject'))
204214
})
205215

206216
it('should reauthenticate existing invalid connection', async function () {
@@ -213,6 +223,7 @@ describe('SmusAuthenticationProvider', function () {
213223
assert.strictEqual(result, mockSmusConnection)
214224
assert.ok(mockAuth.reauthenticate.calledWith(existingConnection))
215225
assert.ok(mockSecondaryAuth.useNewConnection.called)
226+
assert.ok(executeCommandStub.calledWith('aws.smus.switchProject'))
216227
})
217228

218229
it('should throw error for invalid domain URL', async function () {
@@ -225,6 +236,8 @@ describe('SmusAuthenticationProvider', function () {
225236
return err.code === 'FailedToConnect' && (err.cause as any)?.code === 'InvalidDomainUrl'
226237
}
227238
)
239+
// Should not trigger project selection on error
240+
assert.ok(executeCommandStub.notCalled)
228241
})
229242

230243
it('should handle SmusUtils errors', async function () {
@@ -235,6 +248,8 @@ describe('SmusAuthenticationProvider', function () {
235248
() => smusAuthProvider.connectToSmus(testDomainUrl),
236249
(err: ToolkitError) => err.code === 'FailedToConnect'
237250
)
251+
// Should not trigger project selection on error
252+
assert.ok(executeCommandStub.notCalled)
238253
})
239254

240255
it('should handle auth creation errors', async function () {
@@ -245,6 +260,47 @@ describe('SmusAuthenticationProvider', function () {
245260
() => smusAuthProvider.connectToSmus(testDomainUrl),
246261
(err: ToolkitError) => err.code === 'FailedToConnect'
247262
)
263+
// Should not trigger project selection on error
264+
assert.ok(executeCommandStub.notCalled)
265+
})
266+
267+
it('should not trigger project selection in SMUS space environment', async function () {
268+
isInSmusSpaceEnvironmentStub.returns(true)
269+
mockAuth.listConnections.resolves([])
270+
271+
const result = await smusAuthProvider.connectToSmus(testDomainUrl)
272+
273+
assert.strictEqual(result, mockSmusConnection)
274+
assert.ok(mockAuth.createConnection.called)
275+
assert.ok(mockSecondaryAuth.useNewConnection.called)
276+
assert.ok(executeCommandStub.notCalled)
277+
})
278+
279+
it('should not trigger project selection when reusing connection in SMUS space environment', async function () {
280+
isInSmusSpaceEnvironmentStub.returns(true)
281+
const existingConnection = { ...mockSmusConnection, domainUrl: testDomainUrl.toLowerCase() }
282+
mockAuth.listConnections.resolves([existingConnection])
283+
mockAuth.getConnectionState.returns('valid')
284+
285+
const result = await smusAuthProvider.connectToSmus(testDomainUrl)
286+
287+
assert.strictEqual(result, mockSmusConnection)
288+
assert.ok(mockSecondaryAuth.useNewConnection.calledWith(existingConnection))
289+
assert.ok(executeCommandStub.notCalled)
290+
})
291+
292+
it('should not trigger project selection when reauthenticating in SMUS space environment', async function () {
293+
isInSmusSpaceEnvironmentStub.returns(true)
294+
const existingConnection = { ...mockSmusConnection, domainUrl: testDomainUrl.toLowerCase() }
295+
mockAuth.listConnections.resolves([existingConnection])
296+
mockAuth.getConnectionState.returns('invalid')
297+
298+
const result = await smusAuthProvider.connectToSmus(testDomainUrl)
299+
300+
assert.strictEqual(result, mockSmusConnection)
301+
assert.ok(mockAuth.reauthenticate.calledWith(existingConnection))
302+
assert.ok(mockSecondaryAuth.useNewConnection.called)
303+
assert.ok(executeCommandStub.notCalled)
248304
})
249305
})
250306

packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ describe('SmusUtils', () => {
164164
})
165165

166166
it('should export SmusCredentialExpiry with correct values', () => {
167-
assert.strictEqual(SmusCredentialExpiry.derExpiryMs, 55 * 60 * 1000)
167+
assert.strictEqual(SmusCredentialExpiry.derExpiryMs, 10 * 60 * 1000)
168168
assert.strictEqual(SmusCredentialExpiry.projectExpiryMs, 10 * 60 * 1000)
169169
assert.strictEqual(SmusCredentialExpiry.connectionExpiryMs, 10 * 60 * 1000)
170170
})

0 commit comments

Comments
 (0)