Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions packages/core/src/awsService/sagemaker/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,16 @@ export async function deeplinkConnect(
wsUrl: string,
token: string,
domain: string,
appType?: string
appType?: string,
isSMUS: boolean = false
) {
getLogger().debug(
`sm:deeplinkConnect: connectionIdentifier: ${connectionIdentifier} session: ${session} wsUrl: ${wsUrl} token: ${token}`
'sm:deeplinkConnect: connectionIdentifier: %s session: %s wsUrl: %s token: %s isSMUS: %s',
connectionIdentifier,
session,
wsUrl,
token,
isSMUS
)

if (isRemoteWorkspace()) {
Expand All @@ -112,7 +118,7 @@ export async function deeplinkConnect(
connectionIdentifier,
ctx.extensionContext,
'sm_dl',
false /* isSMUS */,
isSMUS,
undefined /* node */,
session,
wsUrl,
Expand All @@ -130,7 +136,10 @@ export async function deeplinkConnect(
)
} catch (err: any) {
getLogger().error(
`sm:OpenRemoteConnect: Unable to connect to target space with arn: ${connectionIdentifier} error: ${err}`
'sm:OpenRemoteConnect: Unable to connect to target space with arn: %s error: %s isSMUS: %s',
connectionIdentifier,
err,
isSMUS
)

if (![RemoteSessionError.MissingExtension, RemoteSessionError.ExtensionVersionTooLow].includes(err.code)) {
Expand Down
8 changes: 8 additions & 0 deletions packages/core/src/awsService/sagemaker/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,11 @@ export const InstanceTypeNotSelectedMessage = (spaceName: string) => {

export const RemoteAccessRequiredMessage =
'This space requires remote access to be enabled.\nWould you like to restart the space and connect?\nAny unsaved work will be lost.'

export const SmusDeeplinkSessionExpiredError = {
title: 'Session Disconnected',
message:
'Your SageMaker Unified Studio session has been disconnected. Select a local (non-remote) VS Code window and use the SageMaker Unified Studio portal to connect again.',
code: 'SMUS_SESSION_DISCONNECTED',
shortMessage: 'Session disconnected, re-connect from SageMaker Unified Studio portal.',
} as const
58 changes: 34 additions & 24 deletions packages/core/src/awsService/sagemaker/credentialMapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,41 +90,51 @@ export async function persistSmusProjectCreds(spaceArn: string, node: SagemakerU
* @param session - SSM session ID.
* @param wsUrl - SSM WebSocket URL.
* @param token - Bearer token for the session.
* @param appType - Application type (e.g., 'jupyterlab', 'codeeditor').
* @param isSMUS - If true, skip refreshUrl construction (SMUS connections cannot refresh).
*/
export async function persistSSMConnection(
spaceArn: string,
domain: string,
session?: string,
wsUrl?: string,
token?: string,
appType?: string
appType?: string,
isSMUS?: boolean
): Promise<void> {
const { region } = parseArn(spaceArn)
const endpoint = DevSettings.instance.get('endpoints', {})['sagemaker'] ?? ''
let refreshUrl: string | undefined

let appSubDomain = 'jupyterlab'
if (appType && appType.toLowerCase() === 'codeeditor') {
appSubDomain = 'code-editor'
}
if (!isSMUS) {
// Construct refreshUrl for SageMaker AI connections
const { region } = parseArn(spaceArn)
const endpoint = DevSettings.instance.get('endpoints', {})['sagemaker'] ?? ''

let envSubdomain: string
let appSubDomain = 'jupyterlab'
if (appType && appType.toLowerCase() === 'codeeditor') {
appSubDomain = 'code-editor'
}

if (endpoint.includes('beta')) {
envSubdomain = 'devo'
} else if (endpoint.includes('gamma')) {
envSubdomain = 'loadtest'
} else {
envSubdomain = 'studio'
}
let envSubdomain: string

// Use the standard AWS domain for 'studio' (prod).
// For non-prod environments, use the obfuscated domain 'asfiovnxocqpcry.com'.
const baseDomain =
envSubdomain === 'studio'
? `studio.${region}.sagemaker.aws`
: `${envSubdomain}.studio.${region}.asfiovnxocqpcry.com`
if (endpoint.includes('beta')) {
envSubdomain = 'devo'
} else if (endpoint.includes('gamma')) {
envSubdomain = 'loadtest'
} else {
envSubdomain = 'studio'
}

// Use the standard AWS domain for 'studio' (prod).
// For non-prod environments, use the obfuscated domain 'asfiovnxocqpcry.com'.
const baseDomain =
envSubdomain === 'studio'
? `studio.${region}.sagemaker.aws`
: `${envSubdomain}.studio.${region}.asfiovnxocqpcry.com`

refreshUrl = `https://studio-${domain}.${baseDomain}/${appSubDomain}`
}
// For SMUS connections, refreshUrl remains undefined

const refreshUrl = `https://studio-${domain}.${baseDomain}/${appSubDomain}`
await setSpaceCredentials(spaceArn, refreshUrl, {
sessionId: session ?? '-',
url: wsUrl ?? '-',
Expand Down Expand Up @@ -179,12 +189,12 @@ export async function setSmusSpaceSsoProfile(spaceArn: string, projectId: string
* Stores SSM connection information for a given space, typically from a deep link session.
* This initializes the request as 'fresh' and includes a refresh URL if provided.
* @param spaceArn - The arn of the SageMaker space.
* @param refreshUrl - URL to use for refreshing session tokens.
* @param refreshUrl - URL to use for refreshing session tokens (undefined for SMUS connections).
* @param credentials - The session information used to initiate the connection.
*/
export async function setSpaceCredentials(
spaceArn: string,
refreshUrl: string,
refreshUrl: string | undefined,
credentials: SsmConnectionInfo
): Promise<void> {
const data = await loadMappings()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { IncomingMessage, ServerResponse } from 'http'
import url from 'url'
import { SessionStore } from '../sessionStore'
import { open, parseArn, readServerInfo } from '../utils'
import { openErrorPage } from '../errorPage'
import { SmusDeeplinkSessionExpiredError } from '../../constants'

export async function handleGetSessionAsync(req: IncomingMessage, res: ServerResponse): Promise<void> {
const parsedUrl = url.parse(req.url || '', true)
Expand Down Expand Up @@ -46,8 +48,34 @@ export async function handleGetSessionAsync(req: IncomingMessage, res: ServerRes
res.end()
return
} else if (status === 'not-started') {
const serverInfo = await readServerInfo()
const refreshUrl = await store.getRefreshUrl(connectionIdentifier)

// Check if this is a SMUS connection (no refreshUrl available)
if (refreshUrl === undefined) {
console.log(`SMUS session expired for connection: ${connectionIdentifier}`)

// Clean up the expired connection entry
try {
await store.cleanupExpiredConnection(connectionIdentifier)
console.log(`Cleaned up expired connection: ${connectionIdentifier}`)
} catch (cleanupErr) {
console.error(`Failed to cleanup expired connection: ${cleanupErr}`)
// Continue with error response even if cleanup fails
}

await openErrorPage(SmusDeeplinkSessionExpiredError.title, SmusDeeplinkSessionExpiredError.message)
res.writeHead(400, { 'Content-Type': 'application/json' })
res.end(
JSON.stringify({
error: SmusDeeplinkSessionExpiredError.code,
message: SmusDeeplinkSessionExpiredError.shortMessage,
})
)
return
}

// Continue with existing SageMaker AI refresh flow
const serverInfo = await readServerInfo()
const { spaceName } = parseArn(connectionIdentifier)

const url = `${refreshUrl}/${encodeURIComponent(spaceName)}?remote_access_token_refresh=true&reconnect_identifier=${encodeURIComponent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { readMapping, writeMapping } from './utils'
export type SessionStatus = 'pending' | 'fresh' | 'consumed' | 'not-started'

export class SessionStore {
async getRefreshUrl(connectionId: string) {
async getRefreshUrl(connectionId: string): Promise<string | undefined> {
const mapping = await readMapping()

if (!mapping.deepLink) {
Expand All @@ -21,10 +21,6 @@ export class SessionStore {
throw new Error(`No mapping found for connectionId: "${connectionId}"`)
}

if (!entry.refreshUrl) {
throw new Error(`No refreshUrl found for connectionId: "${connectionId}"`)
}

return entry.refreshUrl
}

Expand Down Expand Up @@ -113,6 +109,20 @@ export class SessionStore {
await writeMapping(mapping)
}

async cleanupExpiredConnection(connectionId: string) {
const mapping = await readMapping()

if (!mapping.deepLink) {
throw new Error('No deepLink mapping found')
}

// Remove the entire connection entry for the expired space
if (mapping.deepLink[connectionId]) {
delete mapping.deepLink[connectionId]
await writeMapping(mapping)
}
}

async setSession(connectionId: string, requestId: string, ssmConnectionInfo: SsmConnectionInfo) {
const mapping = await readMapping()

Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/awsService/sagemaker/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ export async function prepareDevEnvConnection(
await persistSmusProjectCreds(spaceArn, node as SagemakerUnifiedStudioSpaceNode)
}
} else if (connectionType === 'sm_dl') {
await persistSSMConnection(spaceArn, domain ?? '', session, wsUrl, token, appType)
await persistSSMConnection(spaceArn, domain ?? '', session, wsUrl, token, appType, isSMUS)
}

await startLocalServer(ctx)
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/extensionNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ export async function activate(context: vscode.ExtensionContext) {
await handleAmazonQInstall()
}

await activateSageMakerUnifiedStudio(context)
await activateSageMakerUnifiedStudio(extContext)

await activateApplicationComposer(context)
await activateThreatComposerEditor(context)
Expand Down
12 changes: 8 additions & 4 deletions packages/core/src/sagemakerunifiedstudio/activation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,25 @@
* SPDX-License-Identifier: Apache-2.0
*/

import * as vscode from 'vscode'
import { activate as activateConnectionMagicsSelector } from './connectionMagicsSelector/activation'
import { activate as activateExplorer } from './explorer/activation'
import { isSageMaker } from '../shared/extensionUtilities'
import { initializeResourceMetadata } from './shared/utils/resourceMetadataUtils'
import { setContext } from '../shared/vscode/setContext'
import { SmusUtils } from './shared/smusUtils'
import * as smusUriHandlers from './uriHandlers'
import { ExtContext } from '../shared/extensions'

export async function activate(extensionContext: vscode.ExtensionContext): Promise<void> {
export async function activate(ctx: ExtContext): Promise<void> {
// Only run when environment is a SageMaker Unified Studio space
if (isSageMaker('SMUS') || isSageMaker('SMUS-SPACE-REMOTE-ACCESS')) {
await initializeResourceMetadata()
// Setting context before any getContext calls to avoid potential race conditions.
await setContext('aws.smus.inSmusSpaceEnvironment', SmusUtils.isInSmusSpaceEnvironment())
await activateConnectionMagicsSelector(extensionContext)
await activateConnectionMagicsSelector(ctx.extensionContext)
}
await activateExplorer(extensionContext)
await activateExplorer(ctx.extensionContext)

// Register SMUS URI handler for deeplink connections
ctx.extensionContext.subscriptions.push(smusUriHandlers.register(ctx))
}
121 changes: 121 additions & 0 deletions packages/core/src/sagemakerunifiedstudio/uriHandlers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*!
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

import * as vscode from 'vscode'
import { SearchParams } from '../shared/vscode/uriHandler'
import { ExtContext } from '../shared/extensions'
import { deeplinkConnect } from '../awsService/sagemaker/commands'
import { telemetry } from '../shared/telemetry/telemetry'
/**
* Registers the SMUS deeplink URI handler at path `/connect/smus`.
*
* This handler processes deeplink URLs from the SageMaker Unified Studio console
* to establish remote connections to SMUS spaces.
*
* @param ctx Extension context containing the URI handler
* @returns Disposable for cleanup
*/
export function register(ctx: ExtContext) {
async function connectHandler(params: ReturnType<typeof parseConnectParams>) {
await telemetry.smus_deeplinkConnect.run(async (span) => {
span.record(extractTelemetryMetadata(params))

// WORKAROUND: The ws_url from the startSession API call contains a query parameter
// 'cell-number' within itself. When the entire deeplink URL is processed by the URI
// handler, 'cell-number' is parsed as a standalone query parameter at the top level
// instead of remaining part of the ws_url. This causes the ws_url to lose the
// cell-number context it needs. To fix this, we manually re-append the cell-number
// query parameter back to the ws_url to restore the original intended URL structure.
await deeplinkConnect(
ctx,
params.connection_identifier,
params.session,
`${params.ws_url}&cell-number=${params['cell-number']}`, // Re-append cell-number to ws_url
params.token,
params.domain,
params.app_type,
true // isSMUS=true for SMUS connections
)
})
}

return vscode.Disposable.from(ctx.uriHandler.onPath('/connect/smus', connectHandler, parseConnectParams))
}

/**
* Parses and validates SMUS deeplink URI parameters.
*
* Required parameters:
* - connection_identifier: Space ARN identifying the SMUS space
* - domain: Domain ID for the SMUS space (SM AI side)
* - user_profile: User profile name
* - session: SSM session ID
* - ws_url: WebSocket URL for SSM connection (originally contains cell-number as a query param)
* - cell-number: extracted from ws_url during URI parsing
* - token: Authentication token
*
* Optional parameters:
* - app_type: Application type (e.g., JupyterLab, CodeEditor)
* - smus_domain_id: SMUS domain identifier
* - smus_domain_account_id: SMUS domain account ID
* - smus_project_id: SMUS project identifier
* - smus_domain_region: SMUS domain region
*
* Note: The ws_url from startSession API originally includes cell-number as a query parameter.
* However, when the deeplink URL is processed, the URI handler extracts cell-number as a
* separate top-level parameter. This is why we need to re-append it in the connectHandler.
*
* @param query URI query parameters
* @returns Parsed parameters object
* @throws Error if required parameters are missing
*/
export function parseConnectParams(query: SearchParams) {
const requiredParams = query.getFromKeysOrThrow(
'connection_identifier',
'domain',
'user_profile',
'session',
'ws_url',
'cell-number',
'token'
)
const optionalParams = query.getFromKeys(
'app_type',
'smus_domain_id',
'smus_domain_account_id',
'smus_project_id',
'smus_domain_region'
)

return { ...requiredParams, ...optionalParams }
}

/**
* Extracts telemetry metadata from URI parameters and space ARN.
*
* @param params Parsed URI parameters
* @returns Telemetry metadata object
*/
function extractTelemetryMetadata(params: ReturnType<typeof parseConnectParams>) {
// Extract metadata from space ARN
// ARN format: arn:aws:sagemaker:region:account-id:space/domain-id/space-name
const arnParts = params.connection_identifier.split(':')
const resourceParts = arnParts[5]?.split('/') // Gets "space/domain-id/space-name"

const projectRegion = arnParts[3] // region from ARN
const projectAccountId = arnParts[4] // account-id from ARN
const domainIdFromArn = resourceParts?.[1] // domain-id from ARN
const spaceName = resourceParts?.[2] // space-name from ARN

return {
smusDomainId: params.smus_domain_id,
smusDomainAccountId: params.smus_domain_account_id,
smusProjectId: params.smus_project_id,
smusDomainRegion: params.smus_domain_region,
smusProjectRegion: projectRegion,
smusProjectAccountId: projectAccountId,
smusSpaceKey: domainIdFromArn && spaceName ? `${domainIdFromArn}/${spaceName}` : undefined,
}
}
Loading
Loading