Skip to content

Commit 8049826

Browse files
Bhargava Varadharajanguntamb
authored andcommitted
fix(smus): Add proactive cred refresh for active SSH connections
**Description** Added a proactive cred check and refresh when SSH connections are established. Also updated the error messages to be actionable for users. **Motivation** Bug : Previously once cred expired, we were throwing blanket error which did not tell user what the issue was and there was no path to recovery as well. Now with proactive cred refresh, user should be able to retry in ~10-15 seconds. **Testing Done** Tested all flows manually. Unit tests partial, needs to be updated.
1 parent 75d77b8 commit 8049826

File tree

14 files changed

+401
-25
lines changed

14 files changed

+401
-25
lines changed

packages/core/src/auth/auth.ts

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ export class Auth implements AuthService, ConnectionManager {
233233
}
234234

235235
const provider = this.getSsoTokenProvider(connection.id, profile)
236-
const token = await provider.getToken()
236+
// Calling existing getToken private method - It will handle setting the connection state etc.
237+
const token = await this._getToken(connection.id, provider)
237238

238239
if (!token?.accessToken) {
239240
throw new Error(`No access token available for connection ${connection.id}`)
@@ -946,10 +947,22 @@ export class Auth implements AuthService, ConnectionManager {
946947
if (previousState === 'valid') {
947948
// Non-token expiration errors can happen. We must log it here, otherwise they are lost.
948949
getLogger().warn(`auth: valid connection became invalid. Last error: %s`, this.#validationErrors.get(id))
949-
950950
const timeout = new Timeout(60000)
951951
this.#invalidCredentialsTimeouts.set(id, timeout)
952952

953+
// Check if this is a SMUS profile - if so, skip the generic prompt
954+
// as SMUS has its own reauthentication flow
955+
const isSmusConnection = profile.type === 'sso' && 'domainUrl' in profile && 'domainId' in profile
956+
if (isSmusConnection) {
957+
getLogger().debug(`auth: Skipping generic reauthentication prompt for SMUS connection ${id}`)
958+
// For SMUS connections, just throw the InvalidConnection error
959+
// The SMUS auth provider will handle showing the appropriate prompt
960+
throw new ToolkitError('Connection is invalid or expired. Try logging in again.', {
961+
code: errorCode.invalidConnection,
962+
cause: this.#validationErrors.get(id),
963+
})
964+
}
965+
953966
const connLabel = profile.metadata.label ?? (profile.type === 'sso' ? this.getSsoProfileLabel(profile) : id)
954967
const message = localize(
955968
'aws.auth.invalidConnection',

packages/core/src/awsService/sagemaker/credentialMapping.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,12 @@ export async function persistLocalCredentials(spaceArn: string): Promise<void> {
7474
export async function persistSmusProjectCreds(spaceArn: string, node: SagemakerUnifiedStudioSpaceNode): Promise<void> {
7575
const nodeParent = node.getParent() as SageMakerUnifiedStudioSpacesParentNode
7676
const authProvider = nodeParent.getAuthProvider()
77-
const projectAuthProvider = await authProvider.getProjectCredentialProvider(nodeParent.getProjectId())
77+
const projectId = nodeParent.getProjectId()
78+
const projectAuthProvider = await authProvider.getProjectCredentialProvider(projectId)
7879
await projectAuthProvider.getCredentials()
79-
await setSmusSpaceSsoProfile(spaceArn, nodeParent.getProjectId())
80+
await setSmusSpaceSsoProfile(spaceArn, projectId)
81+
// Trigger SSH credential refresh for the project
82+
projectAuthProvider.startProactiveCredentialRefresh()
8083
}
8184

8285
/**

packages/core/src/awsService/sagemaker/detached-server/errorPage.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import { open } from './utils'
1515
export enum ExceptionType {
1616
ACCESS_DENIED = 'AccessDeniedException',
1717
DEFAULT = 'Default',
18+
EXPIRED_TOKEN = 'ExpiredTokenException',
1819
INTERNAL_FAILURE = 'InternalFailure',
1920
RESOURCE_LIMIT_EXCEEDED = 'ResourceLimitExceeded',
2021
THROTTLING = 'ThrottlingException',
@@ -31,13 +32,18 @@ export const getVSCodeErrorTitle = (error: SageMakerServiceException): string =>
3132
return ErrorText.StartSession[ExceptionType.DEFAULT].Title
3233
}
3334

34-
export const getVSCodeErrorText = (error: SageMakerServiceException): string => {
35+
export const getVSCodeErrorText = (error: SageMakerServiceException, isSmus?: boolean): string => {
3536
const exceptionType = error.name as ExceptionType
3637

3738
switch (exceptionType) {
3839
case ExceptionType.ACCESS_DENIED:
3940
case ExceptionType.VALIDATION:
4041
return ErrorText.StartSession[exceptionType].Text.replace('{message}', error.message)
42+
case ExceptionType.EXPIRED_TOKEN:
43+
// Use SMUS-specific message if in SMUS context
44+
return isSmus
45+
? ErrorText.StartSession[ExceptionType.EXPIRED_TOKEN].SmusText
46+
: ErrorText.StartSession[exceptionType].Text
4147
case ExceptionType.INTERNAL_FAILURE:
4248
case ExceptionType.RESOURCE_LIMIT_EXCEEDED:
4349
case ExceptionType.THROTTLING:
@@ -57,6 +63,12 @@ export const ErrorText = {
5763
Title: 'Unexpected system error',
5864
Text: 'We encountered an unexpected error: [{exceptionType}]. Please contact your administrator and provide them with this error so they can investigate the issue.',
5965
},
66+
[ExceptionType.EXPIRED_TOKEN]: {
67+
Title: 'Authentication expired',
68+
Text: 'Your session has expired. Please refresh your credentials and try again.',
69+
SmusText:
70+
'Your session has expired. This is likely due to network connectivity issues after machine sleep/resume. Please wait 10-30 seconds for automatic credential refresh, then try again. If the issue persists, try reconnecting through AWS Toolkit.',
71+
},
6072
[ExceptionType.INTERNAL_FAILURE]: {
6173
Title: 'Failed to connect remotely to VSCode',
6274
Text: 'Unable to establish remote connection to VSCode. This could be due to several factors. Please try again by clicking the VSCode button. If the problem persists, please contact your admin.',

packages/core/src/awsService/sagemaker/detached-server/routes/getSession.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
// Disabled: detached server files cannot import vscode.
77
/* eslint-disable aws-toolkits/no-console-log */
88
import { IncomingMessage, ServerResponse } from 'http'
9-
import { startSagemakerSession, parseArn } from '../utils'
9+
import { startSagemakerSession, parseArn, isSmusConnection } from '../utils'
1010
import { resolveCredentialsFor } from '../credentials'
1111
import url from 'url'
1212
import { SageMakerServiceException } from '@amzn/sagemaker-client'
@@ -33,6 +33,8 @@ export async function handleGetSession(req: IncomingMessage, res: ServerResponse
3333
}
3434

3535
const { region } = parseArn(connectionIdentifier)
36+
// Detect if this is a SMUS connection for specialized error handling
37+
const isSmus = await isSmusConnection(connectionIdentifier)
3638

3739
try {
3840
const session = await startSagemakerSession({ region, connectionIdentifier, credentials })
@@ -48,7 +50,7 @@ export async function handleGetSession(req: IncomingMessage, res: ServerResponse
4850
const error = err as SageMakerServiceException
4951
console.error(`Failed to start SageMaker session for ${connectionIdentifier}:`, err)
5052
const errorTitle = getVSCodeErrorTitle(error)
51-
const errorText = getVSCodeErrorText(error)
53+
const errorText = getVSCodeErrorText(error, isSmus)
5254
await openErrorPage(errorTitle, errorText)
5355
res.writeHead(500, { 'Content-Type': 'text/plain' })
5456
res.end('Failed to start SageMaker session')

packages/core/src/awsService/sagemaker/detached-server/utils.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,24 @@ async function processWriteQueue() {
121121
}
122122
}
123123

124+
/**
125+
* Detects if the connection identifier is using SMUS credentials
126+
* @param connectionIdentifier - The connection identifier to check
127+
* @returns Promise<boolean> - true if SMUS, false otherwise
128+
*/
129+
export async function isSmusConnection(connectionIdentifier: string): Promise<boolean> {
130+
try {
131+
const mapping = await readMapping()
132+
const profile = mapping.localCredential?.[connectionIdentifier]
133+
134+
// Check if profile exists and has smusProjectId
135+
return profile && 'smusProjectId' in profile
136+
} catch (err) {
137+
// If we can't read the mapping, assume not SMUS to avoid breaking existing functionality
138+
return false
139+
}
140+
}
141+
124142
/**
125143
* Writes the mapping to a temp file and atomically renames it to the target path.
126144
* Uses a queue to prevent race conditions when multiple requests try to write simultaneously.

packages/core/src/sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,18 @@ export class ConnectionCredentialsProvider implements CredentialsProvider {
218218
`SMUS Connection: Successfully invalidated connection credentials cache for connection ${this.connectionId}`
219219
)
220220
}
221+
222+
/**
223+
* Disposes of the provider and cleans up resources
224+
*/
225+
public dispose(): void {
226+
this.logger.debug(
227+
`SMUS Connection: Disposing connection credentials provider for connection ${this.connectionId}`
228+
)
229+
// Clear cache to clean up resources
230+
this.invalidate()
231+
this.logger.debug(
232+
`SMUS Connection: Successfully disposed connection credentials provider for connection ${this.connectionId}`
233+
)
234+
}
221235
}

packages/core/src/sagemakerunifiedstudio/auth/providers/domainExecRoleCredentialsProvider.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,4 +314,12 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider {
314314
this.credentialCache = undefined
315315
this.logger.debug(`SMUS DER: Successfully invalidated DER credentials cache for domain ${this.domainId}`)
316316
}
317+
/**
318+
* Disposes of the provider and cleans up resources
319+
*/
320+
public dispose(): void {
321+
this.logger.debug(`SMUS DER: Disposing DER credentials provider for domain ${this.domainId}`)
322+
this.invalidate()
323+
this.logger.debug(`SMUS DER: Successfully disposed DER credentials provider for domain ${this.domainId}`)
324+
}
317325
}

packages/core/src/sagemakerunifiedstudio/auth/providers/projectRoleCredentialsProvider.ts

Lines changed: 161 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ export class ProjectRoleCredentialsProvider implements CredentialsProvider {
2828
credentials: AWS.Credentials
2929
expiresAt: Date
3030
}
31+
private refreshTimer?: NodeJS.Timeout
32+
private readonly refreshInterval = 10 * 60 * 1000 // 10 minutes
33+
private readonly checkInterval = 10 * 1000 // 10 seconds - check frequently, refresh based on actual time
34+
private sshRefreshActive = false
35+
private lastRefreshTime?: Date
3136

3237
constructor(
3338
private readonly smusAuthProvider: SmusAuthenticationProvider,
@@ -162,11 +167,21 @@ export class ProjectRoleCredentialsProvider implements CredentialsProvider {
162167

163168
return awsCredentials
164169
} catch (err) {
165-
this.logger.error(
166-
'SMUS Project: Failed to get project credentials for project %s: %s',
167-
this.projectId,
168-
(err as Error).message
169-
)
170+
this.logger.error('SMUS Project: Failed to get project credentials for project %s: %s', this.projectId, err)
171+
172+
// Handle InvalidGrantException specially - indicates need for reauthentication
173+
if (err instanceof Error && err.name === 'InvalidGrantException') {
174+
// Invalidate cache when authentication fails
175+
this.invalidate()
176+
throw new ToolkitError(
177+
`Failed to get project credentials for project ${this.projectId}: ${err.message}. Reauthentication required.`,
178+
{
179+
code: 'InvalidRefreshToken',
180+
cause: err,
181+
}
182+
)
183+
}
184+
170185
throw new ToolkitError(`Failed to get project credentials for project ${this.projectId}: ${err}`, {
171186
code: 'ProjectCredentialsFetchFailed',
172187
cause: err instanceof Error ? err : undefined,
@@ -192,6 +207,139 @@ export class ProjectRoleCredentialsProvider implements CredentialsProvider {
192207
}
193208
}
194209

210+
/**
211+
* Starts proactive credential refresh for SSH connections
212+
*
213+
* Uses an expiry-based approach with safety buffer:
214+
* - Checks every 10 seconds using setTimeout
215+
* - Refreshes when credentials expire within 5 minutes (safety buffer)
216+
* - Falls back to 10-minute time-based refresh if no expiry information available
217+
* - Handles sleep/resume because it uses wall-clock time for expiry checks
218+
*
219+
* This means credentials are refreshed just before they expire, reducing
220+
* unnecessary API calls while ensuring credentials remain valid.
221+
*/
222+
public startProactiveCredentialRefresh(): void {
223+
if (this.sshRefreshActive) {
224+
this.logger.debug(`SMUS Project: SSH refresh already active for project ${this.projectId}`)
225+
return
226+
}
227+
228+
this.logger.info(`SMUS Project: Starting SSH credential refresh for project ${this.projectId}`)
229+
this.sshRefreshActive = true
230+
this.lastRefreshTime = new Date() // Initialize refresh time
231+
232+
// Start the check timer (checks every 10 seconds, refreshes every 10 minutes based on actual time)
233+
this.scheduleNextCheck()
234+
}
235+
236+
/**
237+
* Stops proactive credential refresh
238+
* Called when SSH connection ends or SMUS disconnects
239+
*/
240+
public stopProactiveCredentialRefresh(): void {
241+
if (!this.sshRefreshActive) {
242+
return
243+
}
244+
245+
this.logger.info(`SMUS Project: Stopping SSH credential refresh for project ${this.projectId}`)
246+
this.sshRefreshActive = false
247+
this.lastRefreshTime = undefined
248+
249+
// Clean up timer
250+
if (this.refreshTimer) {
251+
clearTimeout(this.refreshTimer)
252+
this.refreshTimer = undefined
253+
}
254+
}
255+
256+
/**
257+
* Schedules the next credential check (every 10 seconds)
258+
* Refreshes credentials when they expire within 5 minutes (safety buffer)
259+
* Falls back to 10-minute time-based refresh if no expiry information available
260+
* This handles sleep/resume scenarios correctly
261+
*/
262+
private scheduleNextCheck(): void {
263+
if (!this.sshRefreshActive) {
264+
return
265+
}
266+
// Check every 10 seconds, but only refresh every 10 minutes based on actual time elapsed
267+
this.refreshTimer = setTimeout(async () => {
268+
try {
269+
const now = new Date()
270+
// Check if we need to refresh based on actual time elapsed
271+
if (this.shouldPerformRefresh(now)) {
272+
await this.refresh()
273+
}
274+
// Schedule next check if still active
275+
if (this.sshRefreshActive) {
276+
this.scheduleNextCheck()
277+
}
278+
} catch (error) {
279+
this.logger.error(
280+
`SMUS Project: Failed to refresh credentials for project ${this.projectId}: %O`,
281+
error
282+
)
283+
// Continue trying even if refresh fails. Dispose will handle stopping the refresh.
284+
if (this.sshRefreshActive) {
285+
this.scheduleNextCheck()
286+
}
287+
}
288+
}, this.checkInterval)
289+
}
290+
291+
/**
292+
* Determines if a credential refresh should be performed based on credential expiration
293+
* This handles sleep/resume scenarios properly and is more efficient than time-based refresh
294+
*/
295+
private shouldPerformRefresh(now: Date): boolean {
296+
if (!this.lastRefreshTime || !this.credentialCache) {
297+
// First refresh or no cached credentials
298+
this.logger.debug(`SMUS Project: First refresh - no previous credentials for ${this.projectId}`)
299+
return true
300+
}
301+
302+
// Check if credentials expire soon (with 5-minute safety buffer)
303+
const safetyBufferMs = 5 * 60 * 1000 // 5 minutes before expiry
304+
const expiryTime = this.credentialCache.credentials.expiration?.getTime()
305+
306+
if (!expiryTime) {
307+
// No expiry info - fall back to time-based refresh as safety net
308+
const timeSinceLastRefresh = now.getTime() - this.lastRefreshTime.getTime()
309+
const shouldRefresh = timeSinceLastRefresh >= this.refreshInterval
310+
return shouldRefresh
311+
}
312+
313+
const timeUntilExpiry = expiryTime - now.getTime()
314+
const shouldRefresh = timeUntilExpiry < safetyBufferMs
315+
return shouldRefresh
316+
}
317+
318+
/**
319+
* Performs credential refresh by invalidating cache and fetching fresh credentials
320+
*/
321+
private async refresh(): Promise<void> {
322+
const now = new Date()
323+
const expiryTime = this.credentialCache?.credentials.expiration?.getTime()
324+
325+
if (expiryTime) {
326+
const minutesUntilExpiry = Math.round((expiryTime - now.getTime()) / 60000)
327+
this.logger.debug(
328+
`SMUS Project: Refreshing credentials for project ${this.projectId} - expires in ${minutesUntilExpiry} minutes`
329+
)
330+
} else {
331+
const minutesSinceLastRefresh = this.lastRefreshTime
332+
? Math.round((now.getTime() - this.lastRefreshTime.getTime()) / 60000)
333+
: 0
334+
this.logger.debug(
335+
`SMUS Project: Refreshing credentials for project ${this.projectId} - time-based refresh after ${minutesSinceLastRefresh} minutes`
336+
)
337+
}
338+
339+
await this.getCredentials()
340+
this.lastRefreshTime = new Date()
341+
}
342+
195343
/**
196344
* Invalidates cached project credentials
197345
* Clears the internal cache without fetching new credentials
@@ -204,4 +352,12 @@ export class ProjectRoleCredentialsProvider implements CredentialsProvider {
204352
`SMUS Project: Successfully invalidated project credentials cache for project ${this.projectId}`
205353
)
206354
}
355+
356+
/**
357+
* Disposes of the provider and cleans up resources
358+
*/
359+
public dispose(): void {
360+
this.stopProactiveCredentialRefresh()
361+
this.invalidate()
362+
}
207363
}

0 commit comments

Comments
 (0)