diff --git a/packages/core/src/awsService/sagemaker/commands.ts b/packages/core/src/awsService/sagemaker/commands.ts index 66ffe35fbee..8850314a23d 100644 --- a/packages/core/src/awsService/sagemaker/commands.ts +++ b/packages/core/src/awsService/sagemaker/commands.ts @@ -96,10 +96,16 @@ export async function deeplinkConnect( wsUrl: string, token: string, domain: string, - appType?: string + appType?: string, + isSMUS: boolean = false ) { getLogger().debug( - `sm:deeplinkConnect: connectionIdentifier: ${connectionIdentifier} session: ${session} wsUrl: ${wsUrl} token: ${token}` + 'sm:deeplinkConnect: connectionIdentifier: %s session: %s wsUrl: %s token: %s isSMUS: %s', + connectionIdentifier, + session, + wsUrl, + token, + isSMUS ) if (isRemoteWorkspace()) { @@ -112,7 +118,7 @@ export async function deeplinkConnect( connectionIdentifier, ctx.extensionContext, 'sm_dl', - false /* isSMUS */, + isSMUS, undefined /* node */, session, wsUrl, @@ -130,7 +136,10 @@ export async function deeplinkConnect( ) } catch (err: any) { getLogger().error( - `sm:OpenRemoteConnect: Unable to connect to target space with arn: ${connectionIdentifier} error: ${err}` + 'sm:OpenRemoteConnect: Unable to connect to target space with arn: %s error: %s isSMUS: %s', + connectionIdentifier, + err, + isSMUS ) if (![RemoteSessionError.MissingExtension, RemoteSessionError.ExtensionVersionTooLow].includes(err.code)) { diff --git a/packages/core/src/awsService/sagemaker/constants.ts b/packages/core/src/awsService/sagemaker/constants.ts index 1e0875cd385..6e5f33195a0 100644 --- a/packages/core/src/awsService/sagemaker/constants.ts +++ b/packages/core/src/awsService/sagemaker/constants.ts @@ -45,3 +45,11 @@ export const InstanceTypeNotSelectedMessage = (spaceName: string) => { export const RemoteAccessRequiredMessage = 'This space requires remote access to be enabled.\nWould you like to restart the space and connect?\nAny unsaved work will be lost.' + +export const SmusDeeplinkSessionExpiredError = { + title: 'Session Disconnected', + message: + 'Your SageMaker Unified Studio session has been disconnected. Select a local (non-remote) VS Code window and use the SageMaker Unified Studio portal to connect again.', + code: 'SMUS_SESSION_DISCONNECTED', + shortMessage: 'Session disconnected, re-connect from SageMaker Unified Studio portal.', +} as const diff --git a/packages/core/src/awsService/sagemaker/credentialMapping.ts b/packages/core/src/awsService/sagemaker/credentialMapping.ts index e84b16bb415..f8d58758f11 100644 --- a/packages/core/src/awsService/sagemaker/credentialMapping.ts +++ b/packages/core/src/awsService/sagemaker/credentialMapping.ts @@ -90,6 +90,8 @@ export async function persistSmusProjectCreds(spaceArn: string, node: SagemakerU * @param session - SSM session ID. * @param wsUrl - SSM WebSocket URL. * @param token - Bearer token for the session. + * @param appType - Application type (e.g., 'jupyterlab', 'codeeditor'). + * @param isSMUS - If true, skip refreshUrl construction (SMUS connections cannot refresh). */ export async function persistSSMConnection( spaceArn: string, @@ -97,34 +99,42 @@ export async function persistSSMConnection( session?: string, wsUrl?: string, token?: string, - appType?: string + appType?: string, + isSMUS?: boolean ): Promise { - const { region } = parseArn(spaceArn) - const endpoint = DevSettings.instance.get('endpoints', {})['sagemaker'] ?? '' + let refreshUrl: string | undefined - let appSubDomain = 'jupyterlab' - if (appType && appType.toLowerCase() === 'codeeditor') { - appSubDomain = 'code-editor' - } + if (!isSMUS) { + // Construct refreshUrl for SageMaker AI connections + const { region } = parseArn(spaceArn) + const endpoint = DevSettings.instance.get('endpoints', {})['sagemaker'] ?? '' - let envSubdomain: string + let appSubDomain = 'jupyterlab' + if (appType && appType.toLowerCase() === 'codeeditor') { + appSubDomain = 'code-editor' + } - if (endpoint.includes('beta')) { - envSubdomain = 'devo' - } else if (endpoint.includes('gamma')) { - envSubdomain = 'loadtest' - } else { - envSubdomain = 'studio' - } + let envSubdomain: string - // Use the standard AWS domain for 'studio' (prod). - // For non-prod environments, use the obfuscated domain 'asfiovnxocqpcry.com'. - const baseDomain = - envSubdomain === 'studio' - ? `studio.${region}.sagemaker.aws` - : `${envSubdomain}.studio.${region}.asfiovnxocqpcry.com` + if (endpoint.includes('beta')) { + envSubdomain = 'devo' + } else if (endpoint.includes('gamma')) { + envSubdomain = 'loadtest' + } else { + envSubdomain = 'studio' + } + + // Use the standard AWS domain for 'studio' (prod). + // For non-prod environments, use the obfuscated domain 'asfiovnxocqpcry.com'. + const baseDomain = + envSubdomain === 'studio' + ? `studio.${region}.sagemaker.aws` + : `${envSubdomain}.studio.${region}.asfiovnxocqpcry.com` + + refreshUrl = `https://studio-${domain}.${baseDomain}/${appSubDomain}` + } + // For SMUS connections, refreshUrl remains undefined - const refreshUrl = `https://studio-${domain}.${baseDomain}/${appSubDomain}` await setSpaceCredentials(spaceArn, refreshUrl, { sessionId: session ?? '-', url: wsUrl ?? '-', @@ -179,12 +189,12 @@ export async function setSmusSpaceSsoProfile(spaceArn: string, projectId: string * Stores SSM connection information for a given space, typically from a deep link session. * This initializes the request as 'fresh' and includes a refresh URL if provided. * @param spaceArn - The arn of the SageMaker space. - * @param refreshUrl - URL to use for refreshing session tokens. + * @param refreshUrl - URL to use for refreshing session tokens (undefined for SMUS connections). * @param credentials - The session information used to initiate the connection. */ export async function setSpaceCredentials( spaceArn: string, - refreshUrl: string, + refreshUrl: string | undefined, credentials: SsmConnectionInfo ): Promise { const data = await loadMappings() diff --git a/packages/core/src/awsService/sagemaker/detached-server/routes/getSessionAsync.ts b/packages/core/src/awsService/sagemaker/detached-server/routes/getSessionAsync.ts index f8dad504067..c0db2712d07 100644 --- a/packages/core/src/awsService/sagemaker/detached-server/routes/getSessionAsync.ts +++ b/packages/core/src/awsService/sagemaker/detached-server/routes/getSessionAsync.ts @@ -9,6 +9,8 @@ import { IncomingMessage, ServerResponse } from 'http' import url from 'url' import { SessionStore } from '../sessionStore' import { open, parseArn, readServerInfo } from '../utils' +import { openErrorPage } from '../errorPage' +import { SmusDeeplinkSessionExpiredError } from '../../constants' export async function handleGetSessionAsync(req: IncomingMessage, res: ServerResponse): Promise { const parsedUrl = url.parse(req.url || '', true) @@ -46,8 +48,34 @@ export async function handleGetSessionAsync(req: IncomingMessage, res: ServerRes res.end() return } else if (status === 'not-started') { - const serverInfo = await readServerInfo() const refreshUrl = await store.getRefreshUrl(connectionIdentifier) + + // Check if this is a SMUS connection (no refreshUrl available) + if (refreshUrl === undefined) { + console.log(`SMUS session expired for connection: ${connectionIdentifier}`) + + // Clean up the expired connection entry + try { + await store.cleanupExpiredConnection(connectionIdentifier) + console.log(`Cleaned up expired connection: ${connectionIdentifier}`) + } catch (cleanupErr) { + console.error(`Failed to cleanup expired connection: ${cleanupErr}`) + // Continue with error response even if cleanup fails + } + + await openErrorPage(SmusDeeplinkSessionExpiredError.title, SmusDeeplinkSessionExpiredError.message) + res.writeHead(400, { 'Content-Type': 'application/json' }) + res.end( + JSON.stringify({ + error: SmusDeeplinkSessionExpiredError.code, + message: SmusDeeplinkSessionExpiredError.shortMessage, + }) + ) + return + } + + // Continue with existing SageMaker AI refresh flow + const serverInfo = await readServerInfo() const { spaceName } = parseArn(connectionIdentifier) const url = `${refreshUrl}/${encodeURIComponent(spaceName)}?remote_access_token_refresh=true&reconnect_identifier=${encodeURIComponent( diff --git a/packages/core/src/awsService/sagemaker/detached-server/sessionStore.ts b/packages/core/src/awsService/sagemaker/detached-server/sessionStore.ts index 04098f68c89..9a09ad2418d 100644 --- a/packages/core/src/awsService/sagemaker/detached-server/sessionStore.ts +++ b/packages/core/src/awsService/sagemaker/detached-server/sessionStore.ts @@ -9,7 +9,7 @@ import { readMapping, writeMapping } from './utils' export type SessionStatus = 'pending' | 'fresh' | 'consumed' | 'not-started' export class SessionStore { - async getRefreshUrl(connectionId: string) { + async getRefreshUrl(connectionId: string): Promise { const mapping = await readMapping() if (!mapping.deepLink) { @@ -21,10 +21,6 @@ export class SessionStore { throw new Error(`No mapping found for connectionId: "${connectionId}"`) } - if (!entry.refreshUrl) { - throw new Error(`No refreshUrl found for connectionId: "${connectionId}"`) - } - return entry.refreshUrl } @@ -113,6 +109,20 @@ export class SessionStore { await writeMapping(mapping) } + async cleanupExpiredConnection(connectionId: string) { + const mapping = await readMapping() + + if (!mapping.deepLink) { + throw new Error('No deepLink mapping found') + } + + // Remove the entire connection entry for the expired space + if (mapping.deepLink[connectionId]) { + delete mapping.deepLink[connectionId] + await writeMapping(mapping) + } + } + async setSession(connectionId: string, requestId: string, ssmConnectionInfo: SsmConnectionInfo) { const mapping = await readMapping() diff --git a/packages/core/src/awsService/sagemaker/model.ts b/packages/core/src/awsService/sagemaker/model.ts index e25e8791d4f..a9ab87647bf 100644 --- a/packages/core/src/awsService/sagemaker/model.ts +++ b/packages/core/src/awsService/sagemaker/model.ts @@ -85,7 +85,7 @@ export async function prepareDevEnvConnection( await persistSmusProjectCreds(spaceArn, node as SagemakerUnifiedStudioSpaceNode) } } else if (connectionType === 'sm_dl') { - await persistSSMConnection(spaceArn, domain ?? '', session, wsUrl, token, appType) + await persistSSMConnection(spaceArn, domain ?? '', session, wsUrl, token, appType, isSMUS) } await startLocalServer(ctx) diff --git a/packages/core/src/extensionNode.ts b/packages/core/src/extensionNode.ts index a8a7855913e..221ed32500e 100644 --- a/packages/core/src/extensionNode.ts +++ b/packages/core/src/extensionNode.ts @@ -199,7 +199,7 @@ export async function activate(context: vscode.ExtensionContext) { await handleAmazonQInstall() } - await activateSageMakerUnifiedStudio(context) + await activateSageMakerUnifiedStudio(extContext) await activateApplicationComposer(context) await activateThreatComposerEditor(context) diff --git a/packages/core/src/sagemakerunifiedstudio/activation.ts b/packages/core/src/sagemakerunifiedstudio/activation.ts index 7fefd2eb44a..9c47137d6da 100644 --- a/packages/core/src/sagemakerunifiedstudio/activation.ts +++ b/packages/core/src/sagemakerunifiedstudio/activation.ts @@ -3,21 +3,25 @@ * SPDX-License-Identifier: Apache-2.0 */ -import * as vscode from 'vscode' import { activate as activateConnectionMagicsSelector } from './connectionMagicsSelector/activation' import { activate as activateExplorer } from './explorer/activation' import { isSageMaker } from '../shared/extensionUtilities' import { initializeResourceMetadata } from './shared/utils/resourceMetadataUtils' import { setContext } from '../shared/vscode/setContext' import { SmusUtils } from './shared/smusUtils' +import * as smusUriHandlers from './uriHandlers' +import { ExtContext } from '../shared/extensions' -export async function activate(extensionContext: vscode.ExtensionContext): Promise { +export async function activate(ctx: ExtContext): Promise { // Only run when environment is a SageMaker Unified Studio space if (isSageMaker('SMUS') || isSageMaker('SMUS-SPACE-REMOTE-ACCESS')) { await initializeResourceMetadata() // Setting context before any getContext calls to avoid potential race conditions. await setContext('aws.smus.inSmusSpaceEnvironment', SmusUtils.isInSmusSpaceEnvironment()) - await activateConnectionMagicsSelector(extensionContext) + await activateConnectionMagicsSelector(ctx.extensionContext) } - await activateExplorer(extensionContext) + await activateExplorer(ctx.extensionContext) + + // Register SMUS URI handler for deeplink connections + ctx.extensionContext.subscriptions.push(smusUriHandlers.register(ctx)) } diff --git a/packages/core/src/sagemakerunifiedstudio/uriHandlers.ts b/packages/core/src/sagemakerunifiedstudio/uriHandlers.ts new file mode 100644 index 00000000000..590fa1e2e72 --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/uriHandlers.ts @@ -0,0 +1,121 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as vscode from 'vscode' +import { SearchParams } from '../shared/vscode/uriHandler' +import { ExtContext } from '../shared/extensions' +import { deeplinkConnect } from '../awsService/sagemaker/commands' +import { telemetry } from '../shared/telemetry/telemetry' +/** + * Registers the SMUS deeplink URI handler at path `/connect/smus`. + * + * This handler processes deeplink URLs from the SageMaker Unified Studio console + * to establish remote connections to SMUS spaces. + * + * @param ctx Extension context containing the URI handler + * @returns Disposable for cleanup + */ +export function register(ctx: ExtContext) { + async function connectHandler(params: ReturnType) { + await telemetry.smus_deeplinkConnect.run(async (span) => { + span.record(extractTelemetryMetadata(params)) + + // WORKAROUND: The ws_url from the startSession API call contains a query parameter + // 'cell-number' within itself. When the entire deeplink URL is processed by the URI + // handler, 'cell-number' is parsed as a standalone query parameter at the top level + // instead of remaining part of the ws_url. This causes the ws_url to lose the + // cell-number context it needs. To fix this, we manually re-append the cell-number + // query parameter back to the ws_url to restore the original intended URL structure. + await deeplinkConnect( + ctx, + params.connection_identifier, + params.session, + `${params.ws_url}&cell-number=${params['cell-number']}`, // Re-append cell-number to ws_url + params.token, + params.domain, + params.app_type, + true // isSMUS=true for SMUS connections + ) + }) + } + + return vscode.Disposable.from(ctx.uriHandler.onPath('/connect/smus', connectHandler, parseConnectParams)) +} + +/** + * Parses and validates SMUS deeplink URI parameters. + * + * Required parameters: + * - connection_identifier: Space ARN identifying the SMUS space + * - domain: Domain ID for the SMUS space (SM AI side) + * - user_profile: User profile name + * - session: SSM session ID + * - ws_url: WebSocket URL for SSM connection (originally contains cell-number as a query param) + * - cell-number: extracted from ws_url during URI parsing + * - token: Authentication token + * + * Optional parameters: + * - app_type: Application type (e.g., JupyterLab, CodeEditor) + * - smus_domain_id: SMUS domain identifier + * - smus_domain_account_id: SMUS domain account ID + * - smus_project_id: SMUS project identifier + * - smus_domain_region: SMUS domain region + * + * Note: The ws_url from startSession API originally includes cell-number as a query parameter. + * However, when the deeplink URL is processed, the URI handler extracts cell-number as a + * separate top-level parameter. This is why we need to re-append it in the connectHandler. + * + * @param query URI query parameters + * @returns Parsed parameters object + * @throws Error if required parameters are missing + */ +export function parseConnectParams(query: SearchParams) { + const requiredParams = query.getFromKeysOrThrow( + 'connection_identifier', + 'domain', + 'user_profile', + 'session', + 'ws_url', + 'cell-number', + 'token' + ) + const optionalParams = query.getFromKeys( + 'app_type', + 'smus_domain_id', + 'smus_domain_account_id', + 'smus_project_id', + 'smus_domain_region' + ) + + return { ...requiredParams, ...optionalParams } +} + +/** + * Extracts telemetry metadata from URI parameters and space ARN. + * + * @param params Parsed URI parameters + * @returns Telemetry metadata object + */ +function extractTelemetryMetadata(params: ReturnType) { + // Extract metadata from space ARN + // ARN format: arn:aws:sagemaker:region:account-id:space/domain-id/space-name + const arnParts = params.connection_identifier.split(':') + const resourceParts = arnParts[5]?.split('/') // Gets "space/domain-id/space-name" + + const projectRegion = arnParts[3] // region from ARN + const projectAccountId = arnParts[4] // account-id from ARN + const domainIdFromArn = resourceParts?.[1] // domain-id from ARN + const spaceName = resourceParts?.[2] // space-name from ARN + + return { + smusDomainId: params.smus_domain_id, + smusDomainAccountId: params.smus_domain_account_id, + smusProjectId: params.smus_project_id, + smusDomainRegion: params.smus_domain_region, + smusProjectRegion: projectRegion, + smusProjectAccountId: projectAccountId, + smusSpaceKey: domainIdFromArn && spaceName ? `${domainIdFromArn}/${spaceName}` : undefined, + } +} diff --git a/packages/core/src/shared/telemetry/vscodeTelemetry.json b/packages/core/src/shared/telemetry/vscodeTelemetry.json index fcf6140eb13..433ac87de69 100644 --- a/packages/core/src/shared/telemetry/vscodeTelemetry.json +++ b/packages/core/src/shared/telemetry/vscodeTelemetry.json @@ -1634,6 +1634,47 @@ "required": false } ] + }, + { + "name": "smus_deeplinkConnect", + "description": "Emitted when a user connects to a SMUS space via deeplink", + "metadata": [ + { + "type": "result" + }, + { + "type": "reason", + "required": false + }, + { + "type": "smusDomainId", + "required": false + }, + { + "type": "smusDomainAccountId", + "required": false + }, + { + "type": "smusProjectId", + "required": false + }, + { + "type": "smusDomainRegion", + "required": false + }, + { + "type": "smusProjectRegion", + "required": false + }, + { + "type": "smusProjectAccountId", + "required": false + }, + { + "type": "smusSpaceKey", + "required": false + } + ] } ] } diff --git a/packages/core/src/test/awsService/sagemaker/credentialMapping.test.ts b/packages/core/src/test/awsService/sagemaker/credentialMapping.test.ts index 3134f11e5e0..c114c8b0bba 100644 --- a/packages/core/src/test/awsService/sagemaker/credentialMapping.test.ts +++ b/packages/core/src/test/awsService/sagemaker/credentialMapping.test.ts @@ -218,6 +218,75 @@ describe('credentialMapping', () => { 'Unsupported or missing app type for space. Expected JupyterLab or CodeEditor, got: UnsupportedApp', }) }) + + it('stores undefined refreshUrl when isSMUS=true', async () => { + sandbox.stub(DevSettings.instance, 'get').returns({}) + sandbox.stub(fs, 'existsFile').resolves(false) + const writeStub = sandbox.stub(fs, 'writeFile').resolves() + + await persistSSMConnection(appArn, domain, 'sess-123', 'wss://smus-ws', 'token-xyz', 'jupyterlab', true) + + const raw = writeStub.firstCall.args[1] + const data = JSON.parse(typeof raw === 'string' ? raw : raw.toString()) + + // Verify refreshUrl is undefined for SMUS connections + assert.strictEqual(data.deepLink?.[appArn]?.refreshUrl, undefined) + + // Verify SSM connection info is stored correctly + assert.deepStrictEqual(data.deepLink?.[appArn]?.requests['initial-connection'], { + sessionId: 'sess-123', + url: 'wss://smus-ws', + token: 'token-xyz', + status: 'fresh', + }) + }) + + it('stores valid refreshUrl when isSMUS=false (SageMaker AI behavior)', async () => { + sandbox.stub(DevSettings.instance, 'get').returns({}) + sandbox.stub(fs, 'existsFile').resolves(false) + const writeStub = sandbox.stub(fs, 'writeFile').resolves() + + await persistSSMConnection(appArn, domain, 'sess-456', 'wss://sm-ws', 'token-abc', 'jupyterlab', false) + + const raw = writeStub.firstCall.args[1] + const data = JSON.parse(typeof raw === 'string' ? raw : raw.toString()) + + // Verify refreshUrl is present for SageMaker AI connections + assert.ok(data.deepLink?.[appArn]?.refreshUrl) + assertRefreshUrlMatches(data.deepLink?.[appArn]?.refreshUrl, 'studio.us-west-2.sagemaker.aws') + + // Verify SSM connection info is stored correctly + assert.deepStrictEqual(data.deepLink?.[appArn]?.requests['initial-connection'], { + sessionId: 'sess-456', + url: 'wss://sm-ws', + token: 'token-abc', + status: 'fresh', + }) + }) + + it('stores valid refreshUrl when isSMUS is undefined (default SageMaker AI behavior)', async () => { + sandbox.stub(DevSettings.instance, 'get').returns({}) + sandbox.stub(fs, 'existsFile').resolves(false) + const writeStub = sandbox.stub(fs, 'writeFile').resolves() + + // Call without isSMUS parameter (should default to SageMaker AI behavior) + await persistSSMConnection(appArn, domain, 'sess-789', 'wss://default-ws', 'token-def', 'jupyterlab') + + const raw = writeStub.firstCall.args[1] + const data = JSON.parse(typeof raw === 'string' ? raw : raw.toString()) + + // Verify refreshUrl is present when isSMUS is not specified + assert.ok(data.deepLink?.[appArn]?.refreshUrl) + assertRefreshUrlMatches(data.deepLink?.[appArn]?.refreshUrl, 'studio.us-west-2.sagemaker.aws') + + // Verify SSM connection info is stored correctly + assert.deepStrictEqual(data.deepLink?.[appArn]?.requests['initial-connection'], { + sessionId: 'sess-789', + url: 'wss://default-ws', + token: 'token-def', + status: 'fresh', + }) + }) }) describe('persistSmusProjectCreds', () => { diff --git a/packages/core/src/test/awsService/sagemaker/detached-server/routes/getSessionAsync.test.ts b/packages/core/src/test/awsService/sagemaker/detached-server/routes/getSessionAsync.test.ts index 8d3ab8563ee..9b3ecb2f2c9 100644 --- a/packages/core/src/test/awsService/sagemaker/detached-server/routes/getSessionAsync.test.ts +++ b/packages/core/src/test/awsService/sagemaker/detached-server/routes/getSessionAsync.test.ts @@ -9,6 +9,8 @@ import assert from 'assert' import { SessionStore } from '../../../../../awsService/sagemaker/detached-server/sessionStore' import { handleGetSessionAsync } from '../../../../../awsService/sagemaker/detached-server/routes/getSessionAsync' import * as utils from '../../../../../awsService/sagemaker/detached-server/utils' +import * as errorPage from '../../../../../awsService/sagemaker/detached-server/errorPage' +import { SmusDeeplinkSessionExpiredError } from '../../../../../awsService/sagemaker/constants' describe('handleGetSessionAsync', () => { let req: Partial @@ -27,6 +29,7 @@ describe('handleGetSessionAsync', () => { sinon.stub(SessionStore.prototype, 'getStatus').callsFake(storeStub.getStatus) sinon.stub(SessionStore.prototype, 'getRefreshUrl').callsFake(storeStub.getRefreshUrl) sinon.stub(SessionStore.prototype, 'markPending').callsFake(storeStub.markPending) + sinon.stub(SessionStore.prototype, 'cleanupExpiredConnection').callsFake(storeStub.cleanupExpiredConnection) }) it('responds with 400 if required query parameters are missing', async () => { @@ -93,6 +96,99 @@ describe('handleGetSessionAsync', () => { assert(resEnd.calledWith('Unexpected error')) }) + describe('SMUS session expiration handling', () => { + let openErrorPageStub: sinon.SinonStub + + beforeEach(() => { + // Stub the openErrorPage function to prevent actual browser opening + openErrorPageStub = sinon.stub(errorPage, 'openErrorPage').resolves() + }) + + it('handles SMUS session expiration when refreshUrl is undefined', async () => { + req = { url: '/session_async?connection_identifier=abc&request_id=req123' } + + storeStub.getFreshEntry.returns(Promise.resolve(undefined)) + storeStub.getStatus.returns(Promise.resolve('not-started')) + storeStub.getRefreshUrl.returns(Promise.resolve(undefined)) // SMUS case: no refreshUrl + storeStub.cleanupExpiredConnection.resolves() + + await handleGetSessionAsync(req as http.IncomingMessage, res as http.ServerResponse) + + // Verify HTTP 400 response with correct error structure + assert(resWriteHead.calledWith(400)) + const actualJson = JSON.parse(resEnd.firstCall.args[0]) + assert.strictEqual(actualJson.error, SmusDeeplinkSessionExpiredError.code) + assert.strictEqual(actualJson.message, SmusDeeplinkSessionExpiredError.shortMessage) + + // Verify cleanup was called + assert(storeStub.cleanupExpiredConnection.calledOnce) + assert(storeStub.cleanupExpiredConnection.calledWith('abc')) + + // Verify error page was opened with correct message + assert(openErrorPageStub.calledOnce) + assert.strictEqual(openErrorPageStub.firstCall.args[0], SmusDeeplinkSessionExpiredError.title) + assert.strictEqual(openErrorPageStub.firstCall.args[1], SmusDeeplinkSessionExpiredError.message) + }) + + it('responds with 400 even if cleanup fails', async () => { + req = { url: '/session_async?connection_identifier=abc&request_id=req123' } + + storeStub.getFreshEntry.returns(Promise.resolve(undefined)) + storeStub.getStatus.returns(Promise.resolve('not-started')) + storeStub.getRefreshUrl.returns(Promise.resolve(undefined)) + storeStub.cleanupExpiredConnection.rejects(new Error('cleanup failed')) + + await handleGetSessionAsync(req as http.IncomingMessage, res as http.ServerResponse) + + assert(resWriteHead.calledWith(400)) + const actualJson = JSON.parse(resEnd.firstCall.args[0]) + assert.strictEqual(actualJson.error, SmusDeeplinkSessionExpiredError.code) + }) + + it('responds with 202 when refreshUrl is valid (existing SageMaker AI flow)', async () => { + req = { url: '/session_async?connection_identifier=abc&request_id=req123' } + + storeStub.getFreshEntry.returns(Promise.resolve(undefined)) + storeStub.getStatus.returns(Promise.resolve('not-started')) + storeStub.getRefreshUrl.returns(Promise.resolve('https://example.com/refresh')) // Valid refreshUrl + storeStub.markPending.returns(Promise.resolve()) + + sinon.stub(utils, 'readServerInfo').resolves({ pid: 1234, port: 4567 }) + sinon + .stub(utils, 'parseArn') + .returns({ region: 'us-east-1', accountId: '123456789012', spaceName: 'test-space' }) + sinon.stub(utils, 'open').resolves() + + await handleGetSessionAsync(req as http.IncomingMessage, res as http.ServerResponse) + + // Verify SageMaker AI flow still works correctly + assert(resWriteHead.calledWith(202)) + assert(resEnd.calledWithMatch(/Session is not ready yet/)) + assert(storeStub.markPending.calledWith('abc', 'req123')) + }) + + it('does not call cleanupExpiredConnection for SageMaker AI connections', async () => { + req = { url: '/session_async?connection_identifier=abc&request_id=req123' } + + storeStub.getFreshEntry.returns(Promise.resolve(undefined)) + storeStub.getStatus.returns(Promise.resolve('not-started')) + storeStub.getRefreshUrl.returns(Promise.resolve('https://example.com/refresh')) + storeStub.markPending.returns(Promise.resolve()) + storeStub.cleanupExpiredConnection.resolves() + + sinon.stub(utils, 'readServerInfo').resolves({ pid: 1234, port: 4567 }) + sinon + .stub(utils, 'parseArn') + .returns({ region: 'us-east-1', accountId: '123456789012', spaceName: 'test-space' }) + sinon.stub(utils, 'open').resolves() + + await handleGetSessionAsync(req as http.IncomingMessage, res as http.ServerResponse) + + // Verify cleanup was NOT called + assert(storeStub.cleanupExpiredConnection.notCalled) + }) + }) + afterEach(() => { sinon.restore() }) diff --git a/packages/core/src/test/awsService/sagemaker/detached-server/sessionStore.test.ts b/packages/core/src/test/awsService/sagemaker/detached-server/sessionStore.test.ts index 2a7828a4951..468b92faa15 100644 --- a/packages/core/src/test/awsService/sagemaker/detached-server/sessionStore.test.ts +++ b/packages/core/src/test/awsService/sagemaker/detached-server/sessionStore.test.ts @@ -40,6 +40,28 @@ describe('SessionStore', () => { assert.strictEqual(result, 'https://refresh.url') }) + it('returns undefined for SMUS connections (no refreshUrl)', async () => { + const store = new SessionStore() + readMappingStub.returns({ + deepLink: { + [connectionId]: { + refreshUrl: undefined, + requests: { + 'initial-connection': { sessionId: 's0', token: 't0', url: 'u0', status: 'fresh' }, + }, + }, + }, + }) + const result = await store.getRefreshUrl(connectionId) + assert.strictEqual(result, undefined) + }) + + it('returns valid URL for SageMaker AI connections (existing behavior)', async () => { + const store = new SessionStore() + const result = await store.getRefreshUrl(connectionId) + assert.strictEqual(result, 'https://refresh.url') + }) + it('throws if no mapping exists for connectionId', async () => { const store = new SessionStore() readMappingStub.returns({ deepLink: {} }) @@ -47,6 +69,13 @@ describe('SessionStore', () => { await assert.rejects(() => store.getRefreshUrl('missing'), /No mapping found/) }) + it('throws if no deepLink mapping exists', async () => { + const store = new SessionStore() + readMappingStub.returns({}) + + await assert.rejects(() => store.getRefreshUrl(connectionId), /No deepLink mapping found/) + }) + it('returns fresh entry and marks consumed', async () => { const store = new SessionStore() const result = await store.getFreshEntry(connectionId, requestId) @@ -142,4 +171,44 @@ describe('SessionStore', () => { status: 'fresh', }) }) + + it('cleans up expired connection', async () => { + const store = new SessionStore() + await store.cleanupExpiredConnection(connectionId) + const updated = writeMappingStub.firstCall.args[0] + assert.strictEqual(updated.deepLink[connectionId], undefined) + }) + + it('does not throw when cleaning up non-existent connection', async () => { + const store = new SessionStore() + await store.cleanupExpiredConnection('non-existent-connection') + assert(writeMappingStub.notCalled) + }) + + it('cleans up only the specified connection without affecting other connections', async () => { + const store = new SessionStore() + const otherConnectionId = 'other-connection' + readMappingStub.returns({ + deepLink: { + [connectionId]: { + refreshUrl: undefined, + requests: { + 'initial-connection': { sessionId: 's1', token: 't1', url: 'u1', status: 'fresh' }, + }, + }, + [otherConnectionId]: { + refreshUrl: 'https://refresh.url', + requests: { + 'initial-connection': { sessionId: 's2', token: 't2', url: 'u2', status: 'fresh' }, + }, + }, + }, + }) + + await store.cleanupExpiredConnection(connectionId) + const updated = writeMappingStub.firstCall.args[0] + assert.strictEqual(updated.deepLink[connectionId], undefined) + assert.ok(updated.deepLink[otherConnectionId]) + assert.ok(updated.deepLink[otherConnectionId].requests['initial-connection']) + }) }) diff --git a/packages/core/src/test/sagemakerunifiedstudio/activation.test.ts b/packages/core/src/test/sagemakerunifiedstudio/activation.test.ts index 0756cdcbe88..7c920771995 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/activation.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/activation.test.ts @@ -13,8 +13,11 @@ import * as explorerActivation from '../../sagemakerunifiedstudio/explorer/activ import * as resourceMetadataUtils from '../../sagemakerunifiedstudio/shared/utils/resourceMetadataUtils' import * as setContext from '../../shared/vscode/setContext' import { SmusUtils } from '../../sagemakerunifiedstudio/shared/smusUtils' +import * as smusUriHandlers from '../../sagemakerunifiedstudio/uriHandlers' +import { ExtContext } from '../../shared/extensions' describe('SageMaker Unified Studio Main Activation', function () { + let mockToolkitExtContext: ExtContext let mockExtensionContext: vscode.ExtensionContext let isSageMakerStub: sinon.SinonStub let initializeResourceMetadataStub: sinon.SinonStub @@ -22,6 +25,7 @@ describe('SageMaker Unified Studio Main Activation', function () { let isInSmusSpaceEnvironmentStub: sinon.SinonStub let activateConnectionMagicsSelectorStub: sinon.SinonStub let activateExplorerStub: sinon.SinonStub + let registerUriHandlerStub: sinon.SinonStub beforeEach(function () { mockExtensionContext = { @@ -37,6 +41,17 @@ describe('SageMaker Unified Studio Main Activation', function () { }, } as any + mockToolkitExtContext = { + extensionContext: mockExtensionContext, + awsContext: {} as any, + samCliContext: sinon.stub() as any, + regionProvider: {} as any, + outputChannel: {} as any, + telemetryService: {} as any, + uriHandler: {} as any, + credentialsStore: {} as any, + } + // Stub all dependencies isSageMakerStub = sinon.stub(extensionUtilities, 'isSageMaker') initializeResourceMetadataStub = sinon.stub(resourceMetadataUtils, 'initializeResourceMetadata') @@ -44,6 +59,7 @@ describe('SageMaker Unified Studio Main Activation', function () { isInSmusSpaceEnvironmentStub = sinon.stub(SmusUtils, 'isInSmusSpaceEnvironment') activateConnectionMagicsSelectorStub = sinon.stub(connectionMagicsSelectorActivation, 'activate') activateExplorerStub = sinon.stub(explorerActivation, 'activate') + registerUriHandlerStub = sinon.stub(smusUriHandlers, 'register') // Set default return values isSageMakerStub.returns(false) @@ -52,6 +68,7 @@ describe('SageMaker Unified Studio Main Activation', function () { isInSmusSpaceEnvironmentStub.returns(false) activateConnectionMagicsSelectorStub.resolves() activateExplorerStub.resolves() + registerUriHandlerStub.returns({ dispose: sinon.stub() } as any) }) afterEach(function () { @@ -62,7 +79,7 @@ describe('SageMaker Unified Studio Main Activation', function () { it('should always activate explorer regardless of environment', async function () { isSageMakerStub.returns(false) - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) assert.ok(activateExplorerStub.calledOnceWith(mockExtensionContext)) }) @@ -70,7 +87,7 @@ describe('SageMaker Unified Studio Main Activation', function () { it('should not initialize SMUS components when not in SageMaker environment', async function () { isSageMakerStub.returns(false) - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) assert.ok(initializeResourceMetadataStub.notCalled) assert.ok(setContextStub.notCalled) @@ -83,7 +100,7 @@ describe('SageMaker Unified Studio Main Activation', function () { isSageMakerStub.withArgs('SMUS-SPACE-REMOTE-ACCESS').returns(false) isInSmusSpaceEnvironmentStub.returns(true) - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) assert.ok(initializeResourceMetadataStub.calledOnce) assert.ok(setContextStub.calledOnceWith('aws.smus.inSmusSpaceEnvironment', true)) @@ -96,7 +113,7 @@ describe('SageMaker Unified Studio Main Activation', function () { isSageMakerStub.withArgs('SMUS-SPACE-REMOTE-ACCESS').returns(true) isInSmusSpaceEnvironmentStub.returns(false) - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) assert.ok(initializeResourceMetadataStub.calledOnce) assert.ok(setContextStub.calledOnceWith('aws.smus.inSmusSpaceEnvironment', false)) @@ -109,7 +126,7 @@ describe('SageMaker Unified Studio Main Activation', function () { isSageMakerStub.withArgs('SMUS-SPACE-REMOTE-ACCESS').returns(false) isInSmusSpaceEnvironmentStub.returns(true) - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) // Verify the order of calls assert.ok(initializeResourceMetadataStub.calledBefore(setContextStub)) @@ -122,7 +139,7 @@ describe('SageMaker Unified Studio Main Activation', function () { const error = new Error('Resource metadata initialization failed') initializeResourceMetadataStub.rejects(error) - await assert.rejects(() => activate(mockExtensionContext), /Resource metadata initialization failed/) + await assert.rejects(() => activate(mockToolkitExtContext), /Resource metadata initialization failed/) assert.ok(initializeResourceMetadataStub.calledOnce) assert.ok(setContextStub.notCalled) @@ -135,7 +152,7 @@ describe('SageMaker Unified Studio Main Activation', function () { const error = new Error('Set context failed') setContextStub.rejects(error) - await assert.rejects(() => activate(mockExtensionContext), /Set context failed/) + await assert.rejects(() => activate(mockToolkitExtContext), /Set context failed/) assert.ok(initializeResourceMetadataStub.calledOnce) assert.ok(setContextStub.calledOnce) @@ -148,7 +165,7 @@ describe('SageMaker Unified Studio Main Activation', function () { const error = new Error('Connection magics selector activation failed') activateConnectionMagicsSelectorStub.rejects(error) - await assert.rejects(() => activate(mockExtensionContext), /Connection magics selector activation failed/) + await assert.rejects(() => activate(mockToolkitExtContext), /Connection magics selector activation failed/) assert.ok(initializeResourceMetadataStub.calledOnce) assert.ok(setContextStub.calledOnce) @@ -159,7 +176,7 @@ describe('SageMaker Unified Studio Main Activation', function () { const error = new Error('Explorer activation failed') activateExplorerStub.rejects(error) - await assert.rejects(() => activate(mockExtensionContext), /Explorer activation failed/) + await assert.rejects(() => activate(mockToolkitExtContext), /Explorer activation failed/) assert.ok(activateExplorerStub.calledOnce) }) @@ -168,11 +185,18 @@ describe('SageMaker Unified Studio Main Activation', function () { isSageMakerStub.withArgs('SMUS').returns(true) isInSmusSpaceEnvironmentStub.returns(true) - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) assert.ok(activateConnectionMagicsSelectorStub.calledWith(mockExtensionContext)) assert.ok(activateExplorerStub.calledWith(mockExtensionContext)) }) + + it('should register URI handler', async function () { + await activate(mockToolkitExtContext) + + assert.ok(registerUriHandlerStub.calledOnceWith(mockToolkitExtContext)) + assert.ok(mockExtensionContext.subscriptions.length > 0) + }) }) describe('environment detection logic', function () { @@ -180,7 +204,7 @@ describe('SageMaker Unified Studio Main Activation', function () { isSageMakerStub.withArgs('SMUS').returns(false) isSageMakerStub.withArgs('SMUS-SPACE-REMOTE-ACCESS').returns(false) - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) assert.ok(isSageMakerStub.calledWith('SMUS')) assert.ok(isSageMakerStub.calledWith('SMUS-SPACE-REMOTE-ACCESS')) @@ -192,7 +216,7 @@ describe('SageMaker Unified Studio Main Activation', function () { isSageMakerStub.withArgs('SMUS-SPACE-REMOTE-ACCESS').returns(false) isInSmusSpaceEnvironmentStub.returns(true) - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) assert.ok(initializeResourceMetadataStub.calledOnce) assert.ok(activateConnectionMagicsSelectorStub.calledOnce) @@ -206,7 +230,7 @@ describe('SageMaker Unified Studio Main Activation', function () { isSageMakerStub.withArgs('SMUS-SPACE-REMOTE-ACCESS').returns(true) isInSmusSpaceEnvironmentStub.returns(false) - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) assert.ok(initializeResourceMetadataStub.calledOnce) assert.ok(activateConnectionMagicsSelectorStub.calledOnce) @@ -217,13 +241,13 @@ describe('SageMaker Unified Studio Main Activation', function () { // Test with true isInSmusSpaceEnvironmentStub.returns(true) - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) assert.ok(setContextStub.calledWith('aws.smus.inSmusSpaceEnvironment', true)) // Reset and test with false setContextStub.resetHistory() isInSmusSpaceEnvironmentStub.returns(false) - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) assert.ok(setContextStub.calledWith('aws.smus.inSmusSpaceEnvironment', false)) }) }) @@ -237,7 +261,7 @@ describe('SageMaker Unified Studio Main Activation', function () { const setContextError = new Error('Context setting failed') setContextStub.rejects(setContextError) - await assert.rejects(() => activate(mockExtensionContext), /Context setting failed/) + await assert.rejects(() => activate(mockToolkitExtContext), /Context setting failed/) // Verify that initializeResourceMetadata was called but subsequent functions were not assert.ok(initializeResourceMetadataStub.calledOnce) @@ -252,22 +276,25 @@ describe('SageMaker Unified Studio Main Activation', function () { isInSmusSpaceEnvironmentStub.returns(true) // All functions should succeed - await activate(mockExtensionContext) + await activate(mockToolkitExtContext) // Verify all expected functions were called assert.ok(initializeResourceMetadataStub.calledOnce) assert.ok(setContextStub.calledOnce) assert.ok(activateConnectionMagicsSelectorStub.calledOnce) assert.ok(activateExplorerStub.calledOnce) + assert.ok(registerUriHandlerStub.calledOnce) }) - it('should handle undefined extension context gracefully', async function () { - const undefinedContext = undefined as any + it('should handle minimal extension context gracefully', async function () { + const minimalContext = { + extensionContext: mockExtensionContext, + } as any - // Should not throw for undefined context, but let the individual activation functions handle it - await activate(undefinedContext) + // Should not throw with minimal context + await activate(minimalContext) - assert.ok(activateExplorerStub.calledWith(undefinedContext)) + assert.ok(activateExplorerStub.called) }) }) }) diff --git a/packages/core/src/test/sagemakerunifiedstudio/uriHandlers.test.ts b/packages/core/src/test/sagemakerunifiedstudio/uriHandlers.test.ts new file mode 100644 index 00000000000..ba3aff2b629 --- /dev/null +++ b/packages/core/src/test/sagemakerunifiedstudio/uriHandlers.test.ts @@ -0,0 +1,88 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import assert from 'assert' +import { SearchParams } from '../../shared/vscode/uriHandler' +import { parseConnectParams } from '../../sagemakerunifiedstudio/uriHandlers' + +describe('SMUS URI Handler', function () { + describe('parseConnectParams', function () { + const validParams = { + connection_identifier: 'arn:aws:sagemaker:us-west-2:123456789012:space/d-abc123/my-space', + domain: 'd-abc123', + user_profile: 'test-user', + session: 'sess-abc123', + ws_url: 'wss://ssm.us-west-2.amazonaws.com/stream', + 'cell-number': '1', + token: 'bearer-token-xyz', + } + + it('successfully parses all required parameters', function () { + const query = new SearchParams(validParams) + const result = parseConnectParams(query) + + assert.strictEqual(result.connection_identifier, validParams.connection_identifier) + assert.strictEqual(result.domain, validParams.domain) + assert.strictEqual(result.user_profile, validParams.user_profile) + assert.strictEqual(result.session, validParams.session) + assert.strictEqual(result.ws_url, validParams.ws_url) + assert.strictEqual(result['cell-number'], validParams['cell-number']) + assert.strictEqual(result.token, validParams.token) + }) + + it('throws error when required parameters are missing', function () { + const requiredParams = [ + 'connection_identifier', + 'domain', + 'user_profile', + 'session', + 'ws_url', + 'cell-number', + 'token', + ] as const + + for (const param of requiredParams) { + const { [param]: _removed, ...paramsWithoutOne } = validParams + const query = new SearchParams(paramsWithoutOne) + + assert.throws( + () => parseConnectParams(query), + new RegExp(`${param}.*must be provided`), + `Should throw error for missing ${param}` + ) + } + }) + + it('handles optional parameters correctly', function () { + // Test with all optional parameters present + const paramsWithAllOptional = { + ...validParams, + app_type: 'CodeEditor', + smus_domain_id: 'smus-domain-789', + smus_domain_account_id: '111222333444', + smus_project_id: 'project-999', + smus_domain_region: 'eu-west-1', + } + const queryWithOptional = new SearchParams(paramsWithAllOptional) + const resultWithOptional = parseConnectParams(queryWithOptional) + + assert.strictEqual(resultWithOptional.app_type, 'CodeEditor') + assert.strictEqual(resultWithOptional.smus_domain_id, 'smus-domain-789') + assert.strictEqual(resultWithOptional.smus_domain_account_id, '111222333444') + assert.strictEqual(resultWithOptional.smus_project_id, 'project-999') + assert.strictEqual(resultWithOptional.smus_domain_region, 'eu-west-1') + + // Test without optional parameters - should return undefined + const queryWithoutOptional = new SearchParams(validParams) + const resultWithoutOptional = parseConnectParams(queryWithoutOptional) + + assert.strictEqual(resultWithoutOptional.app_type, undefined) + assert.strictEqual(resultWithoutOptional.smus_domain_id, undefined) + assert.strictEqual(resultWithoutOptional.smus_domain_account_id, undefined) + assert.strictEqual(resultWithoutOptional.smus_project_id, undefined) + assert.strictEqual(resultWithoutOptional.smus_domain_region, undefined) + }) + }) +}) diff --git a/packages/toolkit/.changes/next-release/Feature-d8fd25bc-f07e-4581-a176-4ebf9d9eb606.json b/packages/toolkit/.changes/next-release/Feature-d8fd25bc-f07e-4581-a176-4ebf9d9eb606.json new file mode 100644 index 00000000000..313c0d5a54b --- /dev/null +++ b/packages/toolkit/.changes/next-release/Feature-d8fd25bc-f07e-4581-a176-4ebf9d9eb606.json @@ -0,0 +1,4 @@ +{ + "type": "Feature", + "description": "Deeplink support for SageMaker Unified Studio" +}