diff --git a/packages/core/src/awsService/sagemaker/commands.ts b/packages/core/src/awsService/sagemaker/commands.ts index 0075d7e5dff..8ed485dc9dc 100644 --- a/packages/core/src/awsService/sagemaker/commands.ts +++ b/packages/core/src/awsService/sagemaker/commands.ts @@ -85,7 +85,8 @@ export async function deeplinkConnect( session: string, wsUrl: string, token: string, - domain: string + domain: string, + appType?: string ) { getLogger().debug( `sm:deeplinkConnect: connectionIdentifier: ${connectionIdentifier} session: ${session} wsUrl: ${wsUrl} token: ${token}` @@ -104,7 +105,8 @@ export async function deeplinkConnect( session, wsUrl, token, - domain + domain, + appType ) await startVscodeRemote( diff --git a/packages/core/src/awsService/sagemaker/credentialMapping.ts b/packages/core/src/awsService/sagemaker/credentialMapping.ts index 60d4e94260e..6567a3fa223 100644 --- a/packages/core/src/awsService/sagemaker/credentialMapping.ts +++ b/packages/core/src/awsService/sagemaker/credentialMapping.ts @@ -79,15 +79,16 @@ export async function persistSSMConnection( domain: string, session?: string, wsUrl?: string, - token?: string + token?: string, + appType?: string ): Promise { const { region } = parseArn(appArn) const endpoint = DevSettings.instance.get('endpoints', {})['sagemaker'] ?? '' - // TODO: Hardcoded to 'jupyterlab' due to a bug in Studio that only supports refreshing - // the token for both CodeEditor and JupyterLab Apps in the jupyterlab subdomain. - // This will be fixed shortly after NYSummit launch to support refresh URL in CodeEditor subdomain. - const appSubDomain = 'jupyterlab' + let appSubDomain = 'jupyterlab' + if (appType && appType.toLowerCase() === 'codeeditor') { + appSubDomain = 'code-editor' + } let envSubdomain: string diff --git a/packages/core/src/awsService/sagemaker/model.ts b/packages/core/src/awsService/sagemaker/model.ts index 20a667a0bfa..041cb155692 100644 --- a/packages/core/src/awsService/sagemaker/model.ts +++ b/packages/core/src/awsService/sagemaker/model.ts @@ -50,7 +50,8 @@ export async function prepareDevEnvConnection( session?: string, wsUrl?: string, token?: string, - domain?: string + domain?: string, + appType?: string ) { const remoteLogger = configureRemoteConnectionLogger() const { ssm, vsc, ssh } = (await ensureDependencies()).unwrap() @@ -72,7 +73,7 @@ export async function prepareDevEnvConnection( if (connectionType === 'sm_lc') { await persistLocalCredentials(appArn) } else if (connectionType === 'sm_dl') { - await persistSSMConnection(appArn, domain ?? '', session, wsUrl, token) + await persistSSMConnection(appArn, domain ?? '', session, wsUrl, token, appType) } await startLocalServer(ctx) diff --git a/packages/core/src/awsService/sagemaker/uriHandlers.ts b/packages/core/src/awsService/sagemaker/uriHandlers.ts index 17c3c512272..6f1143d9054 100644 --- a/packages/core/src/awsService/sagemaker/uriHandlers.ts +++ b/packages/core/src/awsService/sagemaker/uriHandlers.ts @@ -18,7 +18,8 @@ export function register(ctx: ExtContext) { params.session, `${params.ws_url}&cell-number=${params['cell-number']}`, params.token, - params.domain + params.domain, + params.app_type ) }) } @@ -27,7 +28,7 @@ export function register(ctx: ExtContext) { } export function parseConnectParams(query: SearchParams) { - const params = query.getFromKeysOrThrow( + const requiredParams = query.getFromKeysOrThrow( 'connection_identifier', 'domain', 'user_profile', @@ -36,5 +37,7 @@ export function parseConnectParams(query: SearchParams) { 'cell-number', 'token' ) - return params + const optionalParams = query.getFromKeys('app_type') + + return { ...requiredParams, ...optionalParams } } diff --git a/packages/core/src/test/awsService/sagemaker/uriHandlers.test.ts b/packages/core/src/test/awsService/sagemaker/uriHandlers.test.ts index 9ff24b2a3f9..07e20e424b6 100644 --- a/packages/core/src/test/awsService/sagemaker/uriHandlers.test.ts +++ b/packages/core/src/test/awsService/sagemaker/uriHandlers.test.ts @@ -44,6 +44,7 @@ describe('SageMaker URI handler', function () { ws_url: 'wss://example.com', 'cell-number': '4', token: 'my-token', + app_type: 'jupyterlab', } const uri = createConnectUri(params) @@ -55,5 +56,29 @@ describe('SageMaker URI handler', function () { assert.deepStrictEqual(deeplinkConnectStub.firstCall.args[3], 'wss://example.com&cell-number=4') assert.deepStrictEqual(deeplinkConnectStub.firstCall.args[4], 'my-token') assert.deepStrictEqual(deeplinkConnectStub.firstCall.args[5], 'my-domain') + assert.deepStrictEqual(deeplinkConnectStub.firstCall.args[6], 'jupyterlab') + }) + + it('calls deeplinkConnect with undefined app_type when not provided', async function () { + const params = { + connection_identifier: 'abc123', + domain: 'my-domain', + user_profile: 'me', + session: 'sess-xyz', + ws_url: 'wss://example.com', + 'cell-number': '4', + token: 'my-token', + } + + const uri = createConnectUri(params) + await handler.handleUri(uri) + + assert.ok(deeplinkConnectStub.calledOnce) + assert.deepStrictEqual(deeplinkConnectStub.firstCall.args[1], 'abc123') + assert.deepStrictEqual(deeplinkConnectStub.firstCall.args[2], 'sess-xyz') + assert.deepStrictEqual(deeplinkConnectStub.firstCall.args[3], 'wss://example.com&cell-number=4') + assert.deepStrictEqual(deeplinkConnectStub.firstCall.args[4], 'my-token') + assert.deepStrictEqual(deeplinkConnectStub.firstCall.args[5], 'my-domain') + assert.deepStrictEqual(deeplinkConnectStub.firstCall.args[6], undefined) }) })