diff --git a/packages/core/src/sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider.ts b/packages/core/src/sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider.ts index 828e848f810..f060e6477ab 100644 --- a/packages/core/src/sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider.ts +++ b/packages/core/src/sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider.ts @@ -73,6 +73,14 @@ export class ConnectionCredentialsProvider implements CredentialsProvider { return this.smusAuthProvider.getDomainRegion() } + /** + * Gets the domain AWS account ID + * @returns Promise resolving to the domain account ID + */ + public async getDomainAccountId(): Promise { + return this.smusAuthProvider.getDomainAccountId() + } + /** * Gets the hash code * @returns Hash code diff --git a/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts b/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts index 7f331ae4f46..8aadda5fcc0 100644 --- a/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts +++ b/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts @@ -14,7 +14,7 @@ import * as localizedText from '../../../shared/localizedText' import { ToolkitPromptSettings } from '../../../shared/settings' import { setContext, getContext } from '../../../shared/vscode/setContext' import { getLogger } from '../../../shared/logger/logger' -import { SmusUtils, SmusErrorCodes } from '../../shared/smusUtils' +import { SmusUtils, SmusErrorCodes, extractAccountIdFromArn } from '../../shared/smusUtils' import { createSmusProfile, isValidSmusConnection, SmusConnection } from '../model' import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsProvider' import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider' @@ -23,6 +23,7 @@ import { ConnectionClientStore } from '../../shared/client/connectionClientStore import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils' import { fromIni } from '@aws-sdk/credential-providers' import { randomUUID } from '../../../shared/crypto' +import { DefaultStsClient } from '../../../shared/clients/stsClient' /** * Sets the context variable for SageMaker Unified Studio connection state @@ -53,6 +54,7 @@ export class SmusAuthenticationProvider { private credentialsProviderCache = new Map() private projectCredentialProvidersCache = new Map() private connectionCredentialProvidersCache = new Map() + private cachedDomainAccountId: string | undefined public constructor( public readonly auth = Auth.instance, @@ -75,6 +77,8 @@ export class SmusAuthenticationProvider { this.projectCredentialProvidersCache.clear() // Clear connection provider cache when connection changes this.connectionCredentialProvidersCache.clear() + // Clear cached domain account ID when connection changes + this.cachedDomainAccountId = undefined // Clear all clients in client store when connection changes ConnectionClientStore.getInstance().clearAll() await setSmusConnectedContext(this.isConnected()) @@ -423,6 +427,99 @@ export class SmusAuthenticationProvider { return this.activeConnection.domainUrl } + /** + * Gets the AWS account ID for the active domain connection + * In SMUS space environment, extracts from ResourceArn in metadata + * Otherwise, makes an STS GetCallerIdentity call using DER credentials and caches the result + * @returns Promise resolving to the domain's AWS account ID + * @throws ToolkitError if unable to retrieve account ID + */ + public async getDomainAccountId(): Promise { + const logger = getLogger() + + // Return cached value if available + if (this.cachedDomainAccountId) { + logger.debug('SMUS: Using cached domain account ID') + return this.cachedDomainAccountId + } + + // If in SMUS space environment, extract account ID from resource-metadata file + if (getContext('aws.smus.inSmusSpaceEnvironment')) { + try { + logger.debug('SMUS: Extracting domain account ID from ResourceArn in resource-metadata file') + + const resourceMetadata = getResourceMetadata()! + const resourceArn = resourceMetadata.ResourceArn + + if (!resourceArn) { + throw new ToolkitError('ResourceArn not found in metadata file', { + code: SmusErrorCodes.AccountIdNotFound, + }) + } + + // Extract account ID from ResourceArn using SmusUtils + const accountId = extractAccountIdFromArn(resourceArn) + + // Cache the account ID + this.cachedDomainAccountId = accountId + + logger.debug( + `Successfully extracted and cached domain account ID from resource-metadata file: ${accountId}` + ) + + return accountId + } catch (err) { + logger.error(`Failed to extract domain account ID from ResourceArn: %s`, err) + + throw new ToolkitError('Failed to extract AWS account ID from ResourceArn in SMUS space environment', { + code: SmusErrorCodes.GetDomainAccountIdFailed, + cause: err instanceof Error ? err : undefined, + }) + } + } + + if (!this.activeConnection) { + throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection }) + } + + // Use existing STS GetCallerIdentity implementation for non-SMUS space environments + try { + logger.debug('Fetching domain account ID via STS GetCallerIdentity') + + // Get DER credentials provider + const derCredProvider = await this.getDerCredentialsProvider() + + // Get the region for STS client + const region = this.getDomainRegion() + + // Create STS client with DER credentials + const stsClient = new DefaultStsClient(region, await derCredProvider.getCredentials()) + + // Make GetCallerIdentity call + const callerIdentity = await stsClient.getCallerIdentity() + + if (!callerIdentity.Account) { + throw new ToolkitError('Account ID not found in STS GetCallerIdentity response', { + code: SmusErrorCodes.AccountIdNotFound, + }) + } + + // Cache the account ID + this.cachedDomainAccountId = callerIdentity.Account + + logger.debug(`Successfully retrieved and cached domain account ID: ${callerIdentity.Account}`) + + return callerIdentity.Account + } catch (err) { + logger.error(`Failed to retrieve domain account ID: %s`, err) + + throw new ToolkitError('Failed to retrieve AWS account ID for active domain connection', { + code: SmusErrorCodes.GetDomainAccountIdFailed, + cause: err instanceof Error ? err : undefined, + }) + } + } + public getDomainRegion(): string { if (getContext('aws.smus.inSmusSpaceEnvironment')) { const resourceMetadata = getResourceMetadata()! @@ -516,6 +613,10 @@ export class SmusAuthenticationProvider { logger.warn(`SMUS: Failed to invalidate connection credentials for cache key ${cacheKey}: %s`, err) } } + + // Clear cached domain account ID + this.cachedDomainAccountId = undefined + logger.debug('SMUS: Cleared cached domain account ID') } /** @@ -560,6 +661,10 @@ export class SmusAuthenticationProvider { } } this.credentialsProviderCache.clear() + + // Clear cached domain account ID + this.cachedDomainAccountId = undefined + this.logger.debug('SMUS Auth: Successfully disposed authentication provider') } diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/activation.ts b/packages/core/src/sagemakerunifiedstudio/explorer/activation.ts index 65aed68e670..8a686b48654 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/activation.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/activation.ts @@ -20,8 +20,8 @@ import { getLogger } from '../../shared/logger/logger' import { setSmusConnectedContext, SmusAuthenticationProvider } from '../auth/providers/smusAuthenticationProvider' import { setupUserActivityMonitoring } from '../../awsService/sagemaker/sagemakerSpace' import { telemetry } from '../../shared/telemetry/telemetry' -import { SageMakerUnifiedStudioSpacesParentNode } from './nodes/sageMakerUnifiedStudioSpacesParentNode' import { isSageMaker } from '../../shared/extensionUtilities' +import { recordSpaceTelemetry } from '../shared/telemetry' export async function activate(extensionContext: vscode.ExtensionContext): Promise { // Initialize the SMUS authentication provider @@ -75,16 +75,7 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi return } await telemetry.smus_stopSpace.run(async (span) => { - span.record({ - smusSpaceKey: node.resource.DomainSpaceKey, - smusDomainRegion: node.resource.regionCode, - smusDomainId: ( - node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode - )?.getAuthProvider()?.activeConnection?.domainId, - smusProjectId: ( - node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode - )?.getProjectId(), - }) + await recordSpaceTelemetry(span, node) await stopSpace(node.resource, extensionContext, node.resource.sageMakerClient) }) }), @@ -95,17 +86,8 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi if (!validateNode(node)) { return } - await telemetry.smus_startSpace.run(async (span) => { - span.record({ - smusSpaceKey: node.resource.DomainSpaceKey, - smusDomainRegion: node.resource.regionCode, - smusDomainId: ( - node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode - )?.getAuthProvider()?.activeConnection?.domainId, - smusProjectId: ( - node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode - )?.getProjectId(), - }) + await telemetry.smus_openRemoteConnection.run(async (span) => { + await recordSpaceTelemetry(span, node) await openRemoteConnect(node.resource, extensionContext, node.resource.sageMakerClient) }) } diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.ts index 43c51997d3b..546a73135c6 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.ts @@ -34,7 +34,7 @@ import { createPlaceholderItem } from '../../../shared/treeview/utils' import { Column, Database, Table } from '@aws-sdk/client-glue' import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider' import { telemetry } from '../../../shared/telemetry/telemetry' -import { getContext } from '../../../shared/vscode/setContext' +import { recordDataConnectionTelemetry } from '../../shared/telemetry' /** * Lakehouse data node for SageMaker Unified Studio @@ -152,16 +152,7 @@ export function createLakehouseConnectionNode( }, async (node) => { return telemetry.smus_renderLakehouseNode.run(async (span) => { - const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') - - span.record({ - smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', - smusDomainId: connection.domainId, - smusProjectId: connection.projectId, - smusConnectionId: connection.connectionId, - smusConnectionType: connection.type, - smusProjectRegion: connection.location?.awsRegion, - }) + await recordDataConnectionTelemetry(span, connection, connectionCredentialsProvider) try { logger.info(`Loading Lakehouse catalogs for connection ${connection.name}`) diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts index 8857aec1449..9dc38b33a7a 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts @@ -131,10 +131,11 @@ export function createRedshiftConnectionNode( async (node) => { return telemetry.smus_renderRedshiftNode.run(async (span) => { const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') - + const accountId = await connectionCredentialsProvider.getDomainAccountId() span.record({ smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', smusDomainId: connection.domainId, + smusDomainAccountId: accountId, smusProjectId: connection.projectId, smusConnectionId: connection.connectionId, smusConnectionType: connection.type, diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/s3Strategy.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/s3Strategy.ts index 9c1130c2553..4106a0b4889 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/s3Strategy.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/s3Strategy.ts @@ -20,7 +20,7 @@ import { import { S3, ListObjectsV2Command } from '@aws-sdk/client-s3' import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider' import { telemetry } from '../../../shared/telemetry/telemetry' -import { getContext } from '../../../shared/vscode/setContext' +import { recordDataConnectionTelemetry } from '../../shared/telemetry' // Regex to match default S3 connection names // eslint-disable-next-line @typescript-eslint/naming-convention @@ -144,16 +144,7 @@ export function createS3ConnectionNode( }, async (node) => { return telemetry.smus_renderS3Node.run(async (span) => { - const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') - - span.record({ - smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', - smusDomainId: connection.domainId, - smusProjectId: connection.projectId, - smusConnectionId: connection.connectionId, - smusConnectionType: connection.type, - smusProjectRegion: connection.location?.awsRegion, - }) + await recordDataConnectionTelemetry(span, connection, connectionCredentialsProvider) try { if (isDefaultConnection && s3Info.prefix) { // For default connections, show the full path as the first node diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.ts index d47933e9948..8097ceed9e7 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.ts @@ -82,10 +82,11 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode { return telemetry.smus_renderProjectChildrenNode.run(async (span) => { try { const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') - + const accountId = await this.authProvider.getDomainAccountId() span.record({ smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', smusDomainId: this.project?.domainId, + smusDomainAccountId: accountId, smusProjectId: this.project?.id, smusDomainRegion: this.authProvider.getDomainRegion(), }) diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.ts index a72db66ea69..db3f6959969 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.ts @@ -13,10 +13,10 @@ import { telemetry } from '../../../shared/telemetry/telemetry' import { createQuickPick } from '../../../shared/ui/pickerPrompter' import { SageMakerUnifiedStudioProjectNode } from './sageMakerUnifiedStudioProjectNode' import { SageMakerUnifiedStudioAuthInfoNode } from './sageMakerUnifiedStudioAuthInfoNode' -import { SmusUtils } from '../../shared/smusUtils' +import { SmusErrorCodes, SmusUtils } from '../../shared/smusUtils' import { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticationProvider' import { ToolkitError } from '../../../../src/shared/errors' -import { errorCode } from '../../shared/errors' +import { recordAuthTelemetry } from '../../shared/telemetry' const contextValueSmusRoot = 'sageMakerUnifiedStudioRoot' const contextValueSmusLogin = 'sageMakerUnifiedStudioLogin' @@ -237,7 +237,10 @@ export const smusLoginCommand = Commands.declare('aws.smus.login', () => async ( if (!domainUrl) { // User cancelled logger.debug('User cancelled domain URL input') - return + throw new ToolkitError('User cancelled domain URL input', { + cancelled: true, + code: SmusErrorCodes.UserCancelled, + }) } // Show a simple status bar message instead of progress dialog @@ -252,19 +255,16 @@ export const smusLoginCommand = Commands.declare('aws.smus.login', () => async ( if (!connection) { throw new ToolkitError('Failed to establish connection', { - code: errorCode.failedAuthConnecton, + code: SmusErrorCodes.FailedAuthConnecton, }) } - // Extract domain ID and region for logging + // Extract domain account ID, domain ID, and region for logging const domainId = connection.domainId const region = connection.ssoRegion logger.info(`Connected to SageMaker Unified Studio domain: ${domainId} in region ${region}`) - span.record({ - smusDomainId: domainId, - awsRegion: region, - }) + await recordAuthTelemetry(span, authProvider, domainId, region) // Show success message void vscode.window.showInformationMessage( @@ -292,9 +292,12 @@ export const smusLoginCommand = Commands.declare('aws.smus.login', () => async ( }) } } catch (err) { - void vscode.window.showErrorMessage( - `SageMaker Unified Studio: Failed to initiate login: ${(err as Error).message}` - ) + const isUserCancelled = err instanceof ToolkitError && err.code === SmusErrorCodes.UserCancelled + if (!isUserCancelled) { + void vscode.window.showErrorMessage( + `SageMaker Unified Studio: Failed to initiate login: ${(err as Error).message}` + ) + } logger.error('Failed to initiate login: %s', (err as Error).message) throw new ToolkitError('Failed to initiate login.', { cause: err as Error, @@ -329,11 +332,7 @@ export const smusSignOutCommand = Commands.declare('aws.smus.signOut', () => asy // Show status message vscode.window.setStatusBarMessage('Signing out from SageMaker Unified Studio...', 5000) - - span.record({ - smusDomainId: domainId, - awsRegion: region, - }) + await recordAuthTelemetry(span, authProvider, domainId, region) // Delete the connection (this will also invalidate tokens and clear cache) if (activeConnection) { @@ -425,10 +424,12 @@ export async function selectSMUSProject(projectNode?: SageMakerUnifiedStudioProj } const selectedProject = await showQuickPick(items) + const accountId = await authProvider.getDomainAccountId() span.record({ smusDomainId: authProvider.getDomainId(), smusProjectId: (selectedProject as DataZoneProject).id as string | undefined, smusDomainRegion: authProvider.getDomainRegion(), + smusDomainAccountId: accountId, }) if ( selectedProject && diff --git a/packages/core/src/sagemakerunifiedstudio/shared/errors.ts b/packages/core/src/sagemakerunifiedstudio/shared/errors.ts deleted file mode 100644 index 1d582c22d74..00000000000 --- a/packages/core/src/sagemakerunifiedstudio/shared/errors.ts +++ /dev/null @@ -1,8 +0,0 @@ -/*! - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -export const errorCode = { - failedAuthConnecton: 'FailedAuthConnecton', -} diff --git a/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts b/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts index 3fe1dd9b27b..5950ba16c65 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts @@ -50,6 +50,14 @@ export const SmusErrorCodes = { SmusLoginFailed: 'SmusLoginFailed', /** Error code for when redeeming access token fails */ RedeemAccessTokenFailed: 'RedeemAccessTokenFailed', + /** Error code for when connection establish fails */ + FailedAuthConnecton: 'FailedAuthConnecton', + /** Error code for when user cancels an operation */ + UserCancelled: 'UserCancelled', + /** Error code for when domain account Id is missing */ + AccountIdNotFound: 'AccountIdNotFound', + /** Error code for when fails to get domain account Id */ + GetDomainAccountIdFailed: 'GetDomainAccountIdFailed', } as const /** @@ -351,3 +359,24 @@ export class SmusUtils { return isSMUSspace && !!resourceMetadata?.AdditionalMetadata?.DataZoneDomainId } } + +/** + * Extracts the account ID from a SageMaker ARN. + * Supports formats like: + * arn:aws:sagemaker:::app/* + * + * @param arn - The full SageMaker ARN string + * @returns The account ID from the ARN + * @throws If the ARN format is invalid + */ +export function extractAccountIdFromArn(arn: string): string { + // Match the ARN components to extract account ID + const regex = /^arn:aws:sagemaker:(?[^:]+):(?\d+):(app|space|domain)\/.+$/i + const match = arn.match(regex) + + if (!match?.groups) { + throw new ToolkitError(`Invalid SageMaker ARN format: "${arn}"`) + } + + return match.groups.accountId +} diff --git a/packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts b/packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts new file mode 100644 index 00000000000..b97762270b9 --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts @@ -0,0 +1,92 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + SmusLogin, + SmusOpenRemoteConnection, + SmusRenderLakehouseNode, + SmusRenderS3Node, + SmusSignOut, + SmusStopSpace, + Span, +} from '../../shared/telemetry/telemetry' +import { SagemakerUnifiedStudioSpaceNode } from '../explorer/nodes/sageMakerUnifiedStudioSpaceNode' +import { SageMakerUnifiedStudioSpacesParentNode } from '../explorer/nodes/sageMakerUnifiedStudioSpacesParentNode' +import { SmusAuthenticationProvider } from '../auth/providers/smusAuthenticationProvider' +import { getLogger } from '../../shared/logger/logger' +import { getContext } from '../../shared/vscode/setContext' +import { ConnectionCredentialsProvider } from '../auth/providers/connectionCredentialsProvider' +import { DataZoneConnection } from './client/datazoneClient' + +/** + * Records space telemetry + */ +export async function recordSpaceTelemetry( + span: Span | Span, + node: SagemakerUnifiedStudioSpaceNode +) { + const parent = node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode + const authProvider = SmusAuthenticationProvider.fromContext() + const accountId = await authProvider.getDomainAccountId() + span.record({ + smusSpaceKey: node.resource.DomainSpaceKey, + smusDomainRegion: node.resource.regionCode, + smusDomainId: parent?.getAuthProvider()?.activeConnection?.domainId, + smusDomainAccountId: accountId, + smusProjectId: parent?.getProjectId(), + }) +} + +/** + * Records auth telemetry + */ +export async function recordAuthTelemetry( + span: Span | Span, + authProvider: SmusAuthenticationProvider, + domainId: string | undefined, + region: string | undefined +) { + const logger = getLogger() + + span.record({ + smusDomainId: domainId, + awsRegion: region, + }) + + try { + if (!region) { + throw new Error(`Region is undefined for domain ${domainId}`) + } + const accountId = await authProvider.getDomainAccountId() + span.record({ + smusDomainAccountId: accountId, + }) + } catch (err) { + logger.error( + `Failed to resolve AWS account ID via STS Client for domain ${domainId} in region ${region}: ${err}` + ) + } +} + +/** + * Records data connection telemetry for SMUS nodes + */ +export async function recordDataConnectionTelemetry( + span: Span | Span, + connection: DataZoneConnection, + connectionCredentialsProvider: ConnectionCredentialsProvider +) { + const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') + const accountId = await connectionCredentialsProvider.getDomainAccountId() + span.record({ + smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', + smusDomainId: connection.domainId, + smusDomainAccountId: accountId, + smusProjectId: connection.projectId, + smusConnectionId: connection.connectionId, + smusConnectionType: connection.type, + smusProjectRegion: connection.location?.awsRegion, + }) +} diff --git a/packages/core/src/shared/telemetry/vscodeTelemetry.json b/packages/core/src/shared/telemetry/vscodeTelemetry.json index aefb6e3a1b1..6fbfa22e394 100644 --- a/packages/core/src/shared/telemetry/vscodeTelemetry.json +++ b/packages/core/src/shared/telemetry/vscodeTelemetry.json @@ -295,6 +295,16 @@ "name": "smusConnectionType", "type": "string", "description": "SMUS connection type" + }, + { + "name": "smusDomainAccountId", + "type": "string", + "description": "SMUS domain account id" + }, + { + "name": "smusProjectAccountId", + "type": "string", + "description": "SMUS project account id" } ], "metrics": [ @@ -1359,6 +1369,10 @@ { "type": "smusDomainId", "required": false + }, + { + "type": "smusDomainAccountId", + "required": false } ] }, @@ -1369,6 +1383,10 @@ { "type": "smusDomainId", "required": false + }, + { + "type": "smusDomainAccountId", + "required": false } ] }, @@ -1380,6 +1398,10 @@ "type": "smusDomainId", "required": false }, + { + "type": "smusDomainAccountId", + "required": false + }, { "type": "smusProjectId", "required": false @@ -1402,6 +1424,10 @@ "type": "smusDomainId", "required": false }, + { + "type": "smusDomainAccountId", + "required": false + }, { "type": "smusProjectId", "required": false @@ -1414,13 +1440,17 @@ "passive": true }, { - "name": "smus_startSpace", + "name": "smus_openRemoteConnection", "description": "Emitted whenever a user starts a SMUS space", "metadata": [ { "type": "smusDomainId", "required": false }, + { + "type": "smusDomainAccountId", + "required": false + }, { "type": "smusProjectId", "required": false @@ -1443,6 +1473,10 @@ "type": "smusDomainId", "required": false }, + { + "type": "smusDomainAccountId", + "required": false + }, { "type": "smusProjectId", "required": false @@ -1469,6 +1503,10 @@ "type": "smusDomainId", "required": false }, + { + "type": "smusDomainAccountId", + "required": false + }, { "type": "smusProjectId", "required": false @@ -1499,6 +1537,10 @@ "type": "smusDomainId", "required": false }, + { + "type": "smusDomainAccountId", + "required": false + }, { "type": "smusProjectId", "required": false @@ -1529,6 +1571,10 @@ "type": "smusDomainId", "required": false }, + { + "type": "smusDomainAccountId", + "required": false + }, { "type": "smusProjectId", "required": false diff --git a/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts b/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts index dafd1a630ec..f971ff5520f 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts @@ -14,8 +14,12 @@ import { SmusAuthenticationProvider } from '../../../sagemakerunifiedstudio/auth import { SmusConnection } from '../../../sagemakerunifiedstudio/auth/model' import { DataZoneClient } from '../../../sagemakerunifiedstudio/shared/client/datazoneClient' import { SmusUtils } from '../../../sagemakerunifiedstudio/shared/smusUtils' +import * as smusUtils from '../../../sagemakerunifiedstudio/shared/smusUtils' import { ToolkitError } from '../../../shared/errors' import * as messages from '../../../shared/utilities/messages' +import * as vscodeSetContext from '../../../shared/vscode/setContext' +import * as resourceMetadataUtils from '../../../sagemakerunifiedstudio/shared/utils/resourceMetadataUtils' +import { DefaultStsClient } from '../../../shared/clients/stsClient' describe('SmusAuthenticationProvider', function () { let mockAuth: any @@ -410,4 +414,166 @@ describe('SmusAuthenticationProvider', function () { assert.strictEqual(SmusAuthenticationProvider.instance, instance) }) }) + + describe('getDomainAccountId', function () { + let getContextStub: sinon.SinonStub + let getResourceMetadataStub: sinon.SinonStub + let extractAccountIdFromArnStub: sinon.SinonStub + let getDerCredentialsProviderStub: sinon.SinonStub + let getDomainRegionStub: sinon.SinonStub + let mockStsClient: any + let mockCredentialsProvider: any + + beforeEach(function () { + // Mock dependencies + getContextStub = sinon.stub(vscodeSetContext, 'getContext') + getResourceMetadataStub = sinon.stub(resourceMetadataUtils, 'getResourceMetadata') + extractAccountIdFromArnStub = sinon.stub(smusUtils, 'extractAccountIdFromArn') + + // Mock STS client + mockStsClient = { + getCallerIdentity: sinon.stub(), + } + sinon + .stub(DefaultStsClient.prototype, 'getCallerIdentity') + .callsFake(() => mockStsClient.getCallerIdentity()) + + // Mock credentials provider + mockCredentialsProvider = { + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + } + + // Stub methods on the provider instance + getDerCredentialsProviderStub = sinon + .stub(smusAuthProvider, 'getDerCredentialsProvider') + .resolves(mockCredentialsProvider) + getDomainRegionStub = sinon.stub(smusAuthProvider, 'getDomainRegion').returns('us-east-1') + + // Reset cached value + smusAuthProvider['cachedDomainAccountId'] = undefined + }) + + afterEach(function () { + sinon.restore() + }) + + describe('when cached value exists', function () { + it('should return cached account ID without making any calls', async function () { + const cachedAccountId = '123456789012' + smusAuthProvider['cachedDomainAccountId'] = cachedAccountId + + const result = await smusAuthProvider.getDomainAccountId() + + assert.strictEqual(result, cachedAccountId) + assert.ok(getContextStub.notCalled) + assert.ok(getResourceMetadataStub.notCalled) + assert.ok(mockStsClient.getCallerIdentity.notCalled) + }) + }) + + describe('in SMUS space environment', function () { + beforeEach(function () { + getContextStub.withArgs('aws.smus.inSmusSpaceEnvironment').returns(true) + }) + + it('should extract account ID from ResourceArn and cache it', async function () { + const testAccountId = '123456789012' + const testResourceArn = `arn:aws:sagemaker:us-east-1:${testAccountId}:domain/test-domain` + + getResourceMetadataStub.returns({ + ResourceArn: testResourceArn, + }) + extractAccountIdFromArnStub.returns(testAccountId) + + const result = await smusAuthProvider.getDomainAccountId() + + assert.strictEqual(result, testAccountId) + assert.strictEqual(smusAuthProvider['cachedDomainAccountId'], testAccountId) + assert.ok(getResourceMetadataStub.called) + assert.ok(extractAccountIdFromArnStub.calledWith(testResourceArn)) + assert.ok(mockStsClient.getCallerIdentity.notCalled) + }) + + it('should throw error when ResourceArn is missing from metadata', async function () { + getResourceMetadataStub.returns({}) + + await assert.rejects( + () => smusAuthProvider.getDomainAccountId(), + (err: ToolkitError) => { + return ( + err.code === 'GetDomainAccountIdFailed' && + err.message.includes( + 'Failed to extract AWS account ID from ResourceArn in SMUS space environment' + ) + ) + } + ) + + assert.strictEqual(smusAuthProvider['cachedDomainAccountId'], undefined) + }) + }) + + describe('in non-SMUS space environment', function () { + beforeEach(function () { + getContextStub.withArgs('aws.smus.inSmusSpaceEnvironment').returns(false) + mockSecondaryAuthState.activeConnection = mockSmusConnection + }) + + it('should use STS GetCallerIdentity to get account ID and cache it', async function () { + const testAccountId = '123456789012' + mockStsClient.getCallerIdentity.resolves({ + Account: testAccountId, + UserId: 'test-user-id', + Arn: 'arn:aws:sts::123456789012:assumed-role/test-role/test-session', + }) + + const result = await smusAuthProvider.getDomainAccountId() + + assert.strictEqual(result, testAccountId) + assert.strictEqual(smusAuthProvider['cachedDomainAccountId'], testAccountId) + assert.ok(getDerCredentialsProviderStub.called) + assert.ok(getDomainRegionStub.called) + assert.ok(mockCredentialsProvider.getCredentials.called) + assert.ok(mockStsClient.getCallerIdentity.called) + }) + + it('should throw error when no active connection exists', async function () { + mockSecondaryAuthState.activeConnection = undefined + + await assert.rejects( + () => smusAuthProvider.getDomainAccountId(), + (err: ToolkitError) => { + return ( + err.code === 'NoActiveConnection' && + err.message.includes('No active SMUS connection available') + ) + } + ) + + assert.strictEqual(smusAuthProvider['cachedDomainAccountId'], undefined) + assert.ok(getDerCredentialsProviderStub.notCalled) + assert.ok(mockStsClient.getCallerIdentity.notCalled) + }) + + it('should throw error when STS GetCallerIdentity fails', async function () { + mockStsClient.getCallerIdentity.rejects(new Error('STS call failed')) + + await assert.rejects( + () => smusAuthProvider.getDomainAccountId(), + (err: ToolkitError) => { + return ( + err.code === 'GetDomainAccountIdFailed' && + err.message.includes('Failed to retrieve AWS account ID for active domain connection') + ) + } + ) + + assert.strictEqual(smusAuthProvider['cachedDomainAccountId'], undefined) + }) + }) + }) }) diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/activation.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/activation.test.ts index fcb76325293..982aa481bd3 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/activation.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/activation.test.ts @@ -18,6 +18,7 @@ import { getLogger } from '../../../shared/logger/logger' import { getTestWindow } from '../../shared/vscode/window' import { SeverityLevel } from '../../shared/vscode/message' import * as extensionUtilities from '../../../shared/extensionUtilities' +import { createMockSpaceNode } from '../testUtils' describe('SMUS Explorer Activation', function () { let mockExtensionContext: vscode.ExtensionContext @@ -30,7 +31,7 @@ describe('SMUS Explorer Activation', function () { let dataZoneDisposeStub: sinon.SinonStub let setupUserActivityMonitoringStub: sinon.SinonStub - beforeEach(function () { + beforeEach(async function () { mockExtensionContext = { subscriptions: [], } as any @@ -45,6 +46,7 @@ describe('SMUS Explorer Activation', function () { domainId: 'test-domain', ssoRegion: 'us-east-1', }, + getDomainAccountId: sinon.stub().resolves('123456789012'), } as any mockTreeView = { @@ -285,25 +287,7 @@ describe('SMUS Explorer Activation', function () { assert.ok(stopSpaceCommand) - const mockSpaceNode = { - resource: { - sageMakerClient: {}, - DomainSpaceKey: 'test-space-key', - regionCode: 'us-east-1', - getParent: sinon.stub().returns({ - getAuthProvider: sinon.stub().returns({ - activeConnection: { domainId: 'test-domain' }, - }), - getProjectId: sinon.stub().returns('test-project'), - }), - }, - getParent: sinon.stub().returns({ - getAuthProvider: sinon.stub().returns({ - activeConnection: { domainId: 'test-domain' }, - }), - getProjectId: sinon.stub().returns('test-project'), - }), - } as any + const mockSpaceNode = createMockSpaceNode() // Mock the stopSpace function const stopSpaceStub = sinon.stub() @@ -346,25 +330,7 @@ describe('SMUS Explorer Activation', function () { assert.ok(openRemoteCommand) - const mockSpaceNode = { - resource: { - sageMakerClient: {}, - DomainSpaceKey: 'test-space-key', - regionCode: 'us-east-1', - getParent: sinon.stub().returns({ - getAuthProvider: sinon.stub().returns({ - activeConnection: { domainId: 'test-domain' }, - }), - getProjectId: sinon.stub().returns('test-project'), - }), - }, - getParent: sinon.stub().returns({ - getAuthProvider: sinon.stub().returns({ - activeConnection: { domainId: 'test-domain' }, - }), - getProjectId: sinon.stub().returns('test-project'), - }), - } as any + const mockSpaceNode = createMockSpaceNode() // Mock the openRemoteConnect function const openRemoteConnectStub = sinon.stub() diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.test.ts index 6039e5e4e02..63e87c25f23 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.test.ts @@ -35,6 +35,7 @@ describe('LakehouseStrategy', function () { secretAccessKey: 'test-secret', sessionToken: 'test-token', }), + getDomainAccountId: async () => '123456789012', } beforeEach(function () { diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.test.ts index 31ec0e8bb24..50b5e36e251 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.test.ts @@ -262,6 +262,7 @@ describe('redshiftStrategy', function () { accessKeyId: 'test-key', secretAccessKey: 'test-secret', }), + getDomainAccountId: async () => '123456789012', } const node = createRedshiftConnectionNode( diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/s3Strategy.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/s3Strategy.test.ts index 9b193ea0106..f6838ab483e 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/s3Strategy.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/s3Strategy.test.ts @@ -11,6 +11,7 @@ import { S3Client } from '../../../../sagemakerunifiedstudio/shared/client/s3Cli import { ConnectionClientStore } from '../../../../sagemakerunifiedstudio/shared/client/connectionClientStore' import { NodeType, ConnectionType } from '../../../../sagemakerunifiedstudio/explorer/nodes/types' import { ConnectionCredentialsProvider } from '../../../../sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider' +import { createMockS3Connection, createMockCredentialsProvider } from '../../testUtils' describe('s3Strategy', function () { let sandbox: sinon.SinonSandbox @@ -145,23 +146,8 @@ describe('s3Strategy', function () { }) it('should create S3 connection node for default connection with full path', function () { - const connection = { - connectionId: 'conn-123', - name: 'project.s3_default_folder', - type: 'S3Connection', - props: { - s3Properties: { - s3Uri: 's3://test-bucket/domain/project/', - }, - }, - } - - const credentialsProvider = { - getCredentials: async () => ({ - accessKeyId: 'test-key', - secretAccessKey: 'test-secret', - }), - } + const connection = createMockS3Connection() + const credentialsProvider = createMockCredentialsProvider() const node = createS3ConnectionNode( connection as any, @@ -209,12 +195,7 @@ describe('s3Strategy', function () { }, } - const credentialsProvider = { - getCredentials: async () => ({ - accessKeyId: 'test-key', - secretAccessKey: 'test-secret', - }), - } + const credentialsProvider = createMockCredentialsProvider() mockS3Client.listPaths.resolves({ paths: [ @@ -240,23 +221,8 @@ describe('s3Strategy', function () { }) it('should handle bucket listing for default connection with full path display', async function () { - const connection = { - connectionId: 'conn-123', - name: 'project.s3_default_folder', - type: 'S3Connection', - props: { - s3Properties: { - s3Uri: 's3://test-bucket/domain/project/', - }, - }, - } - - const credentialsProvider = { - getCredentials: async () => ({ - accessKeyId: 'test-key', - secretAccessKey: 'test-secret', - }), - } + const connection = createMockS3Connection() + const credentialsProvider = createMockCredentialsProvider() mockS3Client.listPaths.resolves({ paths: [ diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.test.ts index 02a0c1078d9..2fd8317fe06 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.test.ts @@ -14,6 +14,7 @@ import { SagemakerClient } from '../../../../shared/clients/sagemaker' import { SageMakerUnifiedStudioDataNode } from '../../../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioDataNode' import { SageMakerUnifiedStudioComputeNode } from '../../../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioComputeNode' import * as vscodeUtils from '../../../../shared/vscode/setContext' +import { createMockExtensionContext } from '../../testUtils' describe('SageMakerUnifiedStudioProjectNode', function () { let projectNode: SageMakerUnifiedStudioProjectNode @@ -36,20 +37,11 @@ describe('SageMakerUnifiedStudioProjectNode', function () { invalidateAllProjectCredentialsInCache: sinon.stub(), getProjectCredentialProvider: sinon.stub(), getDomainRegion: sinon.stub().returns('us-west-2'), + getDomainAccountId: sinon.stub().resolves('123456789012'), } as any // Create mock extension context - const mockExtensionContext = { - subscriptions: [], - workspaceState: { - get: sinon.stub(), - update: sinon.stub(), - }, - globalState: { - get: sinon.stub(), - update: sinon.stub(), - }, - } as any + const mockExtensionContext = createMockExtensionContext() projectNode = new SageMakerUnifiedStudioProjectNode(mockParent, mockAuthProvider, mockExtensionContext) diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.test.ts index f89413e9528..64b866c7704 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.test.ts @@ -17,6 +17,7 @@ import { SmusAuthenticationProvider } from '../../../../sagemakerunifiedstudio/a import * as pickerPrompter from '../../../../shared/ui/pickerPrompter' import { getTestWindow } from '../../../shared/vscode/window' import { assertTelemetry } from '../../../../../src/test/testUtil' +import { createMockExtensionContext, createMockUnauthenticatedAuthProvider } from '../../testUtils' describe('SmusRootNode', function () { let rootNode: SageMakerUnifiedStudioRootNode @@ -30,19 +31,36 @@ describe('SmusRootNode', function () { domainId: testDomainId, } + /** + * Helper function to verify login and learn more nodes + */ + async function verifyLoginAndLearnMoreNodes(children: any[]) { + assert.strictEqual(children.length, 2) + assert.strictEqual(children[0].id, 'smusLogin') + assert.strictEqual(children[1].id, 'smusLearnMore') + + // Check login node + const loginTreeItem = await children[0].getTreeItem() + assert.strictEqual(loginTreeItem.label, 'Sign in to get started') + assert.strictEqual(loginTreeItem.contextValue, 'sageMakerUnifiedStudioLogin') + assert.deepStrictEqual(loginTreeItem.command, { + command: 'aws.smus.login', + title: 'Sign in to SageMaker Unified Studio', + }) + + // Check learn more node + const learnMoreTreeItem = await children[1].getTreeItem() + assert.strictEqual(learnMoreTreeItem.label, 'Learn more about SageMaker Unified Studio') + assert.strictEqual(learnMoreTreeItem.contextValue, 'sageMakerUnifiedStudioLearnMore') + assert.deepStrictEqual(learnMoreTreeItem.command, { + command: 'aws.smus.learnMore', + title: 'Learn more about SageMaker Unified Studio', + }) + } + beforeEach(function () { // Create mock extension context - const mockExtensionContext = { - subscriptions: [], - workspaceState: { - get: sinon.stub(), - update: sinon.stub(), - }, - globalState: { - get: sinon.stub(), - update: sinon.stub(), - }, - } as any + const mockExtensionContext = createMockExtensionContext() // Create a mock auth provider const mockAuthProvider = { @@ -80,17 +98,7 @@ describe('SmusRootNode', function () { onDidChange: sinon.stub().returns({ dispose: sinon.stub() }), } as any - const mockExtensionContext = { - subscriptions: [], - workspaceState: { - get: sinon.stub(), - update: sinon.stub(), - }, - globalState: { - get: sinon.stub(), - update: sinon.stub(), - }, - } as any + const mockExtensionContext = createMockExtensionContext() const node = new SageMakerUnifiedStudioRootNode(mockAuthProvider, mockExtensionContext) assert.strictEqual(node.id, 'smusRootNode') @@ -115,24 +123,8 @@ describe('SmusRootNode', function () { it('returns correct tree item when not authenticated', async function () { // Create a mock auth provider for unauthenticated state - const mockAuthProvider = { - isConnected: sinon.stub().returns(false), - isConnectionValid: sinon.stub().returns(false), - activeConnection: undefined, - onDidChange: sinon.stub().returns({ dispose: sinon.stub() }), - } as any - - const mockExtensionContext = { - subscriptions: [], - workspaceState: { - get: sinon.stub(), - update: sinon.stub(), - }, - globalState: { - get: sinon.stub(), - update: sinon.stub(), - }, - } as any + const mockAuthProvider = createMockUnauthenticatedAuthProvider() + const mockExtensionContext = createMockExtensionContext() const unauthenticatedNode = new SageMakerUnifiedStudioRootNode(mockAuthProvider, mockExtensionContext) const treeItem = unauthenticatedNode.getTreeItem() @@ -148,49 +140,12 @@ describe('SmusRootNode', function () { describe('getChildren', function () { it('returns login node when not authenticated (empty domain ID)', async function () { // Create a mock auth provider for unauthenticated state - const mockAuthProvider = { - isConnected: sinon.stub().returns(false), - isConnectionValid: sinon.stub().returns(false), - activeConnection: undefined, - onDidChange: sinon.stub().returns({ dispose: sinon.stub() }), - } as any - - const mockExtensionContext = { - subscriptions: [], - workspaceState: { - get: sinon.stub(), - update: sinon.stub(), - }, - globalState: { - get: sinon.stub(), - update: sinon.stub(), - }, - } as any + const mockAuthProvider = createMockUnauthenticatedAuthProvider() + const mockExtensionContext = createMockExtensionContext() const unauthenticatedNode = new SageMakerUnifiedStudioRootNode(mockAuthProvider, mockExtensionContext) const children = await unauthenticatedNode.getChildren() - - assert.strictEqual(children.length, 2) - assert.strictEqual(children[0].id, 'smusLogin') - assert.strictEqual(children[1].id, 'smusLearnMore') - - // Check login node - const loginTreeItem = await children[0].getTreeItem() - assert.strictEqual(loginTreeItem.label, 'Sign in to get started') - assert.strictEqual(loginTreeItem.contextValue, 'sageMakerUnifiedStudioLogin') - assert.deepStrictEqual(loginTreeItem.command, { - command: 'aws.smus.login', - title: 'Sign in to SageMaker Unified Studio', - }) - - // Check learn more node - const learnMoreTreeItem = await children[1].getTreeItem() - assert.strictEqual(learnMoreTreeItem.label, 'Learn more about SageMaker Unified Studio') - assert.strictEqual(learnMoreTreeItem.contextValue, 'sageMakerUnifiedStudioLearnMore') - assert.deepStrictEqual(learnMoreTreeItem.command, { - command: 'aws.smus.learnMore', - title: 'Learn more about SageMaker Unified Studio', - }) + await verifyLoginAndLearnMoreNodes(children) }) it('returns login node when DataZone client throws error', async function () { @@ -202,34 +157,11 @@ describe('SmusRootNode', function () { onDidChange: sinon.stub().returns({ dispose: sinon.stub() }), } as any - const mockExtensionContext = { - subscriptions: [], - workspaceState: { - get: sinon.stub(), - update: sinon.stub(), - }, - globalState: { - get: sinon.stub(), - update: sinon.stub(), - }, - } as any + const mockExtensionContext = createMockExtensionContext() const errorNode = new SageMakerUnifiedStudioRootNode(mockAuthProvider, mockExtensionContext) const children = await errorNode.getChildren() - - assert.strictEqual(children.length, 2) - assert.strictEqual(children[0].id, 'smusLogin') - assert.strictEqual(children[1].id, 'smusLearnMore') - - // Check login node - const loginTreeItem = await children[0].getTreeItem() - assert.strictEqual(loginTreeItem.label, 'Sign in to get started') - assert.strictEqual(loginTreeItem.contextValue, 'sageMakerUnifiedStudioLogin') - - // Check learn more node - const learnMoreTreeItem = await children[1].getTreeItem() - assert.strictEqual(learnMoreTreeItem.label, 'Learn more about SageMaker Unified Studio') - assert.strictEqual(learnMoreTreeItem.contextValue, 'sageMakerUnifiedStudioLearnMore') + await verifyLoginAndLearnMoreNodes(children) }) it('returns root nodes when authenticated', async function () { @@ -267,17 +199,7 @@ describe('SmusRootNode', function () { showReauthenticationPrompt: sinon.stub(), } as any - const mockExtensionContext = { - subscriptions: [], - workspaceState: { - get: sinon.stub(), - update: sinon.stub(), - }, - globalState: { - get: sinon.stub(), - update: sinon.stub(), - }, - } as any + const mockExtensionContext = createMockExtensionContext() const expiredNode = new SageMakerUnifiedStudioRootNode(mockAuthProvider, mockExtensionContext) const children = await expiredNode.getChildren() @@ -350,6 +272,7 @@ describe('SelectSMUSProject', function () { isConnected: sinon.stub().returns(true), isConnectionValid: sinon.stub().returns(true), activeConnection: { domainId: testDomainId, ssoRegion: 'us-west-2' }, + getDomainAccountId: sinon.stub().resolves('123456789012'), getDomainId: sinon.stub().returns(testDomainId), getDomainRegion: sinon.stub().returns('us-west-2'), } as any) @@ -532,6 +455,7 @@ describe('selectSMUSProject - Additional Tests', function () { sinon.stub(DataZoneClient, 'getInstance').returns(mockDataZoneClient as any) sinon.stub(SmusAuthenticationProvider, 'fromContext').returns({ activeConnection: { domainId: testDomainId, ssoRegion: 'us-west-2' }, + getDomainAccountId: sinon.stub().resolves('123456789012'), getDomainId: sinon.stub().returns(testDomainId), getDomainRegion: sinon.stub().returns('us-west-2'), } as any) diff --git a/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts b/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts index e8f17d486a3..f895fe6ea13 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts @@ -11,6 +11,7 @@ import { SmusTimeouts, SmusCredentialExpiry, validateCredentialFields, + extractAccountIdFromArn, } from '../../../sagemakerunifiedstudio/shared/smusUtils' import { ToolkitError } from '../../../shared/errors' import * as extensionUtilities from '../../../shared/extensionUtilities' @@ -461,3 +462,62 @@ describe('SmusUtils', () => { }) }) }) + +describe('extractAccountIdFromArn', () => { + describe('valid ARN formats', () => { + it('should extract account ID from valid ARN', () => { + const arn = 'arn:aws:sagemaker:us-west-2:123456789012:app/domain-id/ce/CodeEditor/default' + const result = extractAccountIdFromArn(arn) + + assert.strictEqual(result, '123456789012') + }) + }) + + describe('invalid ARN formats', () => { + it('should throw error for empty ARN', () => { + assert.throws( + () => extractAccountIdFromArn(''), + (error: any) => { + assert.ok(error instanceof ToolkitError) + assert.ok(error.message.includes('Invalid SageMaker ARN format')) + return true + } + ) + }) + + it('should throw error for non-ARN string', () => { + assert.throws( + () => extractAccountIdFromArn('not-an-arn'), + (error: any) => { + assert.ok(error instanceof ToolkitError) + assert.ok(error.message.includes('Invalid SageMaker ARN format')) + return true + } + ) + }) + + it('should throw error for wrong service', () => { + const arn = 'arn:aws:s3:us-east-1:123456789012:bucket/my-bucket' + assert.throws( + () => extractAccountIdFromArn(arn), + (error: any) => { + assert.ok(error instanceof ToolkitError) + assert.ok(error.message.includes('Invalid SageMaker ARN format')) + return true + } + ) + }) + + it('should throw error for missing account ID', () => { + const arn = 'arn:aws:sagemaker:us-east-1::space/domain/space' + assert.throws( + () => extractAccountIdFromArn(arn), + (error: any) => { + assert.ok(error instanceof ToolkitError) + assert.ok(error.message.includes('Invalid SageMaker ARN format')) + return true + } + ) + }) + }) +}) diff --git a/packages/core/src/test/sagemakerunifiedstudio/testUtils.ts b/packages/core/src/test/sagemakerunifiedstudio/testUtils.ts new file mode 100644 index 00000000000..ce1a706325d --- /dev/null +++ b/packages/core/src/test/sagemakerunifiedstudio/testUtils.ts @@ -0,0 +1,89 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as sinon from 'sinon' + +/** + * Creates a mock extension context for SageMaker Unified Studio tests + */ +export function createMockExtensionContext(): any { + return { + subscriptions: [], + workspaceState: { + get: sinon.stub(), + update: sinon.stub(), + }, + globalState: { + get: sinon.stub(), + update: sinon.stub(), + }, + } +} + +/** + * Creates a mock S3 connection for SageMaker Unified Studio tests + */ +export function createMockS3Connection() { + return { + connectionId: 'conn-123', + name: 'project.s3_default_folder', + type: 'S3Connection', + props: { + s3Properties: { + s3Uri: 's3://test-bucket/domain/project/', + }, + }, + } +} + +/** + * Creates a mock credentials provider for SageMaker Unified Studio tests + */ +export function createMockCredentialsProvider() { + return { + getCredentials: async () => ({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + }), + getDomainAccountId: async () => '123456789012', + } +} +/** + * Creates a mock unauthenticated auth provider for SageMaker Unified Studio tests + */ +export function createMockUnauthenticatedAuthProvider(): any { + return { + isConnected: sinon.stub().returns(false), + isConnectionValid: sinon.stub().returns(false), + activeConnection: undefined, + onDidChange: sinon.stub().returns({ dispose: sinon.stub() }), + } +} /** + * + Creates a mock space node for SageMaker Unified Studio tests + */ +export function createMockSpaceNode(): any { + return { + resource: { + sageMakerClient: {}, + DomainSpaceKey: 'test-space-key', + regionCode: 'us-east-1', + getParent: sinon.stub().returns({ + getAuthProvider: sinon.stub().returns({ + activeConnection: { domainId: 'test-domain' }, + getDomainAccountId: sinon.stub().resolves('123456789012'), + }), + getProjectId: sinon.stub().returns('test-project'), + }), + }, + getParent: sinon.stub().returns({ + getAuthProvider: sinon.stub().returns({ + activeConnection: { domainId: 'test-domain' }, + getDomainAccountId: sinon.stub().resolves('123456789012'), + }), + getProjectId: sinon.stub().returns('test-project'), + }), + } +}