diff --git a/.gitignore b/.gitignore index fb06d810f42..8a8b0fbe406 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ src.gen/* **/src/auth/sso/oidcclientpkce.d.ts **/src/sagemakerunifiedstudio/shared/client/gluecatalogapi.d.ts **/src/sagemakerunifiedstudio/shared/client/sqlworkbench.d.ts +**/src/sagemakerunifiedstudio/shared/client/datazonecustomclient.d.ts # Generated by tests **/src/testFixtures/**/bin diff --git a/package-lock.json b/package-lock.json index 2dd14a509fd..364db5ba08a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -22,7 +22,7 @@ "vscode-nls-dev": "^4.0.4" }, "devDependencies": { - "@aws-toolkits/telemetry": "^1.0.329", + "@aws-toolkits/telemetry": "^1.0.338", "@playwright/browser-chromium": "^1.43.1", "@stylistic/eslint-plugin": "^2.11.0", "@types/he": "^1.2.3", @@ -24526,9 +24526,9 @@ } }, "node_modules/@aws-toolkits/telemetry": { - "version": "1.0.329", - "resolved": "https://registry.npmjs.org/@aws-toolkits/telemetry/-/telemetry-1.0.329.tgz", - "integrity": "sha512-zMkljZDtIAxuZzPTLL5zIxn+zGmk767sbqGIc2ZYuv0sSU+UoYgB3tqwV5KVV2oDPKs5593nwJC97NVHJqzowQ==", + "version": "1.0.338", + "resolved": "https://registry.npmjs.org/@aws-toolkits/telemetry/-/telemetry-1.0.338.tgz", + "integrity": "sha512-fg3zCqH4GEBjgL3Wo+xiijRbkyxMh4hXPsOD8Q52k8bvmq5rL9tjbp2IqmHI8JVLOhoomedicXK9ZVEdKSsatw==", "dev": true, "license": "Apache-2.0", "dependencies": { @@ -42520,7 +42520,7 @@ }, "packages/toolkit": { "name": "aws-toolkit-vscode", - "version": "3.86.0-SNAPSHOT", + "version": "3.87.0-SNAPSHOT", "license": "Apache-2.0", "dependencies": { "aws-core-vscode": "file:../core/" diff --git a/package.json b/package.json index 67135911fe7..242e4ee9594 100644 --- a/package.json +++ b/package.json @@ -42,7 +42,7 @@ "scan-licenses": "ts-node ./scripts/scan-licenses.ts" }, "devDependencies": { - "@aws-toolkits/telemetry": "^1.0.329", + "@aws-toolkits/telemetry": "^1.0.338", "@playwright/browser-chromium": "^1.43.1", "@stylistic/eslint-plugin": "^2.11.0", "@types/he": "^1.2.3", diff --git a/packages/core/package.nls.json b/packages/core/package.nls.json index 9e5c91e39e8..bda0d97b022 100644 --- a/packages/core/package.nls.json +++ b/packages/core/package.nls.json @@ -231,6 +231,7 @@ "AWS.command.s3.uploadFileToParent": "Upload to Parent...", "AWS.command.smus.switchProject": "Switch Project", "AWS.command.smus.refreshProject": "Refresh Project", + "AWS.command.smus.refresh": "Refresh", "AWS.command.smus.signOut": "Sign Out", "AWS.command.sagemaker.filterSpaces": "Filter Sagemaker Spaces", "AWS.command.stepFunctions.createStateMachineFromTemplate": "Create a new Step Functions state machine", diff --git a/packages/core/scripts/build/generateServiceClient.ts b/packages/core/scripts/build/generateServiceClient.ts index de601e6ee44..ac46c307a0f 100644 --- a/packages/core/scripts/build/generateServiceClient.ts +++ b/packages/core/scripts/build/generateServiceClient.ts @@ -249,6 +249,10 @@ void (async () => { serviceJsonPath: 'src/sagemakerunifiedstudio/shared/client/sqlworkbench.json', serviceName: 'SQLWorkbench', }, + { + serviceJsonPath: 'src/sagemakerunifiedstudio/shared/client/datazonecustomclient.json', + serviceName: 'DataZoneCustomClient', + }, ] await generateServiceClients(serviceClientDefinitions) })() diff --git a/packages/core/src/auth/auth.ts b/packages/core/src/auth/auth.ts index 053df50321e..18f4fdd5b14 100644 --- a/packages/core/src/auth/auth.ts +++ b/packages/core/src/auth/auth.ts @@ -598,7 +598,7 @@ export class Auth implements AuthService, ConnectionManager { } @withTelemetryContext({ name: 'updateConnectionState', class: authClassName }) - private async updateConnectionState(id: Connection['id'], connectionState: ProfileMetadata['connectionState']) { + public async updateConnectionState(id: Connection['id'], connectionState: ProfileMetadata['connectionState']) { getLogger().info(`auth: Updating connection state of ${id} to ${connectionState}`) if (connectionState === 'authenticating') { diff --git a/packages/core/src/auth/sso/clients.ts b/packages/core/src/auth/sso/clients.ts index 2fa8b1a3854..14bc35c039e 100644 --- a/packages/core/src/auth/sso/clients.ts +++ b/packages/core/src/auth/sso/clients.ts @@ -36,6 +36,8 @@ import { AuthenticationFlow } from './model' import { toSnakeCase } from '../../shared/utilities/textUtilities' import { getUserAgent, withTelemetryContext } from '../../shared/telemetry/util' import { oneSecond } from '../../shared/datetime' +import { telemetry } from '../../shared/telemetry/telemetry' +import { getTelemetryReason, getTelemetryReasonDesc, getHttpStatusCode } from '../../shared/errors' export class OidcClient { public constructor( @@ -86,15 +88,40 @@ export class OidcClient { } public async createToken(request: CreateTokenRequest) { + const startTime = this.clock.Date.now() + const grantType = request.grantType + let response try { response = await this.client.createToken(request as CreateTokenRequest) } catch (err) { + const statusCode = getHttpStatusCode(err) + telemetry.auth_ssoTokenOperation.emit({ + result: 'Failed', + grantType: grantType ?? 'unknown', + duration: this.clock.Date.now() - startTime, + reason: getTelemetryReason(err), + reasonDesc: getTelemetryReasonDesc(err), + ...(statusCode !== undefined ? { httpStatusCode: String(statusCode) } : {}), + }) + + getLogger().error(`sso-oidc: createToken failed (grantType=${grantType}): ${err}`) + const newError = AwsClientResponseError.instanceIf(err) throw newError } assertHasProps(response, 'accessToken', 'expiresIn') + telemetry.auth_ssoTokenOperation.emit({ + result: 'Succeeded', + grantType: grantType ?? 'unknown', + duration: this.clock.Date.now() - startTime, + }) + + getLogger().debug( + `sso-oidc: createToken succeeded (grantType=${grantType}, requestId=${response.$metadata.requestId})` + ) + return { ...selectFrom(response, 'accessToken', 'refreshToken', 'tokenType'), requestId: response.$metadata.requestId, diff --git a/packages/core/src/awsService/sagemaker/credentialMapping.ts b/packages/core/src/awsService/sagemaker/credentialMapping.ts index f8d58758f11..931384e8811 100644 --- a/packages/core/src/awsService/sagemaker/credentialMapping.ts +++ b/packages/core/src/awsService/sagemaker/credentialMapping.ts @@ -15,6 +15,7 @@ import { getLogger } from '../../shared/logger/logger' import { parseArn } from './detached-server/utils' import { SagemakerUnifiedStudioSpaceNode } from '../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpaceNode' import { SageMakerUnifiedStudioSpacesParentNode } from '../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode' +import { isSmusSsoConnection } from '../../sagemakerunifiedstudio/auth/model' const mappingFileName = '.sagemaker-space-profiles' const mappingFilePath = path.join(os.homedir(), '.aws', mappingFileName) @@ -74,10 +75,11 @@ export async function persistLocalCredentials(spaceArn: string): Promise { export async function persistSmusProjectCreds(spaceArn: string, node: SagemakerUnifiedStudioSpaceNode): Promise { const nodeParent = node.getParent() as SageMakerUnifiedStudioSpacesParentNode const authProvider = nodeParent.getAuthProvider() + const activeConnection = authProvider.activeConnection const projectId = nodeParent.getProjectId() const projectAuthProvider = await authProvider.getProjectCredentialProvider(projectId) await projectAuthProvider.getCredentials() - await setSmusSpaceSsoProfile(spaceArn, projectId) + await setSmusSpaceProfile(spaceArn, projectId, isSmusSsoConnection(activeConnection) ? 'sso' : 'iam') // Trigger SSH credential refresh for the project projectAuthProvider.startProactiveCredentialRefresh() } @@ -177,11 +179,16 @@ export async function setSpaceSsoProfile( * Sets the SM Space to map to SageMaker Unified Studio Project. * @param spaceArn - The arn of the SageMaker Unified Studio space. * @param projectId - The project ID associated with the SageMaker Unified Studio space. + * @param credentialType - The type of credential ('sso' or 'iam'). */ -export async function setSmusSpaceSsoProfile(spaceArn: string, projectId: string): Promise { +export async function setSmusSpaceProfile( + spaceArn: string, + projectId: string, + credentialType: 'iam' | 'sso' +): Promise { const data = await loadMappings() data.localCredential ??= {} - data.localCredential[spaceArn] = { type: 'sso', smusProjectId: projectId } + data.localCredential[spaceArn] = { type: credentialType, smusProjectId: projectId } await saveMappings(data) } diff --git a/packages/core/src/awsService/sagemaker/detached-server/credentials.ts b/packages/core/src/awsService/sagemaker/detached-server/credentials.ts index 748679309c8..033cf1d3b8e 100644 --- a/packages/core/src/awsService/sagemaker/detached-server/credentials.ts +++ b/packages/core/src/awsService/sagemaker/detached-server/credentials.ts @@ -29,11 +29,25 @@ export async function resolveCredentialsFor(connectionIdentifier: string): Promi switch (profile.type) { case 'iam': { - const name = profile.profileName?.split(':')[1] - if (!name) { - throw new Error(`Invalid IAM profile name for "${connectionIdentifier}"`) + if ('profileName' in profile) { + const name = profile.profileName?.split(':')[1] + if (!name) { + throw new Error(`Invalid IAM profile name for "${connectionIdentifier}"`) + } + return fromIni({ profile: name }) + } else if ('smusProjectId' in profile) { + const { accessKey, secret, token } = mapping.smusProjects?.[profile.smusProjectId] || {} + if (!accessKey || !secret || !token) { + throw new Error(`Missing ProjectRole credentials for SMUS Space "${connectionIdentifier}"`) + } + return { + accessKeyId: accessKey, + secretAccessKey: secret, + sessionToken: token, + } + } else { + throw new Error(`Missing IAM credentials for "${connectionIdentifier}"`) } - return fromIni({ profile: name }) } case 'sso': { if ('accessKey' in profile && 'secret' in profile && 'token' in profile) { diff --git a/packages/core/src/awsService/sagemaker/detached-server/errorPage.ts b/packages/core/src/awsService/sagemaker/detached-server/errorPage.ts index bff3e62ae61..d1c88ca3fec 100644 --- a/packages/core/src/awsService/sagemaker/detached-server/errorPage.ts +++ b/packages/core/src/awsService/sagemaker/detached-server/errorPage.ts @@ -32,7 +32,11 @@ export const getVSCodeErrorTitle = (error: SageMakerServiceException): string => return ErrorText.StartSession[ExceptionType.DEFAULT].Title } -export const getVSCodeErrorText = (error: SageMakerServiceException, isSmus?: boolean): string => { +export const getVSCodeErrorText = ( + error: SageMakerServiceException, + isSmus?: boolean, + isSmusIamConn?: boolean +): string => { const exceptionType = error.name as ExceptionType switch (exceptionType) { @@ -41,9 +45,12 @@ export const getVSCodeErrorText = (error: SageMakerServiceException, isSmus?: bo return ErrorText.StartSession[exceptionType].Text.replace('{message}', error.message) case ExceptionType.EXPIRED_TOKEN: // Use SMUS-specific message if in SMUS context - return isSmus - ? ErrorText.StartSession[ExceptionType.EXPIRED_TOKEN].SmusText - : ErrorText.StartSession[exceptionType].Text + if (isSmus) { + return isSmusIamConn + ? ErrorText.StartSession[ExceptionType.EXPIRED_TOKEN].SmusIamText + : ErrorText.StartSession[ExceptionType.EXPIRED_TOKEN].SmusSsoText + } + return ErrorText.StartSession[exceptionType].Text case ExceptionType.INTERNAL_FAILURE: case ExceptionType.RESOURCE_LIMIT_EXCEEDED: case ExceptionType.THROTTLING: @@ -66,8 +73,10 @@ export const ErrorText = { [ExceptionType.EXPIRED_TOKEN]: { Title: 'Authentication expired', Text: 'Your session has expired. Please refresh your credentials and try again.', - SmusText: - 'Your session has expired. This is likely due to network connectivity issues after machine sleep/resume. Please wait 10-30 seconds for automatic credential refresh, then try again. If the issue persists, try reconnecting through AWS Toolkit.', + SmusSsoText: + 'Your session has expired. This is likely due to network connectivity issues after machine sleep/resume. Wait 10-30 seconds for automatic credential refresh, then try again. If the issue persists, try reconnecting through AWS Toolkit.', + SmusIamText: + 'Your session has expired. Update the credentials associated with the IAM profile or use a valid IAM profile, then try again.', }, [ExceptionType.INTERNAL_FAILURE]: { Title: 'Failed to connect remotely to VSCode', diff --git a/packages/core/src/awsService/sagemaker/detached-server/routes/getSession.ts b/packages/core/src/awsService/sagemaker/detached-server/routes/getSession.ts index 0c9ce74ad30..2db2d11ddeb 100644 --- a/packages/core/src/awsService/sagemaker/detached-server/routes/getSession.ts +++ b/packages/core/src/awsService/sagemaker/detached-server/routes/getSession.ts @@ -6,7 +6,7 @@ // Disabled: detached server files cannot import vscode. /* eslint-disable aws-toolkits/no-console-log */ import { IncomingMessage, ServerResponse } from 'http' -import { startSagemakerSession, parseArn, isSmusConnection } from '../utils' +import { startSagemakerSession, parseArn, isSmusConnection, isSmusIamConnection } from '../utils' import { resolveCredentialsFor } from '../credentials' import url from 'url' import { SageMakerServiceException } from '@amzn/sagemaker-client' @@ -35,6 +35,7 @@ export async function handleGetSession(req: IncomingMessage, res: ServerResponse const { region } = parseArn(connectionIdentifier) // Detect if this is a SMUS connection for specialized error handling const isSmus = await isSmusConnection(connectionIdentifier) + const isSmusIamConn = await isSmusIamConnection(connectionIdentifier) try { const session = await startSagemakerSession({ region, connectionIdentifier, credentials }) @@ -50,7 +51,7 @@ export async function handleGetSession(req: IncomingMessage, res: ServerResponse const error = err as SageMakerServiceException console.error(`Failed to start SageMaker session for ${connectionIdentifier}:`, err) const errorTitle = getVSCodeErrorTitle(error) - const errorText = getVSCodeErrorText(error, isSmus) + const errorText = getVSCodeErrorText(error, isSmus, isSmusIamConn) await openErrorPage(errorTitle, errorText) res.writeHead(500, { 'Content-Type': 'text/plain' }) res.end('Failed to start SageMaker session') diff --git a/packages/core/src/awsService/sagemaker/detached-server/utils.ts b/packages/core/src/awsService/sagemaker/detached-server/utils.ts index d4c963c40ff..fdbd1da1ab2 100644 --- a/packages/core/src/awsService/sagemaker/detached-server/utils.ts +++ b/packages/core/src/awsService/sagemaker/detached-server/utils.ts @@ -147,6 +147,24 @@ export async function isSmusConnection(connectionIdentifier: string): Promise - true if SMUS IAM connection, false otherwise + */ +export async function isSmusIamConnection(connectionIdentifier: string): Promise { + try { + const mapping = await readMapping() + const profile = mapping.localCredential?.[connectionIdentifier] + + // Check if profile exists, has smusProjectId, and type is 'iam' + return profile && 'smusProjectId' in profile && profile.type === 'iam' + } catch (err) { + // If we can't detect it is iam connection, assume not SMUS IAM to avoid breaking existing functionality + return false + } +} + /** * Writes the mapping to a temp file and atomically renames it to the target path. * Uses a queue to prevent race conditions when multiple requests try to write simultaneously. diff --git a/packages/core/src/awsService/sagemaker/hyperpodCommands.ts b/packages/core/src/awsService/sagemaker/hyperpodCommands.ts index 059da96a9dd..454bcc49c1d 100644 --- a/packages/core/src/awsService/sagemaker/hyperpodCommands.ts +++ b/packages/core/src/awsService/sagemaker/hyperpodCommands.ts @@ -16,9 +16,30 @@ const localize = nls.loadMessageBundle() export async function openHyperPodRemoteConnection(node: SagemakerDevSpaceNode): Promise { await startHyperpodSpaceCommand(node) + await waitForDevSpaceRunning(node) await connectToHyperPodDevSpace(node) } +async function waitForDevSpaceRunning(node: SagemakerDevSpaceNode): Promise { + const kubectlClient = node.getParent().getKubectlClient(node.hpCluster.clusterName) + if (!kubectlClient) { + getLogger().error(`No kubectlClient available for cluster: ${node.hpCluster.clusterName}`) + return + } + const timeout = 5 * 60 * 1000 // 5 minutes + const startTime = Date.now() + + while (Date.now() - startTime < timeout) { + const status = await kubectlClient.getHyperpodSpaceStatus(node.devSpace) + if (status === 'Running') { + return + } + await new Promise((resolve) => setTimeout(resolve, 5000)) + } + + throw new Error('Timeout waiting for dev space to reach Running status') +} + export async function connectToHyperPodDevSpace(node: SagemakerDevSpaceNode): Promise { const logger = getLogger() diff --git a/packages/core/src/awsService/sagemaker/types.ts b/packages/core/src/awsService/sagemaker/types.ts index 82f4d4f92d6..76eb3c23ea9 100644 --- a/packages/core/src/awsService/sagemaker/types.ts +++ b/packages/core/src/awsService/sagemaker/types.ts @@ -12,7 +12,7 @@ export interface SpaceMappings { export type LocalCredentialProfile = | { type: 'iam'; profileName: string } | { type: 'sso'; accessKey: string; secret: string; token: string } - | { type: 'sso'; smusProjectId: string } + | { type: 'sso' | 'iam'; smusProjectId: string } export interface DeeplinkSession { requests: Record diff --git a/packages/core/src/sagemakerunifiedstudio/auth/authenticationOrchestrator.ts b/packages/core/src/sagemakerunifiedstudio/auth/authenticationOrchestrator.ts new file mode 100644 index 00000000000..208f7ccd95c --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/auth/authenticationOrchestrator.ts @@ -0,0 +1,330 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as vscode from 'vscode' +import { getLogger } from '../../shared/logger/logger' +import { ToolkitError } from '../../shared/errors' +import { SmusErrorCodes } from '../shared/smusUtils' +import { SmusAuthenticationProvider } from './providers/smusAuthenticationProvider' + +import { SmusSsoAuthenticationUI } from './ui/ssoAuthentication' +import { + SmusIamProfileSelector, + IamProfileSelection, + IamProfileEditingInProgress, + IamProfileBackNavigation, +} from './ui/iamProfileSelection' +import { SmusAuthenticationPreferencesManager } from './preferences/authenticationPreferences' +import { DataZoneCustomClientHelper } from '../shared/client/datazoneCustomClientHelper' +import { recordAuthTelemetry } from '../shared/telemetry' + +export type SmusAuthenticationMethod = 'sso' | 'iam' + +export type SmusAuthenticationResult = + | { status: 'SUCCESS' } + | { status: 'BACK' } + | { status: 'EDITING' } + | { status: 'INVALID_PROFILE'; error: string } + +/** + * Orchestrates SMUS authentication flows + */ +export class SmusAuthenticationOrchestrator { + private static readonly logger = getLogger('smus') + + /** + * Handles IAM authentication flow + * @param authProvider The SMUS authentication provider + * @param span Telemetry span + * @param context Extension context + * @param existingProfileName Optional profile name to re-authenticate with (skips profile selection) + * @param existingRegion Optional region to use (skips region selection) + */ + public static async handleIamAuthentication( + authProvider: SmusAuthenticationProvider, + span: any, + context: vscode.ExtensionContext, + existingProfileName?: string, + existingRegion?: string + ): Promise { + const logger = this.logger + + try { + let profileSelection: IamProfileSelection | IamProfileEditingInProgress | IamProfileBackNavigation + + // If profile and region are provided, skip profile selection (re-authentication case) + if (existingProfileName && existingRegion) { + logger.debug( + `Auth: Re-authenticating with existing profile: ${existingProfileName}, region: ${existingRegion}` + ) + profileSelection = { + profileName: existingProfileName, + region: existingRegion, + } + } else { + // Show IAM profile selection dialog + profileSelection = await SmusIamProfileSelector.showIamProfileSelection() + } + + // Handle different result types + if ('isBack' in profileSelection) { + // User chose to go back to authentication method selection + logger.debug('User chose to go back to authentication method selection') + return { status: 'BACK' } + } + + if ('isEditing' in profileSelection) { + // User chose to edit credentials or is in editing mode + logger.debug('User is editing credentials') + return { status: 'EDITING' } + } + + // At this point, we have a profile selected + logger.debug(`Selected profile: ${profileSelection.profileName}, region: ${profileSelection.region}`) + + // Validate the selected profile + const validation = await authProvider.validateIamProfile(profileSelection.profileName) + if (!validation.isValid) { + logger.debug(`Profile validation failed: ${validation.error}`) + return { status: 'INVALID_PROFILE', error: validation.error || 'Profile validation failed' } + } + + // Discover IAM-based domain using IAM credential. If IAM-based domain is not present, we should throw an appropriate error + // and exit + logger.debug('Discovering IAM-based domain using IAM credentials') + + const domainUrl = await this.findSmusIamDomain( + authProvider, + profileSelection.profileName, + profileSelection.region + ) + if (!domainUrl) { + throw new ToolkitError('No IAM-based domains found in the specified region', { + code: SmusErrorCodes.IamDomainNotFound, + cancelled: true, + }) + } + + // Connect using IAM profile with IAM-based domain flag + const connection = await authProvider.connectWithIamProfile( + profileSelection.profileName, + profileSelection.region, + domainUrl, + true // isIamDomain - we found an IAM-based domain + ) + + if (!connection) { + throw new ToolkitError('Failed to establish IAM connection', { + code: SmusErrorCodes.FailedAuthConnecton, + }) + } + + logger.info( + `Successfully connected with IAM profile ${profileSelection.profileName} in region ${profileSelection.region} to IAM-based domain` + ) + + // Extract domain ID and region for telemetry logging + const domainId = connection.domainId + const region = authProvider.getDomainRegion() + + logger.info(`Connected to SageMaker Unified Studio domain: ${domainId} in region ${region}`) + await this.recordAuthTelemetry(span, authProvider, domainId, region) + + // Refresh the tree view to show authenticated state + try { + await vscode.commands.executeCommand('aws.smus.rootView.refresh') + } catch (refreshErr) { + logger.debug(`Failed to refresh views after login: ${(refreshErr as Error).message}`) + } + + // After successful IAM authentication (IAM mode), automatically open project picker + logger.debug('IAM authentication successful, opening project picker') + try { + await vscode.commands.executeCommand('aws.smus.switchProject') + } catch (pickerErr) { + logger.debug(`Failed to open project picker: ${(pickerErr as Error).message}`) + } + + // Ask to remember authentication method preference (non-blocking) + void this.askToRememberAuthMethod(context, 'iam') + + // Return success to complete the authentication flow gracefully + return { status: 'SUCCESS' } + } catch (error) { + // Handle user cancellation (including editing mode) + if ( + error instanceof ToolkitError && + (error.code === SmusErrorCodes.UserCancelled || error.code === SmusErrorCodes.IamDomainNotFound) + ) { + logger.debug('IAM authentication cancelled by user or failed due to customer error') + throw error // Re-throw to be handled by the main loop + } else { + // Log the error for actual failures + logger.error('IAM authentication failed: %s', (error as Error).message) + throw error + } + } + } + + /** + * Handles SSO authentication flow + */ + public static async handleSsoAuthentication( + authProvider: SmusAuthenticationProvider, + span: any, + context: vscode.ExtensionContext + ): Promise { + const logger = this.logger + logger.debug('Starting SSO authentication flow') + + // Show domain URL input dialog with back button support + const domainUrl = await SmusSsoAuthenticationUI.showDomainUrlInput() + + logger.debug(`Domain URL input result: ${domainUrl ? 'provided' : 'cancelled or back'}`) + + if (domainUrl === 'BACK') { + // User wants to go back to authentication method selection + logger.debug('User chose to go back from domain URL input') + return { status: 'BACK' } + } + + if (!domainUrl) { + // User cancelled + logger.debug('User cancelled domain URL input') + throw new ToolkitError('User cancelled domain URL input', { + cancelled: true, + code: SmusErrorCodes.UserCancelled, + }) + } + + try { + // Connect to SMUS using the authentication provider + const connection = await authProvider.connectToSmusWithSso(domainUrl) + + if (!connection) { + throw new ToolkitError('Failed to establish connection', { + code: SmusErrorCodes.FailedAuthConnecton, + }) + } + + // Extract domain account ID, domain ID, and region for logging + const domainId = connection.domainId + const region = authProvider.getDomainRegion() // Use the auth provider method that handles both connection types + + logger.info(`Connected to SageMaker Unified Studio domain: ${domainId} in region ${region}`) + await this.recordAuthTelemetry(span, authProvider, domainId, region) + + // Ask to remember authentication method preference + await this.askToRememberAuthMethod(context, 'sso') + + // Immediately refresh the tree view to show authenticated state + try { + await vscode.commands.executeCommand('aws.smus.rootView.refresh') + } catch (refreshErr) { + logger.debug(`Failed to refresh views after login: ${(refreshErr as Error).message}`) + } + + return { status: 'SUCCESS' } + } catch (connectionErr) { + // Clear the status bar message + vscode.window.setStatusBarMessage('Connection to SageMaker Unified Studio Failed') + + // Log the error and re-throw to be handled by the outer catch block + logger.error('Connection failed: %s', (connectionErr as Error).message) + throw new ToolkitError('Connection failed.', { + cause: connectionErr as Error, + code: (connectionErr as Error).name, + }) + } + } + + /** + * Asks the user if they want to remember their authentication method choice after successful login + */ + private static async askToRememberAuthMethod( + context: vscode.ExtensionContext, + method: SmusAuthenticationMethod + ): Promise { + const logger = this.logger + + try { + const methodName = method === 'sso' ? 'SSO Authentication' : 'IAM Credential Profile' + + const result = await vscode.window.showInformationMessage( + `Remember ${methodName} as your preferred authentication method for SageMaker Unified Studio?`, + 'Yes', + 'No' + ) + + if (result === 'Yes') { + logger.debug(`Saving user preference: ${method}`) + await SmusAuthenticationPreferencesManager.setPreferredMethod(context, method, true) + logger.debug(`Preference saved successfully`) + } + } catch (error) { + // Not a hard failure, so not throwing error + logger.warn('Error asking to remember auth method: %s', error) + } + } + + /** + * Finds SMUS IAM-based domain using IAM credentials + * @param authProvider The SMUS authentication provider + * @param profileName The AWS credential profile name + * @param region The AWS region + * @returns Promise resolving to domain URL or undefined if no IAM-based domain found + */ + private static async findSmusIamDomain( + authProvider: SmusAuthenticationProvider, + profileName: string, + region: string + ): Promise { + const logger = this.logger + + try { + logger.debug(`Finding IAM-based domain in region ${region} using profile ${profileName}`) + + // Get DataZoneCustomClientHelper instance + const datazoneCustomClientHelper = DataZoneCustomClientHelper.getInstance( + await authProvider.getCredentialsProviderForIamProfile(profileName), + region + ) + + // Find the IAM-based domain using the client + const iamDomain = await datazoneCustomClientHelper.getIamDomain() + + if (!iamDomain) { + logger.warn(`No IAM-based domain found in region ${region}`) + return undefined + } + + logger.debug(`Found IAM-based domain: ${iamDomain.name} (${iamDomain.id})`) + + // Construct domain URL from the IAM-based domain + const domainUrl = iamDomain.portalUrl || `https://${iamDomain.id}.sagemaker.${region}.on.aws/` + logger.info(`Discovered IAM-based domain URL: ${domainUrl}`) + + return domainUrl + } catch (error) { + logger.error(`Failed to find IAM-based domain: %s`, error) + throw new ToolkitError(`Failed to find IAM-based domain: ${(error as Error).message}`, { + code: SmusErrorCodes.ApiTimeout, + cause: error instanceof Error ? error : undefined, + }) + } + } + + /** + * Records authentication telemetry + */ + private static async recordAuthTelemetry( + span: any, + authProvider: SmusAuthenticationProvider, + domainId: string, + region: string + ): Promise { + await recordAuthTelemetry(span, authProvider, domainId, region) + } +} diff --git a/packages/core/src/sagemakerunifiedstudio/auth/credentialExpiryHandler.ts b/packages/core/src/sagemakerunifiedstudio/auth/credentialExpiryHandler.ts new file mode 100644 index 00000000000..3c03b2875e4 --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/auth/credentialExpiryHandler.ts @@ -0,0 +1,241 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as vscode from 'vscode' +import { getLogger } from '../../shared/logger/logger' +import { ToolkitError } from '../../shared/errors' +import { SmusErrorCodes } from '../shared/smusUtils' +import { SmusIamProfileSelector } from './ui/iamProfileSelection' +import { getCredentialsFilename, getConfigFilename } from '../../auth/credentials/sharedCredentialsFile' +import type { SmusAuthenticationProvider } from './providers/smusAuthenticationProvider' + +export enum IamCredentialExpiryAction { + Reauthenticate = 'reauthenticate', + EditCredentials = 'editCredentials', + SwitchProfile = 'switchProfile', + SignOut = 'signOut', + Cancelled = 'cancelled', +} + +export type IamCredentialExpiryResult = + | { action: IamCredentialExpiryAction.Reauthenticate } + | { action: IamCredentialExpiryAction.EditCredentials } + | { action: IamCredentialExpiryAction.SwitchProfile } + | { action: IamCredentialExpiryAction.SignOut } + | { action: IamCredentialExpiryAction.Cancelled } + +/** + * Shows credential expiry options for IAM connections + * Provides options to re-authenticate, edit credentials, switch profiles, or sign out + * @param authProvider The SMUS authentication provider + * @param connection The expired IAM connection + * @param extensionContext The extension context + * @returns Promise that resolves with the action taken + */ +export async function showIamCredentialExpiryOptions( + authProvider: SmusAuthenticationProvider, + connection: any, + extensionContext: vscode.ExtensionContext +): Promise { + const logger = getLogger('smus') + + type QuickPickItemWithAction = vscode.QuickPickItem & { action: IamCredentialExpiryAction } + const options: QuickPickItemWithAction[] = [ + { + label: '$(sync) Re-authenticate with current profile', + description: `Profile: ${connection.profileName}`, + detail: 'Refresh credentials using the same IAM profile', + action: IamCredentialExpiryAction.Reauthenticate, + }, + { + label: '$(file-text) Edit credentials file', + description: 'Open ~/.aws/credentials and ~/.aws/config', + detail: 'Manually update your AWS credentials', + action: IamCredentialExpiryAction.EditCredentials, + }, + { + label: '$(arrow-swap) Switch to another profile', + description: 'Select a different IAM profile', + detail: 'Choose from available credential profiles', + action: IamCredentialExpiryAction.SwitchProfile, + }, + { + label: '$(trash) Sign out', + description: 'Sign out from this connection', + detail: 'Remove the expired connection', + action: IamCredentialExpiryAction.SignOut, + }, + ] + + const quickPick = vscode.window.createQuickPick() + quickPick.title = 'IAM Credentials Expired' + quickPick.placeholder = 'Choose how to fix your expired credentials' + quickPick.items = options + quickPick.canSelectMany = false + quickPick.ignoreFocusOut = true + + return new Promise((resolve, reject) => { + let isCompleted = false + + quickPick.onDidAccept(async () => { + const selectedItem = quickPick.selectedItems[0] + if (!selectedItem) { + quickPick.dispose() + reject(new ToolkitError('No option selected', { code: SmusErrorCodes.UserCancelled, cancelled: true })) + return + } + + isCompleted = true + quickPick.dispose() + + const itemWithAction = selectedItem as QuickPickItemWithAction + + try { + switch (itemWithAction.action) { + case IamCredentialExpiryAction.Reauthenticate: { + logger.debug( + `SMUS: Re-authenticating with current IAM profile: ${connection.profileName} in region ${connection.region}` + ) + // For IAM connections, just validate the credentials are still valid + // The auth system will handle refreshing them automatically + const validation = await authProvider.validateIamProfile(connection.profileName) + if (validation.isValid) { + // Credentials are valid, refresh the connection state + await authProvider.auth.refreshConnectionState(connection) + void vscode.window.showInformationMessage( + 'Successfully reauthenticated with SageMaker Unified Studio' + ) + resolve({ action: IamCredentialExpiryAction.Reauthenticate }) + } else { + const errorMsg = validation.error || 'Unknown validation error' + // Throw error for telemetry - activation.ts will show the notification + throw new ToolkitError( + `Failed to re-authenticate, ensure credential has been updated: ${errorMsg}`, + { code: SmusErrorCodes.IamValidationFailed } + ) + } + break + } + case IamCredentialExpiryAction.EditCredentials: { + logger.debug('Opening AWS credentials and config files for editing') + // Open both credentials and config files like AWS Explorer does + const credentialsPath = getCredentialsFilename() + const configPath = getConfigFilename() + + // Open both files + const [credentialsDoc, configDoc] = await Promise.all([ + vscode.workspace.openTextDocument(credentialsPath), + vscode.workspace.openTextDocument(configPath), + ]) + + // Show both documents + await vscode.window.showTextDocument(credentialsDoc, { preview: false }) + await vscode.window.showTextDocument(configDoc, { + preview: false, + viewColumn: vscode.ViewColumn.Beside, + }) + + void vscode.window.showInformationMessage( + 'AWS credentials and config files opened. Please update your credentials and try reconnecting.' + ) + resolve({ action: IamCredentialExpiryAction.EditCredentials }) + break + } + case IamCredentialExpiryAction.SwitchProfile: { + logger.debug('Switching to another IAM profile') + try { + const profileSelection = await SmusIamProfileSelector.showIamProfileSelection() + + // Handle back navigation - show the credential expiry menu again + if ('isBack' in profileSelection) { + logger.debug('User clicked back, showing credential expiry options again') + // Recursively show the credential expiry options menu + const result = await showIamCredentialExpiryOptions( + authProvider, + connection, + extensionContext + ) + resolve(result) + return + } + + // Handle editing mode - This is if user picks edit during the profile selection + if ('isEditing' in profileSelection) { + logger.debug('User is editing credentials') + resolve({ action: IamCredentialExpiryAction.EditCredentials }) + return + } + + // User selected a new profile, authenticate with it using the selected profile + // Use dynamic import to avoid circular dependency + const { SmusAuthenticationOrchestrator } = await import('./authenticationOrchestrator.js') + const result = await SmusAuthenticationOrchestrator.handleIamAuthentication( + authProvider, + { record: () => {} }, // Minimal span object + extensionContext, + profileSelection.profileName, + profileSelection.region + ) + + if (result.status === 'SUCCESS') { + void vscode.window.showInformationMessage( + `Successfully switched to profile: ${profileSelection.profileName}` + ) + resolve({ action: IamCredentialExpiryAction.SwitchProfile }) + } else if (result.status === 'INVALID_PROFILE') { + void vscode.window.showErrorMessage(`Failed to switch profile: ${result.error}`) + resolve({ action: IamCredentialExpiryAction.SwitchProfile }) + } else { + // BACK or EDITING - shouldn't happen here but handle gracefully + resolve({ action: IamCredentialExpiryAction.Cancelled }) + } + } catch (switchError) { + // Handle user cancellation gracefully + if ( + switchError instanceof ToolkitError && + switchError.code === SmusErrorCodes.UserCancelled + ) { + logger.debug('Profile switch cancelled by user') + resolve({ action: IamCredentialExpiryAction.Cancelled }) + } else { + // Show error message for actual failures + const errorMsg = (switchError as Error).message + void vscode.window.showErrorMessage(`Failed to switch profile: ${errorMsg}`) + logger.error('Profile switch failed: %s', switchError) + resolve({ action: IamCredentialExpiryAction.SwitchProfile }) + } + } + break + } + case IamCredentialExpiryAction.SignOut: { + logger.debug('Signing out from connection') + // Use the provider's signOut method which properly handles metadata cleanup + await authProvider.signOut() + void vscode.window.showInformationMessage('Successfully signed out') + resolve({ action: IamCredentialExpiryAction.SignOut }) + break + } + } + } catch (error) { + logger.error('Failed to handle credential expiry action: %s', error) + // Only show error for non-reauthenticate cases (reauthenticate handles its own errors) + if (itemWithAction.action !== IamCredentialExpiryAction.Reauthenticate) { + void vscode.window.showErrorMessage(`Failed to complete action: ${(error as Error).message}`) + } + reject(error) + } + }) + + quickPick.onDidHide(() => { + if (!isCompleted) { + quickPick.dispose() + logger.debug('Credential expiry options cancelled by user') + resolve({ action: IamCredentialExpiryAction.Cancelled }) + } + }) + + quickPick.show() + }) +} diff --git a/packages/core/src/sagemakerunifiedstudio/auth/model.ts b/packages/core/src/sagemakerunifiedstudio/auth/model.ts index 6e60fa20e96..adc6fd29455 100644 --- a/packages/core/src/sagemakerunifiedstudio/auth/model.ts +++ b/packages/core/src/sagemakerunifiedstudio/auth/model.ts @@ -3,29 +3,53 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { SsoProfile, SsoConnection } from '../../auth/connection' +import { SsoProfile, SsoConnection, Connection, IamConnection } from '../../auth/connection' +import { DevSettings } from '../../shared/settings' /** - * Scope for SageMaker Unified Studio authentication + * Default scope for SageMaker Unified Studio authentication */ export const scopeSmus = 'datazone:domain:access' +/** + * Gets the DataZone SSO scope from user settings or returns the default + */ +export function getDataZoneSsoScope(): string { + const devSettings = DevSettings.instance + return devSettings.get('datazoneScope', scopeSmus) +} + /** * SageMaker Unified Studio profile extending the base SSO profile */ -export interface SmusProfile extends SsoProfile { +export interface SmusSsoProfile extends SsoProfile { readonly domainUrl: string readonly domainId: string } /** - * SageMaker Unified Studio connection extending the base SSO connection + * SageMaker Unified Studio SSO connection extending the base SSO connection */ -export interface SmusConnection extends SmusProfile, SsoConnection { +export interface SmusSsoConnection extends SmusSsoProfile, SsoConnection { readonly id: string readonly label: string } +/** + * SageMaker Unified Studio IAM connection for credential profile authentication + */ +export interface SmusIamConnection extends IamConnection { + readonly profileName: string + readonly region: string + readonly domainUrl: string + readonly domainId: string +} + +/** + * Union type for all SMUS connection types (SSO and IAM) + */ +export type SmusConnection = SmusSsoConnection | SmusIamConnection + /** * Creates a SageMaker Unified Studio profile * @param domainUrl The SageMaker Unified Studio domain URL @@ -39,8 +63,8 @@ export function createSmusProfile( domainId: string, startUrl: string, region: string, - scopes = [scopeSmus] -): SmusProfile & { readonly scopes: string[] } { + scopes = [getDataZoneSsoScope()] +): SmusSsoProfile & { readonly scopes: string[] } { return { scopes, type: 'sso', @@ -52,17 +76,72 @@ export function createSmusProfile( } /** - * Checks if a connection is a valid SageMaker Unified Studio connection + * Type guard to check if a connection is a SMUS IAM connection * @param conn Connection to check - * @returns True if the connection is a valid SMUS connection + * @returns True if the connection is a SMUS IAM connection + */ +export function isSmusIamConnection(conn?: Connection): conn is SmusIamConnection { + return !!( + conn && + conn.type === 'iam' && + 'profileName' in conn && + 'region' in conn && + 'domainId' in conn && + typeof conn.profileName === 'string' && + typeof conn.region === 'string' && + typeof conn.domainId === 'string' + ) +} + +/** + * Type guard to check if a connection is a SMUS SSO connection + * @param conn Connection to check + * @returns True if the connection is a SMUS SSO connection */ -export function isValidSmusConnection(conn?: any): conn is SmusConnection { +export function isSmusSsoConnection(conn?: Connection): conn is SmusSsoConnection { if (!conn || conn.type !== 'sso') { return false } - // Check if the connection has the required SMUS scope - const hasScope = Array.isArray(conn.scopes) && conn.scopes.includes(scopeSmus) + // Check if the connection has the required SMUS scope (check both default and custom scope) + const configuredScope = getDataZoneSsoScope() + const hasScope = + Array.isArray((conn as any).scopes) && + ((conn as any).scopes.includes(scopeSmus) || (conn as any).scopes.includes(configuredScope)) // Check if the connection has the required SMUS properties const hasSmusProps = 'domainUrl' in conn && 'domainId' in conn return !!hasScope && !!hasSmusProps } + +/** + * Checks if a connection is a valid SageMaker Unified Studio connection (either SSO or IAM) + * @param conn Connection to check + * @param smusMetadata Optional SMUS metadata for IAM connections + * @returns True if the connection is a valid SMUS connection + */ +export function isValidSmusConnection(conn?: any, smusMetadata?: any): conn is SmusConnection | IamConnection { + // Accept SMUS SSO connections + if (isSmusSsoConnection(conn)) { + return true + } + + // For IAM connections, check if they have SMUS metadata either in the connection or separately + if (conn && conn.type === 'iam') { + // Check if connection already has SMUS properties + if (isSmusIamConnection(conn)) { + return true + } + + // Check if we have separate SMUS metadata for this IAM connection + if ( + smusMetadata && + typeof smusMetadata.profileName === 'string' && + typeof smusMetadata.region === 'string' && + typeof smusMetadata.domainUrl === 'string' && + typeof smusMetadata.domainId === 'string' + ) { + return true + } + } + + return false +} diff --git a/packages/core/src/sagemakerunifiedstudio/auth/preferences/authenticationPreferences.ts b/packages/core/src/sagemakerunifiedstudio/auth/preferences/authenticationPreferences.ts new file mode 100644 index 00000000000..47416e73250 --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/auth/preferences/authenticationPreferences.ts @@ -0,0 +1,191 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as vscode from 'vscode' +import { getLogger } from '../../../shared/logger/logger' +import globals from '../../../shared/extensionGlobals.js' +import { SmusAuthenticationMethod } from '../ui/authenticationMethodSelection.js' + +/** + * Configuration for IAM profile preferences + */ +export interface SmusIamProfileConfig { + profileName: string + region: string + lastUsed?: Date + isDefault?: boolean +} + +/** + * SMUS authentication preferences + */ +export interface SmusAuthenticationPreferences { + preferredMethod?: SmusAuthenticationMethod + lastUsedSsoConnection?: string + lastUsedIamProfile?: SmusIamProfileConfig + rememberChoice: boolean +} + +/** + * Manager for SMUS authentication preferences + */ +export class SmusAuthenticationPreferencesManager { + private static readonly logger = getLogger('smus') + // eslint-disable-next-line @typescript-eslint/naming-convention + private static readonly PREFERENCES_KEY = 'aws.smus.authenticationPreferences' + + /** + * Gets the current authentication preferences + * @param context VS Code extension context (unused, kept for API compatibility) + * @returns Current authentication preferences + */ + public static getPreferences(context?: vscode.ExtensionContext): SmusAuthenticationPreferences { + const stored = globals.globalState.get(this.PREFERENCES_KEY) + + return { + rememberChoice: false, + ...stored, + } + } + + /** + * Updates authentication preferences + * @param context VS Code extension context (unused, kept for API compatibility) + * @param preferences Preferences to update + */ + public static async updatePreferences( + context: vscode.ExtensionContext, + preferences: Partial + ): Promise { + const logger = this.logger + + const current = this.getPreferences() + const updated = { ...current, ...preferences } + + logger.debug( + `SMUS Auth: Updating authentication preferences - preferredMethod: ${updated.preferredMethod}, rememberChoice: ${updated.rememberChoice}` + ) + + await globals.globalState.update(this.PREFERENCES_KEY, updated) + } + + /** + * Sets the preferred authentication method + * @param context VS Code extension context + * @param method Preferred authentication method + * @param rememberChoice Whether to remember this choice + */ + public static async setPreferredMethod( + context: vscode.ExtensionContext, + method: SmusAuthenticationMethod, + rememberChoice: boolean + ): Promise { + await this.updatePreferences(context, { + preferredMethod: method, + rememberChoice, + }) + } + + /** + * Gets the preferred authentication method + * @param context VS Code extension context (unused, kept for API compatibility) + * @returns Preferred authentication method or undefined if not set + */ + public static getPreferredMethod(context?: vscode.ExtensionContext): SmusAuthenticationMethod | undefined { + const preferences = this.getPreferences() + return preferences.rememberChoice ? preferences.preferredMethod : undefined + } + + /** + * Sets the last used SSO connection + * @param context VS Code extension context + * @param connectionId Connection ID + */ + public static async setLastUsedSsoConnection( + context: vscode.ExtensionContext, + connectionId: string + ): Promise { + await this.updatePreferences(context, { + lastUsedSsoConnection: connectionId, + }) + } + + /** + * Sets the last used IAM profile configuration + * @param context VS Code extension context + * @param profileConfig IAM profile configuration + */ + public static async setLastUsedIamProfile( + context: vscode.ExtensionContext, + profileConfig: SmusIamProfileConfig + ): Promise { + await this.updatePreferences(context, { + lastUsedIamProfile: { + ...profileConfig, + lastUsed: new Date(), + }, + }) + } + + /** + * Gets the last used IAM profile configuration + * @param context VS Code extension context (unused, kept for API compatibility) + * @returns Last used IAM profile configuration or undefined + */ + public static getLastUsedIamProfile(context?: vscode.ExtensionContext): SmusIamProfileConfig | undefined { + const preferences = this.getPreferences() + return preferences.lastUsedIamProfile + } + + /** + * Clears all authentication preferences + * @param context VS Code extension context (unused, kept for API compatibility) + */ + public static async clearPreferences(context?: vscode.ExtensionContext): Promise { + const logger = this.logger + logger.debug('Clearing authentication preferences') + + await globals.globalState.update(this.PREFERENCES_KEY, undefined) + } + + /** + * Clears only connection-specific preferences, preserving authentication method preference + * @param context VS Code extension context (unused, kept for API compatibility) + */ + public static async clearConnectionPreferences(context?: vscode.ExtensionContext): Promise { + const logger = this.logger + logger.debug('Clearing connection-specific preferences (preserving auth method preference)') + + const currentPrefs = this.getPreferences() + + // Keep only the authentication method preference and rememberChoice flag + const preservedPrefs: SmusAuthenticationPreferences = { + preferredMethod: currentPrefs.preferredMethod, + rememberChoice: currentPrefs.rememberChoice, + // Clear connection-specific data + lastUsedSsoConnection: undefined, + lastUsedIamProfile: undefined, + } + + await globals.globalState.update(this.PREFERENCES_KEY, preservedPrefs) + } + + /** + * Switches the authentication method preference + * @param context VS Code extension context + * @param newMethod New authentication method to switch to + */ + public static async switchAuthenticationMethod( + context: vscode.ExtensionContext, + newMethod: SmusAuthenticationMethod + ): Promise { + const logger = this.logger + logger.debug(`Switching authentication method to: ${newMethod}`) + + await this.updatePreferences(context, { + preferredMethod: newMethod, + }) + } +} diff --git a/packages/core/src/sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider.ts b/packages/core/src/sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider.ts index f060e6477ab..30b8477f6cb 100644 --- a/packages/core/src/sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider.ts +++ b/packages/core/src/sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider.ts @@ -12,6 +12,7 @@ import { DataZoneClient } from '../../shared/client/datazoneClient' import { SmusAuthenticationProvider } from './smusAuthenticationProvider' import { CredentialType } from '../../../shared/telemetry/telemetry' import { SmusCredentialExpiry, validateCredentialFields } from '../../shared/smusUtils' +import { getContext } from '../../../shared/vscode/setContext' /** * Credentials provider for SageMaker Unified Studio Connection credentials @@ -19,7 +20,7 @@ import { SmusCredentialExpiry, validateCredentialFields } from '../../shared/smu * This provider implements independent caching with 10-minute expiry */ export class ConnectionCredentialsProvider implements CredentialsProvider { - private readonly logger = getLogger() + private readonly logger = getLogger('smus') private credentialCache?: { credentials: AWS.Credentials expiresAt: Date @@ -27,7 +28,8 @@ export class ConnectionCredentialsProvider implements CredentialsProvider { constructor( private readonly smusAuthProvider: SmusAuthenticationProvider, - private readonly connectionId: string + private readonly connectionId: string, + private readonly projectId: string ) {} /** @@ -106,7 +108,7 @@ export class ConnectionCredentialsProvider implements CredentialsProvider { try { return this.smusAuthProvider.isConnected() } catch (err) { - this.logger.error('SMUS Connection: Error checking if auth provider is connected: %s', err) + this.logger.error('Error checking if auth provider is connected: %s', err) return false } } @@ -116,7 +118,7 @@ export class ConnectionCredentialsProvider implements CredentialsProvider { * @returns Promise resolving to credentials */ public async getCredentials(): Promise { - this.logger.debug(`SMUS Connection: Getting credentials for connection ${this.connectionId}`) + this.logger.debug(`Getting credentials for connection ${this.connectionId}`) // Check cache first (10-minute expiry) if (this.credentialCache && this.credentialCache.expiresAt > new Date()) { @@ -131,14 +133,21 @@ export class ConnectionCredentialsProvider implements CredentialsProvider { ) try { - const datazoneClient = await DataZoneClient.getInstance(this.smusAuthProvider) + if (getContext('aws.smus.isIamMode') && this.projectId) { + return (await this.smusAuthProvider.getProjectCredentialProvider(this.projectId)).getCredentials() + } + const datazoneClient = DataZoneClient.createWithCredentials( + this.smusAuthProvider.getDomainRegion(), + this.smusAuthProvider.getDomainId(), + await this.smusAuthProvider.getDerCredentialsProvider() + ) const getConnectionResponse = await datazoneClient.getConnection({ domainIdentifier: this.smusAuthProvider.getDomainId(), identifier: this.connectionId, withSecret: true, }) - this.logger.debug(`SMUS Connection: Successfully retrieved connection details for ${this.connectionId}`) + this.logger.debug(`Successfully retrieved connection details for ${this.connectionId}`) // Extract connection credentials const connectionCredentials = getConnectionResponse.connectionCredentials @@ -219,7 +228,7 @@ export class ConnectionCredentialsProvider implements CredentialsProvider { * Clears the internal cache without fetching new credentials */ public invalidate(): void { - this.logger.debug(`SMUS Connection: Invalidating cached credentials for connection ${this.connectionId}`) + this.logger.debug(`Invalidating cached credentials for connection ${this.connectionId}`) // Clear cache to force fresh fetch on next getCredentials() call this.credentialCache = undefined this.logger.debug( diff --git a/packages/core/src/sagemakerunifiedstudio/auth/providers/domainExecRoleCredentialsProvider.ts b/packages/core/src/sagemakerunifiedstudio/auth/providers/domainExecRoleCredentialsProvider.ts index 968749a9c9c..85da4901a94 100644 --- a/packages/core/src/sagemakerunifiedstudio/auth/providers/domainExecRoleCredentialsProvider.ts +++ b/packages/core/src/sagemakerunifiedstudio/auth/providers/domainExecRoleCredentialsProvider.ts @@ -20,7 +20,7 @@ import { SmusCredentialExpiry, SmusTimeouts, SmusErrorCodes, validateCredentialF * its own credential lifecycle independently */ export class DomainExecRoleCredentialsProvider implements CredentialsProvider { - private readonly logger = getLogger() + private readonly logger = getLogger('smus') private credentialCache?: { credentials: AWS.Credentials expiresAt: Date @@ -120,15 +120,15 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider { * @returns Promise resolving to credentials */ public async getCredentials(): Promise { - this.logger.debug(`SMUS DER: Getting DER credentials for domain ${this.domainId}`) + this.logger.debug(`Getting DER credentials for domain ${this.domainId}`) // Check cache first (10-minute expiry with 5-minute buffer for proactive refresh) if (this.credentialCache && this.credentialCache.expiresAt > new Date()) { - this.logger.debug(`SMUS DER: Using cached DER credentials for domain ${this.domainId}`) + this.logger.debug(`Using cached DER credentials for domain ${this.domainId}`) return this.credentialCache.credentials } - this.logger.debug(`SMUS DER: Fetching credentials from API for domain ${this.domainId}`) + this.logger.debug(`Fetching credentials from API for domain ${this.domainId}`) try { // Get current SSO access token @@ -139,11 +139,11 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider { }) } - this.logger.debug(`SMUS DER: Got access token for refresh for domain ${this.domainId}`) + this.logger.debug(`Got access token for refresh for domain ${this.domainId}`) // Call SMUS redeem token API to get DER credentials const redeemUrl = new URL('/sso/redeem-token', this.domainUrl) - this.logger.debug(`SMUS DER: Calling redeem token endpoint: ${redeemUrl.toString()}`) + this.logger.debug(`Calling redeem token endpoint: ${redeemUrl.toString()}`) const requestBody = { domainId: this.domainId, @@ -182,14 +182,14 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider { throw fetchError } - this.logger.debug(`SMUS DER: Redeem token response status: ${response.status} for domain ${this.domainId}`) + this.logger.debug(`Redeem token response status: ${response.status} for domain ${this.domainId}`) if (!response.ok) { // Try to get response body for more details let responseBody = '' try { responseBody = await response.text() - this.logger.debug(`SMUS DER: Error response body for domain ${this.domainId}: ${responseBody}`) + this.logger.debug(`Error response body for domain ${this.domainId}: ${responseBody}`) } catch (bodyErr) { this.logger.debug( `SMUS DER: Could not read error response body for domain ${this.domainId}: ${bodyErr}` @@ -212,7 +212,7 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider { expiration: string } } - this.logger.debug(`SMUS DER: Successfully received credentials from API for domain ${this.domainId}`) + this.logger.debug(`Successfully received credentials from API for domain ${this.domainId}`) // Validate the response data structure if (!data.credentials) { @@ -265,12 +265,12 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider { credentialExpiresAt = parsedExpiration } if (!isNaN(credentialExpiresAt.getTime())) { - this.logger.debug(`SMUS DER: Credential expires at ${credentialExpiresAt.toISOString()}`) + this.logger.debug(`Credential expires at ${credentialExpiresAt.toISOString()}`) } else { - this.logger.debug(`SMUS DER: Invalid credential expiration date, using default`) + this.logger.debug(`Invalid credential expiration date, using default`) } } else { - this.logger.debug(`SMUS DER: No expiration provided, using default`) + this.logger.debug(`No expiration provided, using default`) credentialExpiresAt = new Date(Date.now() + SmusCredentialExpiry.derExpiryMs) } @@ -296,7 +296,7 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider { return awsCredentials } catch (err) { - this.logger.error('SMUS DER: Failed to fetch credentials for domain %s: %s', this.domainId, err) + this.logger.error('Failed to fetch credentials for domain %s: %s', this.domainId, err) throw new ToolkitError(`Failed to fetch DER credentials for domain ${this.domainId}: ${err}`, { code: 'DerCredentialsFetchFailed', cause: err instanceof Error ? err : undefined, @@ -309,17 +309,17 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider { * Clears the internal cache without fetching new credentials */ public invalidate(): void { - this.logger.debug(`SMUS DER: Invalidating cached DER credentials for domain ${this.domainId}`) + this.logger.debug(`Invalidating cached DER credentials for domain ${this.domainId}`) // Clear cache to force fresh fetch on next getCredentials() call this.credentialCache = undefined - this.logger.debug(`SMUS DER: Successfully invalidated DER credentials cache for domain ${this.domainId}`) + this.logger.debug(`Successfully invalidated DER credentials cache for domain ${this.domainId}`) } /** * Disposes of the provider and cleans up resources */ public dispose(): void { - this.logger.debug(`SMUS DER: Disposing DER credentials provider for domain ${this.domainId}`) + this.logger.debug(`Disposing DER credentials provider for domain ${this.domainId}`) this.invalidate() - this.logger.debug(`SMUS DER: Successfully disposed DER credentials provider for domain ${this.domainId}`) + this.logger.debug(`Successfully disposed DER credentials provider for domain ${this.domainId}`) } } diff --git a/packages/core/src/sagemakerunifiedstudio/auth/providers/projectRoleCredentialsProvider.ts b/packages/core/src/sagemakerunifiedstudio/auth/providers/projectRoleCredentialsProvider.ts index 5eb42e1fd5f..7f67d0a97af 100644 --- a/packages/core/src/sagemakerunifiedstudio/auth/providers/projectRoleCredentialsProvider.ts +++ b/packages/core/src/sagemakerunifiedstudio/auth/providers/projectRoleCredentialsProvider.ts @@ -8,11 +8,11 @@ import { ToolkitError } from '../../../shared/errors' import * as AWS from '@aws-sdk/types' import { CredentialsId, CredentialsProvider, CredentialsProviderType } from '../../../auth/providers/credentials' -import { DataZoneClient } from '../../shared/client/datazoneClient' import { SmusAuthenticationProvider } from './smusAuthenticationProvider' import { CredentialType } from '../../../shared/telemetry/telemetry' import { SmusCredentialExpiry, validateCredentialFields } from '../../shared/smusUtils' import { loadMappings, saveMappings } from '../../../awsService/sagemaker/credentialMapping' +import { createDZClientBaseOnDomainMode } from '../../explorer/nodes/utils' /** * Credentials provider for SageMaker Unified Studio Project Role credentials @@ -23,7 +23,7 @@ import { loadMappings, saveMappings } from '../../../awsService/sagemaker/creden * with any AWS SDK client (S3Client, LambdaClient, etc.) */ export class ProjectRoleCredentialsProvider implements CredentialsProvider { - private readonly logger = getLogger() + private readonly logger = getLogger('smus') private credentialCache?: { credentials: AWS.Credentials expiresAt: Date @@ -112,18 +112,18 @@ export class ProjectRoleCredentialsProvider implements CredentialsProvider { * @returns Promise resolving to credentials */ public async getCredentials(): Promise { - this.logger.debug(`SMUS Project: Getting credentials for project ${this.projectId}`) + this.logger.debug(`Getting credentials for project ${this.projectId}`) // Check cache first (10-minute expiry) if (this.credentialCache && this.credentialCache.expiresAt > new Date()) { - this.logger.debug(`SMUS Project: Using cached project credentials for project ${this.projectId}`) + this.logger.debug(`Using cached project credentials for project ${this.projectId}`) return this.credentialCache.credentials } - this.logger.debug(`SMUS Project: Fetching project credentials from API for project ${this.projectId}`) + this.logger.debug(`Fetching project credentials from API for project ${this.projectId}`) try { - const dataZoneClient = await DataZoneClient.getInstance(this.smusAuthProvider) + const dataZoneClient = await createDZClientBaseOnDomainMode(this.smusAuthProvider) const response = await dataZoneClient.getProjectDefaultEnvironmentCreds(this.projectId) this.logger.debug( @@ -167,7 +167,7 @@ export class ProjectRoleCredentialsProvider implements CredentialsProvider { return awsCredentials } catch (err) { - this.logger.error('SMUS Project: Failed to get project credentials for project %s: %s', this.projectId, err) + this.logger.error('Failed to get project credentials for project %s: %s', this.projectId, err) // Handle InvalidGrantException specially - indicates need for reauthentication if (err instanceof Error && err.name === 'InvalidGrantException') { @@ -203,7 +203,7 @@ export class ProjectRoleCredentialsProvider implements CredentialsProvider { } await saveMappings(mapping) } catch (err) { - this.logger.warn('SMUS Project: Failed to write project credentials to mapping file: %s', err) + this.logger.warn('Failed to write project credentials to mapping file: %s', err) } } @@ -221,11 +221,11 @@ export class ProjectRoleCredentialsProvider implements CredentialsProvider { */ public startProactiveCredentialRefresh(): void { if (this.sshRefreshActive) { - this.logger.debug(`SMUS Project: SSH refresh already active for project ${this.projectId}`) + this.logger.debug(`SSH refresh already active for project ${this.projectId}`) return } - this.logger.info(`SMUS Project: Starting SSH credential refresh for project ${this.projectId}`) + this.logger.info(`Starting SSH credential refresh for project ${this.projectId}`) this.sshRefreshActive = true this.lastRefreshTime = new Date() // Initialize refresh time @@ -242,7 +242,7 @@ export class ProjectRoleCredentialsProvider implements CredentialsProvider { return } - this.logger.info(`SMUS Project: Stopping SSH credential refresh for project ${this.projectId}`) + this.logger.info(`Stopping SSH credential refresh for project ${this.projectId}`) this.sshRefreshActive = false this.lastRefreshTime = undefined @@ -295,7 +295,7 @@ export class ProjectRoleCredentialsProvider implements CredentialsProvider { private shouldPerformRefresh(now: Date): boolean { if (!this.lastRefreshTime || !this.credentialCache) { // First refresh or no cached credentials - this.logger.debug(`SMUS Project: First refresh - no previous credentials for ${this.projectId}`) + this.logger.debug(`First refresh - no previous credentials for ${this.projectId}`) return true } @@ -345,7 +345,7 @@ export class ProjectRoleCredentialsProvider implements CredentialsProvider { * Clears the internal cache without fetching new credentials */ public invalidate(): void { - this.logger.debug(`SMUS Project: Invalidating cached credentials for project ${this.projectId}`) + this.logger.debug(`Invalidating cached credentials for project ${this.projectId}`) // Clear cache to force fresh fetch on next getCredentials() call this.credentialCache = undefined this.logger.debug( diff --git a/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts b/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts index 6c0f204cbd3..c83d2e1544b 100644 --- a/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts +++ b/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts @@ -4,6 +4,7 @@ */ import * as vscode from 'vscode' +import { AwsCredentialIdentity } from '@aws-sdk/types' import { Auth } from '../../../auth/auth' import { getSecondaryAuth } from '../../../auth/secondaryAuth' import { ToolkitError } from '../../../shared/errors' @@ -14,17 +15,39 @@ import * as localizedText from '../../../shared/localizedText' import { ToolkitPromptSettings } from '../../../shared/settings' import { setContext, getContext } from '../../../shared/vscode/setContext' import { getLogger } from '../../../shared/logger/logger' -import { SmusUtils, SmusErrorCodes, extractAccountIdFromResourceMetadata } from '../../shared/smusUtils' -import { createSmusProfile, isValidSmusConnection, SmusConnection } from '../model' +import { + SmusUtils, + SmusErrorCodes, + extractAccountIdFromResourceMetadata, + convertToToolkitCredentialProvider, +} from '../../shared/smusUtils' +import { + createSmusProfile, + isValidSmusConnection, + SmusConnection, + SmusIamConnection, + isSmusSsoConnection, + isSmusIamConnection, +} from '../model' +import { IamCredentialExpiryAction, showIamCredentialExpiryOptions } from '../credentialExpiryHandler' + import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsProvider' import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider' import { ConnectionCredentialsProvider } from './connectionCredentialsProvider' import { ConnectionClientStore } from '../../shared/client/connectionClientStore' import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils' -import { fromIni } from '@aws-sdk/credential-providers' +import { CredentialsProviderManager } from '../../../auth/providers/credentialsProviderManager' +import { SharedCredentialsProvider } from '../../../auth/providers/sharedCredentialsProvider' +import { CredentialsId, CredentialsProvider } from '../../../auth/providers/credentials' +import globals from '../../../shared/extensionGlobals' +import { fromContainerMetadata, fromIni, fromNodeProviderChain } from '@aws-sdk/credential-providers' import { randomUUID } from '../../../shared/crypto' import { DefaultStsClient } from '../../../shared/clients/stsClient' +import { DataZoneCustomClientHelper } from '../../shared/client/datazoneCustomClientHelper' +import { createDZClientBaseOnDomainMode } from '../../explorer/nodes/utils' import { DataZoneClient } from '../../shared/client/datazoneClient' +import { loadSharedConfigFiles } from '@smithy/shared-ini-file-loader' +import { loadSharedCredentialsProfiles } from '../../../auth/credentials/sharedCredentials' /** * Sets the context variable for SageMaker Unified Studio connection state @@ -41,6 +64,14 @@ export function setSmusConnectedContext(isConnected: boolean): Promise { export function setSmusSpaceEnvironmentContext(inSmusSpace: boolean): Promise { return setContext('aws.smus.inSmusSpaceEnvironment', inSmusSpace) } + +/** + * Sets the context variable for SMUS IAM mode state + * @param isIamMode Whether the current domain is in IAM mode + */ +export function setSmusIamModeContext(isIamMode: boolean): Promise { + return setContext('aws.smus.isIamMode', isIamMode) +} const authClassName = 'SmusAuthenticationProvider' /** @@ -48,8 +79,8 @@ const authClassName = 'SmusAuthenticationProvider' * Manages authentication state and credentials for SMUS */ export class SmusAuthenticationProvider { - private readonly logger = getLogger() - public readonly onDidChangeActiveConnection = this.secondaryAuth.onDidChangeActiveConnection + private readonly logger = getLogger('smus') + public readonly onDidChangeActiveConnection: vscode.Event private readonly onDidChangeEmitter = new vscode.EventEmitter() public readonly onDidChange = this.onDidChangeEmitter.event private credentialsProviderCache = new Map() @@ -57,17 +88,50 @@ export class SmusAuthenticationProvider { private connectionCredentialProvidersCache = new Map() private cachedDomainAccountId: string | undefined private cachedProjectAccountIds = new Map() + private iamCallerIdentityCache: { arn: string; connectionId: string } | undefined - public constructor( - public readonly auth = Auth.instance, - public readonly secondaryAuth = getSecondaryAuth( + public readonly secondaryAuth: ReturnType + + public constructor(public readonly auth = Auth.instance) { + // Create secondaryAuth after the class is constructed so we can reference instance methods + this.secondaryAuth = getSecondaryAuth( auth, 'smus', 'SageMaker Unified Studio', - isValidSmusConnection + (conn): conn is SmusConnection => { + // Use auth's state directly since secondaryAuth isn't available yet during initialization + const state = auth.getStateMemento() + const smusConnections = state.get('smus.connections') as any + const savedConnectionId = state.get('smus.savedConnectionId') as string + + // Only accept IAM connections that are currently saved for SMUS + if (conn && conn.type === 'iam') { + // Must be the exact connection that SMUS has saved AND have metadata + return ( + conn.id === savedConnectionId && + smusConnections && + smusConnections[conn.id] && + isValidSmusConnection(conn, smusConnections[conn.id]) + ) + } + + // SSO connections: Check if they have SMUS scope (always SMUS-specific) + if (conn && conn.type === 'sso') { + return isValidSmusConnection(conn) // Checks for SMUS scope + } + + // Reject everything else + return false + } ) - ) { - this.onDidChangeActiveConnection(async () => { + + // Initialize the event property + this.onDidChangeActiveConnection = this.secondaryAuth.onDidChangeActiveConnection as vscode.Event< + SmusConnection | undefined + > + + // Set up event listeners + this.secondaryAuth.onDidChangeActiveConnection(async () => { // Stop SSH credential refresh for all projects when connection changes this.stopAllSshCredentialRefresh() @@ -83,16 +147,81 @@ export class SmusAuthenticationProvider { this.cachedDomainAccountId = undefined // Clear cached project account IDs when connection changes this.cachedProjectAccountIds.clear() + // Clear cached IAM caller identity when connection changes + this.clearIamCallerIdentityCache() // Clear all clients in client store when connection changes ConnectionClientStore.getInstance().clearAll() await setSmusConnectedContext(this.isConnected()) await setSmusSpaceEnvironmentContext(SmusUtils.isInSmusSpaceEnvironment()) + + // Set IAM mode context based on connection metadata + const activeConn = this.activeConnection + if (activeConn && 'type' in activeConn && activeConn.type === 'iam') { + const smusConnections = (this.secondaryAuth.state.get('smus.connections') as any) || {} + const connectionMetadata = smusConnections[activeConn.id] + const isIamDomain = connectionMetadata?.isIamDomain || false + await setSmusIamModeContext(isIamDomain) + } else { + // Clear IAM mode context for non-IAM connections or no connection + await setSmusIamModeContext(false) + } + // Update IAM mode context in SMUS space environment + if (getContext('aws.smus.inSmusSpaceEnvironment')) { + await this.initIamModeContextInSpaceEnvironment() + } + this.onDidChangeEmitter.fire() }) // Set initial context in case event does not trigger void setSmusConnectedContext(this.isConnectionValid()) void setSmusSpaceEnvironmentContext(SmusUtils.isInSmusSpaceEnvironment()) + + // Set initial IAM mode context + void (async () => { + // Update IAM mode context in SMUS space environment + if (getContext('aws.smus.inSmusSpaceEnvironment')) { + await this.initIamModeContextInSpaceEnvironment() + } else { + const activeConn = this.activeConnection + if (activeConn && 'type' in activeConn && activeConn.type === 'iam') { + const state = this.auth.getStateMemento() + const smusConnections = (state.get('smus.connections') as any) || {} + const connectionMetadata = smusConnections[activeConn.id] + const isIamDomain = connectionMetadata?.isIamDomain || false + await setSmusIamModeContext(isIamDomain) + } else { + await setSmusIamModeContext(false) + } + } + })() + } + + /** + * Initializes IAM mode context in SMUS space environment + */ + private async initIamModeContextInSpaceEnvironment(): Promise { + try { + const resourceMetadata = getResourceMetadata() + if ( + resourceMetadata?.AdditionalMetadata?.DataZoneDomainId && + resourceMetadata?.AdditionalMetadata?.DataZoneDomainRegion + ) { + const domainId = resourceMetadata.AdditionalMetadata.DataZoneDomainId + const region = resourceMetadata.AdditionalMetadata.DataZoneDomainRegion + + const credentialsProvider = (await this.getDerCredentialsProvider()) as CredentialsProvider + + // Get DataZoneCustomClientHelper instance and check if domain is IAM mode + const datazoneCustomClientHelper = DataZoneCustomClientHelper.getInstance(credentialsProvider, region) + const isIamMode = await datazoneCustomClientHelper.isIamDomain(domainId) + this.logger.debug(`is in IAM mode ${isIamMode}`) + await setSmusIamModeContext(isIamMode) + } + } catch (error) { + this.logger.error('Failed to check IAM mode in SMUS space environment: %s', error) + await setSmusIamModeContext(false) + } } /** @@ -100,7 +229,7 @@ export class SmusAuthenticationProvider { * Called when SMUS connection changes or extension deactivates */ public stopAllSshCredentialRefresh(): void { - this.logger.debug('SMUS Auth: Stopping SSH credential refresh for all projects') + this.logger.debug('Stopping SSH credential refresh for all projects') for (const provider of this.projectCredentialProvidersCache.values()) { provider.stopProactiveCredentialRefresh() } @@ -109,24 +238,56 @@ export class SmusAuthenticationProvider { /** * Gets the active connection */ - public get activeConnection() { + public get activeConnection(): SmusConnection | undefined { if (getContext('aws.smus.inSmusSpaceEnvironment')) { const resourceMetadata = getResourceMetadata()! if (resourceMetadata.AdditionalMetadata!.DataZoneDomainRegion) { + // Return a mock connection object for SMUS space environment + // Include type property based on IAM mode context for telemetry + // Note: type will be undefined initially until mode is detected + const isIamMode = getContext('aws.smus.isIamMode') return { domainId: resourceMetadata.AdditionalMetadata!.DataZoneDomainId!, ssoRegion: resourceMetadata.AdditionalMetadata!.DataZoneDomainRegion!, - // The following fields won't be needed in SMUS space environment - // Craft the domain url with known information - // Use randome id as placeholder domainUrl: `https://${resourceMetadata.AdditionalMetadata!.DataZoneDomainId!}.sagemaker.${resourceMetadata.AdditionalMetadata!.DataZoneDomainRegion!}.on.aws/`, id: randomUUID(), - } + type: isIamMode !== undefined ? (isIamMode ? 'iam' : 'sso') : undefined, + } as any as SmusConnection } else { throw new ToolkitError('Domain region not found in metadata file.') } } - return this.secondaryAuth.activeConnection + const baseConnection = this.secondaryAuth.activeConnection + + // If we have a connection, wrap it with SMUS metadata if available + if (baseConnection) { + const smusConnections = this.secondaryAuth.state.get('smus.connections') as any + const connectionMetadata = smusConnections?.[baseConnection.id] + + if (connectionMetadata) { + // For IAM connections, add the profile-specific metadata + if (baseConnection.type === 'iam') { + return { + ...baseConnection, + profileName: connectionMetadata.profileName, + region: connectionMetadata.region, + domainUrl: connectionMetadata.domainUrl, + domainId: connectionMetadata.domainId, + } as SmusIamConnection + } + // For SSO connections, the metadata is already in the connection object + // but we can ensure consistency by adding any missing properties + else if (baseConnection.type === 'sso') { + return { + ...baseConnection, + domainUrl: connectionMetadata.domainUrl || (baseConnection as any).domainUrl, + domainId: connectionMetadata.domainId || (baseConnection as any).domainId, + } as SmusConnection + } + } + } + + return baseConnection as SmusConnection | undefined } /** @@ -162,20 +323,188 @@ export class SmusAuthenticationProvider { /** * Restores the previous connection - * Uses a promise to prevent multiple simultaneous restore calls + * Validates domain metadata against profile and updates if needed before using saved connection */ public async restore() { + const logger = getLogger('smus') + + // Get the saved connection ID before restoring + const savedConnectionId = this.secondaryAuth.state.get('smus.savedConnectionId') as string + if (!savedConnectionId) { + logger.debug('No saved connection ID found, proceeding with normal restore') + await this.secondaryAuth.restoreConnection() + return + } + + // Get the saved connection metadata + const smusConnections = (this.secondaryAuth.state.get('smus.connections') as any) || {} + const connectionMetadata = smusConnections[savedConnectionId] + + // If no connection metadata exists, proceed with normal restore + if (!connectionMetadata) { + logger.debug('No connection metadata found, proceeding with normal restore') + await this.secondaryAuth.restoreConnection() + return + } + + const savedProfileName = connectionMetadata.profileName + + // If no profile name in metadata, proceed with normal restore + if (!savedProfileName) { + logger.debug('No profile name in metadata, proceeding with normal restore') + await this.secondaryAuth.restoreConnection() + return + } + + const profiles = await loadSharedCredentialsProfiles() + const profile = profiles[savedProfileName] + if (!profile) { + logger.debug(`No profile found with name: ${savedProfileName}`) + await this.secondaryAuth.restoreConnection() + return + } + const region = profile.region || 'not-set' + + const validation = await this.validateIamProfile(savedProfileName) + if (!validation.isValid) { + logger.debug(`Profile validation failed: ${validation.error}, proceeding with normal restore`) + await this.secondaryAuth.restoreConnection() + return + } + + let domainUrl + try { + logger.debug(`Finding IAM-based domain in region using profile ${savedProfileName}`) + + // Get DataZoneCustomClientHelper instance + const datazoneCustomClientHelper = DataZoneCustomClientHelper.getInstance( + await this.getCredentialsProviderForIamProfile(savedProfileName), + region + ) + + // Find the IAM-based domain using the client + const iamDomain = await datazoneCustomClientHelper.getIamDomain() + + if (!iamDomain) { + logger.warn(`No IAM-based domain found in region ${region}, proceeding with normal restore`) + await this.secondaryAuth.restoreConnection() + return + } + + logger.debug(`Found IAM-based domain: ${iamDomain.name} (${iamDomain.id})`) + + // Construct domain URL from the IAM-based domain + domainUrl = iamDomain.portalUrl || `https://${iamDomain.id}.sagemaker.${region}.on.aws/` + logger.debug(`Discovered IAM-based domain URL: ${domainUrl}`) + } catch (error) { + logger.error(`Failed to find IAM-based domain: ${error} , proceeding with normal restore`) + await this.secondaryAuth.restoreConnection() + return + } + + try { + logger.debug(`Validating domain metadata for saved connection ${savedConnectionId}`) + + if (!domainUrl) { + logger.info('No domain URL constructed, proceeding with normal restore') + await this.secondaryAuth.restoreConnection() + return + } + + const { domainId } = SmusUtils.extractDomainInfoFromUrl(domainUrl) + + // Compare with saved metadata + const savedDomainId = connectionMetadata.domainId + const savedRegion = connectionMetadata.region + + if (domainId === savedDomainId && region === savedRegion) { + logger.debug('Domain metadata matches, proceeding with normal restore') + } else { + logger.debug( + `SMUS: Domain metadata mismatch detected. Saved: ${savedDomainId}@${savedRegion}, Profile: ${domainId}@${region}. Updating metadata.` + ) + + // Update the metadata with API values + connectionMetadata.domainId = domainId + connectionMetadata.region = region + + // Save updated metadata + smusConnections[savedConnectionId] = connectionMetadata + await this.secondaryAuth.state.update('smus.connections', smusConnections) + + logger.debug('Successfully updated domain metadata') + } + } catch (error) { + logger.warn(`Failed to validate domain metadata: ${error}. Proceeding with normal restore.`) + } + + // Proceed with normal restore await this.secondaryAuth.restoreConnection() } /** - * Authenticates with SageMaker Unified Studio using a domain URL + * Signs out from SMUS with different behavior based on connection type: + * - SSO connections: Deletes the connection (old behavior) + * - IAM connections: Forgets the connection without affecting the underlying IAM profile + */ + @withTelemetryContext({ name: 'signOut', class: authClassName }) + public async signOut() { + const logger = getLogger('smus') + + const activeConnection = this.activeConnection + if (!activeConnection) { + logger.debug('No active connection to sign out from') + return + } + + const connectionId = activeConnection.id + logger.info(`Signing out from connection ${connectionId}`) + + try { + // Clear SMUS-specific metadata from connections registry + const smusConnections = (this.secondaryAuth.state.get('smus.connections') as any) || {} + if (smusConnections[connectionId]) { + delete smusConnections[connectionId] + await this.secondaryAuth.state.update('smus.connections', smusConnections) + } + + // Handle sign-out based on connection type + // Check if this is a real connection (has 'type' property) vs mock connection in SMUS space + if ('type' in activeConnection && isSmusSsoConnection(activeConnection)) { + // For SSO connections, delete the connection (old behavior) + await this.secondaryAuth.deleteConnection() + logger.info(`Deleted SSO connection ${connectionId}`) + } else if ('type' in activeConnection) { + // For IAM connections, forget the connection without affecting the underlying IAM profile + await this.secondaryAuth.forgetConnection() + logger.info(`Forgot IAM connection ${connectionId} (preserved for other services)`) + + // Clear IAM mode context for IAM connections + await setSmusIamModeContext(false) + logger.debug('Cleared IAM mode context') + } else { + // Mock connection in SMUS space environment - no action needed + logger.info(`Sign out completed for mock connection ${connectionId}`) + } + + logger.info(`Successfully signed out from connection ${connectionId}`) + } catch (error) { + logger.error(`Failed to sign out from connection ${connectionId}:`, error) + throw new ToolkitError('Failed to sign out from SageMaker Unified Studio', { + code: SmusErrorCodes.SignOutFailed, + cause: error instanceof Error ? error : undefined, + }) + } + } + + /** + * Authenticates with SageMaker Unified Studio using SSO and a domain URL * @param domainUrl The SageMaker Unified Studio domain URL - * @returns Promise resolving to the connection + * @returns Promise resolving to the SSO connection */ - @withTelemetryContext({ name: 'connectToSmus', class: authClassName }) - public async connectToSmus(domainUrl: string): Promise { - const logger = getLogger() + @withTelemetryContext({ name: 'connectToSmusWithSso', class: authClassName }) + public async connectToSmusWithSso(domainUrl: string): Promise { + const logger = getLogger('smus') try { // Extract domain info using SmusUtils @@ -183,10 +512,10 @@ export class SmusAuthenticationProvider { // Validate domain ID if (!domainId) { - throw new ToolkitError('Invalid domain URL format', { code: 'InvalidDomainUrl' }) + throw new ToolkitError('Invalid domain URL format', { code: SmusErrorCodes.InvalidDomainUrl }) } - logger.info(`SMUS: Connecting to domain ${domainId} in region ${region}`) + logger.info(`Connecting to domain ${domainId} in region ${region}`) // Check if we already have a connection for this domain const existingConn = (await this.auth.listConnections()).find( @@ -196,49 +525,55 @@ export class SmusAuthenticationProvider { if (existingConn) { const connectionState = this.auth.getConnectionState(existingConn) - logger.info(`SMUS: Found existing connection ${existingConn.id} with state: ${connectionState}`) + logger.info(`Found existing connection ${existingConn.id} with state: ${connectionState}`) // If connection is valid, use it directly without triggering new auth flow if (connectionState === 'valid') { - logger.info('SMUS: Using existing valid connection') + logger.info('Using existing valid connection') - // Use the existing connection - const result = await this.secondaryAuth.useNewConnection(existingConn) + // Only SSO connections can be used with connectToSmusWithSso + if (isSmusSsoConnection(existingConn)) { + // Use the existing SSO connection + const result = await this.secondaryAuth.useNewConnection(existingConn) - // Auto-invoke project selection after successful sign-in (but not in SMUS space environment) - if (!SmusUtils.isInSmusSpaceEnvironment()) { - void vscode.commands.executeCommand('aws.smus.switchProject') - } + // Auto-invoke project selection after successful sign-in (but not in SMUS space environment) + if (!SmusUtils.isInSmusSpaceEnvironment()) { + void vscode.commands.executeCommand('aws.smus.switchProject') + } - return result + return result as SmusConnection + } } - // If connection is invalid or expired, reauthenticate + // If connection is invalid or expired, handle based on connection type if (connectionState === 'invalid') { - logger.info('SMUS: Existing connection is invalid, reauthenticating') - const reauthenticatedConn = await this.reauthenticate(existingConn) - - // Create the SMUS connection wrapper - const smusConn: SmusConnection = { - ...reauthenticatedConn, - domainUrl, - domainId, - } - - const result = await this.secondaryAuth.useNewConnection(smusConn) - logger.debug(`SMUS: Reauthenticated connection successfully, id=${result.id}`) - - // Auto-invoke project selection after successful reauthentication (but not in SMUS space environment) - if (!SmusUtils.isInSmusSpaceEnvironment()) { - void vscode.commands.executeCommand('aws.smus.switchProject') + // Only SSO connections can be reauthenticated + if (isSmusSsoConnection(existingConn)) { + logger.info('Existing SSO connection is invalid, reauthenticating') + const reauthenticatedConn = await this.reauthenticate(existingConn) + + // Create the SMUS connection wrapper + const smusConn: SmusConnection = { + ...reauthenticatedConn, + domainUrl, + domainId, + } + + const result = await this.secondaryAuth.useNewConnection(smusConn) + logger.debug(`Reauthenticated connection successfully, id=${result.id}`) + + // Auto-invoke project selection after successful reauthentication (but not in SMUS space environment) + if (!SmusUtils.isInSmusSpaceEnvironment()) { + void vscode.commands.executeCommand('aws.smus.switchProject') + } + + return result as SmusConnection } - - return result } } // No existing connection found, create a new one - logger.info('SMUS: No existing connection found, creating new connection') + logger.info('No existing connection found, creating new connection') // Get SSO instance info from DataZone const ssoInstanceInfo = await SmusUtils.getSsoInstanceInfo(domainUrl) @@ -246,7 +581,7 @@ export class SmusAuthenticationProvider { // Create a new connection with appropriate scope based on domain URL const profile = createSmusProfile(domainUrl, domainId, ssoInstanceInfo.issuerUrl, ssoInstanceInfo.region) const newConn = await this.auth.createConnection(profile) - logger.debug(`SMUS: Created new connection ${newConn.id}`) + logger.debug(`Created new connection ${newConn.id}`) const smusConn: SmusConnection = { ...newConn, @@ -261,23 +596,353 @@ export class SmusAuthenticationProvider { void vscode.commands.executeCommand('aws.smus.switchProject') } - return result + return result as SmusConnection } catch (e) { throw ToolkitError.chain(e, 'Failed to connect to SageMaker Unified Studio', { - code: 'FailedToConnect', + code: SmusErrorCodes.FailedToConnect, }) } } + /** + * Authenticates with SageMaker Unified Studio using IAM credential profile + * @param profileName The AWS credential profile name + * @param region The AWS region + * @param domainUrl The SageMaker Unified Studio domain URL + * @param isIamDomain Whether the domain is an IAM-based domain + * @returns Promise resolving to the IAM connection + */ + @withTelemetryContext({ name: 'connectWithIamProfile', class: authClassName }) + public async connectWithIamProfile( + profileName: string, + region: string, + domainUrl: string, + isIamDomain: boolean = false + ): Promise { + const logger = getLogger('smus') + + try { + // Extract domain info using SmusUtils + const { domainId } = SmusUtils.extractDomainInfoFromUrl(domainUrl) + + // Validate domain ID + if (!domainId) { + throw new ToolkitError('Invalid domain URL format', { code: SmusErrorCodes.InvalidDomainUrl }) + } + + logger.info(`Connecting with IAM profile ${profileName} to domain ${domainId} in region ${region}`) + + // Note: Credential validation is already done in the orchestrator via validateIamProfile() + // No need for redundant validation here + + // Check if we already have a basic IAM connection for this profile + const profileId = `profile:${profileName}` + const existingConn = await this.auth.getConnection({ id: profileId }) + + if (existingConn && existingConn.type === 'iam') { + logger.info(`Found existing IAM profile connection ${profileId}`) + + // Store SMUS metadata in the connections registry + const smusConnections = (this.secondaryAuth.state.get('smus.connections') as any) || {} + smusConnections[existingConn.id] = { + profileName, + region, + domainUrl, + domainId, + isIamDomain, + } + await this.secondaryAuth.state.update('smus.connections', smusConnections) + + // Use the basic IAM connection with secondaryAuth + await this.secondaryAuth.useNewConnection(existingConn) + + // Ensure the connection state is validated + await this.auth.refreshConnectionState(existingConn) + logger.debug( + `SMUS: Using existing IAM connection as SMUS connection successfully, id=${existingConn.id}` + ) + + // Set IAM mode context if this is an IAM-based domain + if (isIamDomain) { + await setSmusIamModeContext(true) + logger.debug('Set IAM mode context to true') + } + + // Return a SMUS IAM connection wrapper for the caller + const smusIamConn: SmusIamConnection = { + ...existingConn, + profileName, + region, + domainUrl, + domainId, + } + + return smusIamConn + } + + // If no existing connection, the auth system should have created one during profile validation + // This shouldn't happen if credentials are valid, but let's handle it gracefully + throw new ToolkitError( + `IAM profile connection not found for '${profileName}'. Please check your AWS credentials configuration.`, + { + code: SmusErrorCodes.ConnectionNotFound, + } + ) + } catch (e) { + throw ToolkitError.chain(e, 'Failed to connect to SageMaker Unified Studio with IAM profile', { + code: SmusErrorCodes.FailedToConnect, + }) + } + } + + /** + * Validates an IAM credential profile using the existing Toolkit validation infrastructure + * @param profileName Profile name to validate + * @returns Promise resolving to validation result + */ + public async validateIamProfile(profileName: string): Promise<{ isValid: boolean; error?: string }> { + const logger = getLogger('smus') + + try { + logger.debug(`Validating IAM profile: ${profileName}`) + + // Create credentials ID for the profile using the existing Toolkit pattern + const credentialsId: CredentialsId = { + credentialSource: SharedCredentialsProvider.getProviderType(), + credentialTypeId: profileName, + } + + // Get the provider using the existing manager + const provider = await CredentialsProviderManager.getInstance().getCredentialsProvider(credentialsId) + if (!provider) { + return { + isValid: false, + error: `Profile '${profileName}' not found or not available`, + } + } + + // Get credentials and validate using the existing Toolkit validation logic + // This includes proper telemetry and error handling + const credentials = await provider.getCredentials() + await globals.loginManager.validateCredentials( + credentials, + provider.getEndpointUrl?.(), + provider.getDefaultRegion() // Use the region from the profile, not hardcoded + ) + + logger.debug(`Profile validation successful: ${profileName}`) + return { isValid: true } + } catch (error) { + logger.error(`Profile validation failed: ${profileName}`, error) + return { + isValid: false, + error: `Invalid profile '${profileName}' - ${(error as Error).message}`, + } + } + } + + /** + * Gets credentials for an IAM profile using Toolkit providers + * @param profileName AWS profile name + * @returns Promise resolving to credentials + */ + public async getCredentialsForIamProfile(profileName: string): Promise { + const logger = getLogger('smus') + + try { + logger.debug(`Getting credentials for IAM profile: ${profileName}`) + + // Create credentials ID for the profile using the existing Toolkit pattern + const credentialsId: CredentialsId = { + credentialSource: SharedCredentialsProvider.getProviderType(), + credentialTypeId: profileName, + } + + // Get the provider using the existing manager + const provider = await CredentialsProviderManager.getInstance().getCredentialsProvider(credentialsId) + if (!provider) { + throw new ToolkitError(`Profile '${profileName}' not found or not available`, { + code: SmusErrorCodes.ProfileNotFound, + }) + } + + // Get credentials using the existing Toolkit provider + const credentials = await provider.getCredentials() + + logger.debug(`Successfully retrieved credentials for IAM profile: ${profileName}`) + return credentials + } catch (error) { + logger.error(`Failed to get credentials for IAM profile ${profileName}: %s`, error) + throw new ToolkitError( + `Failed to get credentials for profile '${profileName}': ${(error as Error).message}`, + { + code: SmusErrorCodes.CredentialRetrievalFailed, + cause: error instanceof Error ? error : undefined, + } + ) + } + } + + /** + * Gets the underlying credentials provider for an IAM profile + * @param profileName AWS profile name + * @returns Promise resolving to the credentials provider + */ + public async getCredentialsProviderForIamProfile(profileName: string): Promise { + const logger = getLogger('smus') + logger.debug(`Getting credentials provider for IAM profile: ${profileName}`) + + // Create credentials ID for the profile using the existing Toolkit pattern + const credentialsId: CredentialsId = { + credentialSource: SharedCredentialsProvider.getProviderType(), + credentialTypeId: profileName, + } + + // Get the provider using the existing manager + const provider = await CredentialsProviderManager.getInstance().getCredentialsProvider(credentialsId) + if (!provider) { + throw new ToolkitError(`Profile '${profileName}' not found or not available`, { + code: SmusErrorCodes.ProfileNotFound, + }) + } + + // Return the underlying provider directly + // This allows callers to use the provider's full interface including caching and refresh + return provider + } + + /** + * Gets the cached caller identity ARN for the active IAM connection + * Fetches from STS if not cached or if connection has changed + * Only works for IAM connections - returns undefined for SSO connections + * @returns Promise resolving to the ARN, or undefined if not available or not an IAM connection + */ + public async getCachedIamCallerIdentityArn(): Promise { + const logger = getLogger('smus') + try { + const activeConn = this.activeConnection + // Only cache for IAM connections + if (!activeConn || activeConn.type !== 'iam') { + return undefined + } + + // Check if we have a cached ARN for this connection + if (this.iamCallerIdentityCache && this.iamCallerIdentityCache.connectionId === activeConn.id) { + logger.debug('Using cached IAM caller identity ARN') + return this.iamCallerIdentityCache.arn + } + + // Fetch fresh caller identity + logger.debug('Fetching IAM caller identity from STS') + const smusConnections = (this.secondaryAuth.state.get('smus.connections') as any) || {} + const connectionMetadata = smusConnections[activeConn.id] + + if (!connectionMetadata?.profileName || !connectionMetadata?.region) { + logger.debug('Missing profile name or region in connection metadata') + return undefined + } + + const credentials = await this.getCredentialsForIamProfile(connectionMetadata.profileName) + const stsClient = new DefaultStsClient(connectionMetadata.region, credentials) + const callerIdentity = await stsClient.getCallerIdentity() + + if (!callerIdentity.Arn) { + logger.debug('No ARN found in caller identity') + return undefined + } + + // Cache the result + this.iamCallerIdentityCache = { + arn: callerIdentity.Arn, + connectionId: activeConn.id, + } + logger.debug(`Cached IAM caller identity ARN for connection ${activeConn.id}`) + + return callerIdentity.Arn + } catch (error) { + logger.warn(`Failed to get IAM caller identity: %s`, error) + return undefined + } + } + + /** + * Gets the session name from the cached IAM caller identity + * Only works for IAM connections - returns undefined for SSO connections + * @returns Promise resolving to the session name, or undefined if not available or not an IAM connection + */ + public async getSessionName(): Promise { + const arn = await this.getCachedIamCallerIdentityArn() + if (!arn) { + return undefined + } + + const sessionName = SmusUtils.extractSessionNameFromArn(arn) + this.logger.debug(`Extracted session name: ${sessionName || 'none'}`) + return sessionName + } + + /** + * Gets the role ARN from the cached IAM caller identity + * Converts assumed role ARN to IAM role ARN format + * Only works for IAM connections - returns undefined for SSO connections + * @returns Promise resolving to the IAM role ARN, or undefined if not available or not an IAM connection + */ + public async getIamPrincipalArn(): Promise { + const arn = await this.getCachedIamCallerIdentityArn() + if (!arn) { + return undefined + } + + // Convert assumed role ARN to IAM role ARN + const roleArn = SmusUtils.convertAssumedRoleArnToIamRoleArn(arn) + this.logger.debug(`Extracted role ARN: ${roleArn || 'none'}`) + return roleArn + } + + /** + * Clears the cached IAM caller identity + * Should be called when connection changes or credentials are refreshed + */ + private clearIamCallerIdentityCache(): void { + this.iamCallerIdentityCache = undefined + this.logger.debug('Cleared IAM caller identity cache') + } + /** * Reauthenticates an existing connection * @param conn Connection to reauthenticate * @returns Promise resolving to the reauthenticated connection */ @withTelemetryContext({ name: 'reauthenticate', class: authClassName }) - public async reauthenticate(conn: SsoConnection) { + public async reauthenticate(conn: SmusConnection): Promise { try { - return await this.auth.reauthenticate(conn) + // Check if this is an IAM connection + if (isSmusIamConnection(conn)) { + // For IAM connections, show options menu + this.logger.debug('Showing IAM credential expiry options for reauthentication') + const result = await showIamCredentialExpiryOptions(this, conn, globals.context) + + // Handle the result - for most actions, return the original connection + // The actions have already been performed (sign out, edit credentials, etc.) + if (result.action === IamCredentialExpiryAction.SignOut) { + throw new ToolkitError('User signed out from connection', { cancelled: true }) + } else if (result.action === IamCredentialExpiryAction.Cancelled) { + throw new ToolkitError('Reauthentication cancelled by user', { cancelled: true }) + } + + // For Reauthenticate, EditCredentials, and SwitchProfile, return the connection + return conn + } else { + // For SSO connections, use existing re-auth flow + const reauthenticatedConn = await this.auth.reauthenticate(conn) + + // Re-add SMUS-specific properties that aren't preserved by the base auth system + return { + ...reauthenticatedConn, + domainUrl: conn.domainUrl, + domainId: conn.domainId, + } as SmusConnection + } } catch (err) { throw ToolkitError.chain(err, 'Unable to reauthenticate SageMaker Unified Studio connection.') } @@ -287,7 +952,7 @@ export class SmusAuthenticationProvider { * Shows a reauthentication prompt to the user * @param conn Connection to reauthenticate */ - public async showReauthenticationPrompt(conn: SsoConnection): Promise { + public async showReauthenticationPrompt(conn: SmusConnection): Promise { await showReauthenticateMessage({ message: localizedText.connectionExpired('SageMaker Unified Studio'), connect: localizedText.reauthenticate, @@ -306,22 +971,28 @@ export class SmusAuthenticationProvider { * @throws ToolkitError if unable to retrieve access token */ public async getAccessToken(): Promise { - const logger = getLogger() + const logger = getLogger('smus') - if (!this.activeConnection) { + const connection = this.activeConnection + if (!connection) { throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection }) } + // Only SSO connections have access tokens + if (!isSmusSsoConnection(connection)) { + throw new ToolkitError('Access tokens are only available for SSO connections', { + code: SmusErrorCodes.InvalidConnectionType, + }) + } + try { - const accessToken = await this.auth.getSsoAccessToken(this.activeConnection) - logger.debug(`SMUS: Successfully retrieved SSO access token for connection ${this.activeConnection.id}`) + // Type assertion is safe here because we've already checked with isSmusSsoConnection + const accessToken = await this.auth.getSsoAccessToken(connection as SsoConnection) + logger.debug(`Successfully retrieved SSO access token for connection ${connection.id}`) return accessToken } catch (err) { - logger.error( - `SMUS: Failed to retrieve SSO access token for connection ${this.activeConnection.id}: %s`, - err - ) + logger.error(`Failed to retrieve SSO access token for connection ${connection.id}: %s`, err) // Check if this is a reauth error that should be handled by showing SMUS-specific prompt if (err instanceof ToolkitError && err.code === 'InvalidConnection') { @@ -331,7 +1002,7 @@ export class SmusAuthenticationProvider { ) } - throw new ToolkitError(`Failed to retrieve SSO access token for connection ${this.activeConnection.id}`, { + throw new ToolkitError(`Failed to retrieve SSO access token for connection ${connection.id}`, { code: SmusErrorCodes.RedeemAccessTokenFailed, cause: err instanceof Error ? err : undefined, }) @@ -344,26 +1015,26 @@ export class SmusAuthenticationProvider { * @returns Promise resolving to the project credentials provider */ public async getProjectCredentialProvider(projectId: string): Promise { - const logger = getLogger() + const logger = getLogger('smus') if (!this.activeConnection) { throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection }) } - logger.debug(`SMUS: Getting project provider for project ${projectId}`) + logger.debug(`Getting project provider for project ${projectId}`) // Check if we already have a cached provider for this project if (this.projectCredentialProvidersCache.has(projectId)) { - logger.debug('SMUS: Using cached project provider') + logger.debug('Using cached project provider') return this.projectCredentialProvidersCache.get(projectId)! } - logger.debug('SMUS: Creating new project provider') + logger.debug('Creating new project provider') // Create a new project provider and cache it const projectProvider = new ProjectRoleCredentialsProvider(this, projectId) this.projectCredentialProvidersCache.set(projectId, projectProvider) - logger.debug('SMUS: Cached new project provider') + logger.debug('Cached new project provider') return projectProvider } @@ -380,27 +1051,27 @@ export class SmusAuthenticationProvider { projectId: string, region: string ): Promise { - const logger = getLogger() + const logger = getLogger('smus') if (!this.activeConnection) { throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection }) } - const cacheKey = `${this.activeConnection.domainId}:${projectId}:${connectionId}` - logger.debug(`SMUS: Getting connection provider for connection ${connectionId}`) + const cacheKey = `${this.getDomainId()}:${projectId}:${connectionId}` + logger.debug(`Getting connection provider for connection ${connectionId}`) // Check if we already have a cached provider for this connection if (this.connectionCredentialProvidersCache.has(cacheKey)) { - logger.debug('SMUS: Using cached connection provider') + logger.debug('Using cached connection provider') return this.connectionCredentialProvidersCache.get(cacheKey)! } - logger.debug('SMUS: Creating new connection provider') + logger.debug('Creating new connection provider') // Create a new connection provider and cache it - const connectionProvider = new ConnectionCredentialsProvider(this, connectionId) + const connectionProvider = new ConnectionCredentialsProvider(this, connectionId, projectId) this.connectionCredentialProvidersCache.set(cacheKey, connectionProvider) - logger.debug('SMUS: Cached new connection provider') + logger.debug('Cached new connection provider') return connectionProvider } @@ -417,7 +1088,15 @@ export class SmusAuthenticationProvider { if (!this.activeConnection) { throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection }) } - return this.activeConnection.domainId + + // For SMUS connections (both SSO and IAM) with domainId property + if ('domainId' in this.activeConnection) { + return (this.activeConnection as any).domainId + } + + throw new ToolkitError('Domain ID not available. Please reconnect to SMUS.', { + code: SmusErrorCodes.NoActiveConnection, + }) } /** @@ -428,7 +1107,15 @@ export class SmusAuthenticationProvider { if (!this.activeConnection) { throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection }) } - return this.activeConnection.domainUrl + + // For SMUS connections (both SSO and IAM) with domainUrl property + if ('domainUrl' in this.activeConnection) { + return (this.activeConnection as any).domainUrl + } + + throw new ToolkitError('Domain URL not available. Please reconnect to SMUS.', { + code: SmusErrorCodes.NoActiveConnection, + }) } /** @@ -439,11 +1126,11 @@ export class SmusAuthenticationProvider { * @throws ToolkitError if unable to retrieve account ID */ public async getDomainAccountId(): Promise { - const logger = getLogger() + const logger = getLogger('smus') // Return cached value if available if (this.cachedDomainAccountId) { - logger.debug('SMUS: Using cached domain account ID') + logger.debug('Using cached domain account ID') return this.cachedDomainAccountId } @@ -466,14 +1153,19 @@ export class SmusAuthenticationProvider { try { logger.debug('Fetching domain account ID via STS GetCallerIdentity') - // Get DER credentials provider - const derCredProvider = await this.getDerCredentialsProvider() - + let credentialsProvider + if (getContext('aws.smus.isIamMode')) { + credentialsProvider = await this.getCredentialsProviderForIamProfile( + (this.activeConnection as SmusIamConnection).profileName + ) + } else { + credentialsProvider = await this.getDerCredentialsProvider() + } // Get the region for STS client const region = this.getDomainRegion() // Create STS client with DER credentials - const stsClient = new DefaultStsClient(region, await derCredProvider.getCredentials()) + const stsClient = new DefaultStsClient(region, await credentialsProvider.getCredentials()) // Make GetCallerIdentity call const callerIdentity = await stsClient.getCallerIdentity() @@ -508,11 +1200,11 @@ export class SmusAuthenticationProvider { * @returns Promise resolving to the project's AWS account ID */ public async getProjectAccountId(projectId: string): Promise { - const logger = getLogger() + const logger = getLogger('smus') // Return cached value if available if (this.cachedProjectAccountIds.has(projectId)) { - logger.debug(`SMUS: Using cached project account ID for project ${projectId}`) + logger.debug(`Using cached project account ID for project ${projectId}`) return this.cachedProjectAccountIds.get(projectId)! } @@ -540,7 +1232,7 @@ export class SmusAuthenticationProvider { const projectCreds = await projectCredProvider.getCredentials() // Get project region from tooling environment - const dzClient = await DataZoneClient.getInstance(this) + const dzClient = await createDZClientBaseOnDomainMode(this) const toolingEnv = await dzClient.getToolingEnvironment(projectId) const projectRegion = toolingEnv.awsAccountRegion @@ -585,10 +1277,24 @@ export class SmusAuthenticationProvider { } } - if (!this.activeConnection) { + const connection = this.activeConnection + if (!connection) { throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection }) } - return this.activeConnection.ssoRegion + + // Handle different connection types + if (isSmusSsoConnection(connection)) { + return connection.ssoRegion + } + + // For SMUS connections (both SSO and IAM) with region property + if ('region' in connection) { + return (connection as any).region + } + + throw new ToolkitError('Domain region not available. Please reconnect to SMUS.', { + code: SmusErrorCodes.NoActiveConnection, + }) } /** @@ -596,44 +1302,101 @@ export class SmusAuthenticationProvider { * @returns Promise resolving to the credentials provider */ public async getDerCredentialsProvider(): Promise { - const logger = getLogger() + const logger = getLogger('smus') if (getContext('aws.smus.inSmusSpaceEnvironment')) { // When in SMUS space, DomainExecutionRoleCreds can be found in config file // Read the credentials from credential profile DomainExecutionRoleCreds - const credentials = fromIni({ profile: 'DomainExecutionRoleCreds' }) - return { - getCredentials: async () => await credentials(), + try { + // Load AWS config file to check profile configuration + const { configFile } = await loadSharedConfigFiles() + const profileConfig = configFile['DomainExecutionRoleCreds'] + + if (profileConfig?.credential_process) { + // Normal SMUS domain: Use the profile with credential_process + logger.debug('Using DomainExecutionRoleCreds profile with credential_process') + const credentials = fromIni({ profile: 'DomainExecutionRoleCreds' }) + return convertToToolkitCredentialProvider( + async () => await credentials(), + 'DomainExecutionRoleCreds', + `smus-der-profile:${this.getDomainId()}:${this.getDomainRegion()}`, + this.getDomainRegion() + ) + } else if (profileConfig?.credential_source === 'EcsContainer') { + // IAM-based domain with EcsContainer: Use ECS container credentials directly + // The environment has AWS_CONTAINER_CREDENTIALS_RELATIVE_URI set, so use fromContainerMetadata + // which properly handles the ECS credential endpoint + logger.debug('IAM-based domain detected, using ECS container credentials') + const credentials = fromContainerMetadata({ + timeout: 5000, + maxRetries: 3, + }) + return convertToToolkitCredentialProvider( + async () => await credentials(), + 'EcsContainer', + `smus-ecs-container:${this.getDomainId()}:${this.getDomainRegion()}`, + this.getDomainRegion() + ) + } else { + // Fallback: try the profile anyway + logger.debug( + 'SMUS: Unknown profile configuration, attempting to use DomainExecutionRoleCreds profile' + ) + const credentials = fromIni({ profile: 'DomainExecutionRoleCreds' }) + return convertToToolkitCredentialProvider( + async () => await credentials(), + 'DomainExecutionRoleCreds-fallback', + `smus-der-fallback:${this.getDomainId()}:${this.getDomainRegion()}`, + this.getDomainRegion() + ) + } + } catch (error) { + logger.error('Failed to load config file, falling back to default credential chain: %s', error) + const credentials = fromNodeProviderChain() + return convertToToolkitCredentialProvider( + async () => await credentials(), + 'NodeProviderChain', + `smus-node-provider-chain:${this.getDomainId()}:${this.getDomainRegion()}`, + this.getDomainRegion() + ) } } - if (!this.activeConnection) { + const connection = this.activeConnection + if (!connection) { throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection }) } + // Domain Execution Role credentials are only available for SSO connections + if (!isSmusSsoConnection(connection)) { + throw new ToolkitError('Domain Execution Role credentials are only available for SSO connections', { + code: SmusErrorCodes.InvalidConnectionType, + }) + } + // Create a cache key based on the connection details - const cacheKey = `${this.activeConnection.ssoRegion}:${this.activeConnection.domainId}` + const cacheKey = `${connection.ssoRegion}:${connection.domainId}` - logger.debug(`SMUS: Getting credentials provider for cache key: ${cacheKey}`) + logger.debug(`Getting credentials provider for cache key: ${cacheKey}`) // Check if we already have a cached provider if (this.credentialsProviderCache.has(cacheKey)) { - logger.debug('SMUS: Using cached credentials provider') + logger.debug('Using cached credentials provider') return this.credentialsProviderCache.get(cacheKey) } - logger.debug('SMUS: Creating new credentials provider') + logger.debug('Creating new credentials provider') // Create a new provider and cache it const provider = new DomainExecRoleCredentialsProvider( - this.activeConnection.domainUrl, - this.activeConnection.domainId, - this.activeConnection.ssoRegion, + connection.domainUrl, + connection.domainId, + connection.ssoRegion, async () => await this.getAccessToken() ) this.credentialsProviderCache.set(cacheKey, provider) - logger.debug('SMUS: Cached new credentials provider') + logger.debug('Cached new credentials provider') return provider } @@ -643,16 +1406,16 @@ export class SmusAuthenticationProvider { * Used during connection changes or logout */ private async invalidateAllCredentialsInCache(): Promise { - const logger = getLogger() - logger.debug('SMUS: Invalidating all cached credentials') + const logger = getLogger('smus') + logger.debug('Invalidating all cached credentials') // Clear all cached DER providers and their internal credentials for (const [cacheKey, provider] of this.credentialsProviderCache.entries()) { try { provider.invalidate() // This will clear the provider's internal cache - logger.debug(`SMUS: Invalidated credentials for cache key: ${cacheKey}`) + logger.debug(`Invalidated credentials for cache key: ${cacheKey}`) } catch (err) { - logger.warn(`SMUS: Failed to invalidate credentials for cache key ${cacheKey}: %s`, err) + logger.warn(`Failed to invalidate credentials for cache key ${cacheKey}: %s`, err) } } @@ -663,34 +1426,34 @@ export class SmusAuthenticationProvider { for (const [cacheKey, connectionProvider] of this.connectionCredentialProvidersCache.entries()) { try { connectionProvider.invalidate() // This will clear the connection provider's internal cache - logger.debug(`SMUS: Invalidated connection credentials for cache key: ${cacheKey}`) + logger.debug(`Invalidated connection credentials for cache key: ${cacheKey}`) } catch (err) { - logger.warn(`SMUS: Failed to invalidate connection credentials for cache key ${cacheKey}: %s`, err) + logger.warn(`Failed to invalidate connection credentials for cache key ${cacheKey}: %s`, err) } } // Clear cached domain account ID this.cachedDomainAccountId = undefined - logger.debug('SMUS: Cleared cached domain account ID') + logger.debug('Cleared cached domain account ID') // Clear cached project account IDs this.cachedProjectAccountIds.clear() - logger.debug('SMUS: Cleared cached project account IDs') + logger.debug('Cleared cached project account IDs') } /** * Invalidates all project cached credentials */ public async invalidateAllProjectCredentialsInCache(): Promise { - const logger = getLogger() - logger.debug('SMUS: Invalidating all cached project credentials') + const logger = getLogger('smus') + logger.debug('Invalidating all cached project credentials') for (const [projectId, projectProvider] of this.projectCredentialProvidersCache.entries()) { try { projectProvider.invalidate() // This will clear the project provider's internal cache - logger.debug(`SMUS: Invalidated project credentials for project: ${projectId}`) + logger.debug(`Invalidated project credentials for project: ${projectId}`) } catch (err) { - logger.warn(`SMUS: Failed to invalidate project credentials for project ${projectId}: %s`, err) + logger.warn(`Failed to invalidate project credentials for project ${projectId}: %s`, err) } } } @@ -699,7 +1462,7 @@ export class SmusAuthenticationProvider { * Stops SSH credential refresh and cleans up resources */ public dispose(): void { - this.logger.debug('SMUS Auth: Disposing authentication provider and all cached providers') + this.logger.debug('Disposing authentication provider and all cached providers') // Dispose all project providers for (const provider of this.projectCredentialProvidersCache.values()) { @@ -727,7 +1490,13 @@ export class SmusAuthenticationProvider { // Clear cached project account IDs this.cachedProjectAccountIds.clear() - this.logger.debug('SMUS Auth: Successfully disposed authentication provider') + // Clear cached IAM caller identity + this.clearIamCallerIdentityCache() + + DataZoneClient.dispose() + DataZoneCustomClientHelper.dispose() + + this.logger.debug('Successfully disposed authentication provider') } static #instance: SmusAuthenticationProvider | undefined @@ -739,4 +1508,29 @@ export class SmusAuthenticationProvider { public static fromContext() { return (this.#instance ??= new this()) } + + public async invalidateConnection(): Promise { + // When in SMUS space, the extension is already running in projet context and sign in is not needed + if (getContext('aws.smus.inSmusSpaceEnvironment')) { + return + } + + if (!this.activeConnection) { + return + } + + // For IAM connections, actively validate credentials + // No action needed for SSO as the connection is automatically updated + if (isSmusIamConnection(this.activeConnection)) { + try { + const validation = await this.validateIamProfile(this.activeConnection.profileName) + await this.auth.updateConnectionState( + this.activeConnection.id, + validation.isValid ? 'valid' : 'invalid' + ) + } catch { + await this.auth.updateConnectionState(this.activeConnection.id, 'invalid') + } + } + } } diff --git a/packages/core/src/sagemakerunifiedstudio/auth/ui/authenticationMethodSelection.ts b/packages/core/src/sagemakerunifiedstudio/auth/ui/authenticationMethodSelection.ts new file mode 100644 index 00000000000..e209a15b1ca --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/auth/ui/authenticationMethodSelection.ts @@ -0,0 +1,111 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as vscode from 'vscode' +import { getLogger } from '../../../shared/logger/logger' +import { ToolkitError } from '../../../shared/errors' +import { SmusErrorCodes } from '../../shared/smusUtils' + +/** + * Authentication method types supported by SMUS + */ +export type SmusAuthenticationMethod = 'sso' | 'iam' + +/** + * Result of authentication method selection + */ +export interface AuthenticationMethodSelection { + method: SmusAuthenticationMethod +} + +/** + * Authentication method selection dialog for SMUS + */ +export class SmusAuthenticationMethodSelector { + private static readonly logger = getLogger('smus') + + /** + * Shows the authentication method selection dialog matching the Figma design + * @param defaultMethod Optional default method to pre-select + * @returns Promise resolving to the selected authentication method + */ + public static async showAuthenticationMethodSelection( + defaultMethod?: SmusAuthenticationMethod + ): Promise { + const logger = this.logger + + const iamOption: vscode.QuickPickItem = { + label: '$(key) IAM Credential', + detail: 'Use IAM credentials to access resources in SageMaker Unified Studio IAM-based domains.', + } + + const ssoOption: vscode.QuickPickItem = { + label: '$(organization) IAM Identity Center', + detail: 'Use Identity Center to access resources in SageMaker Unified Studio IdC-based domains.', + } + + const options = [iamOption, ssoOption] + + // Set default selection based on preference + let defaultIndex = 0 + if (defaultMethod === 'sso') { + defaultIndex = 1 + } + + const quickPick = vscode.window.createQuickPick() + quickPick.title = 'Select a sign in method' + quickPick.placeholder = 'Choose how you want to authenticate with SageMaker Unified Studio' + quickPick.items = options + quickPick.canSelectMany = false + quickPick.ignoreFocusOut = true + + // Pre-select the default method + if (options[defaultIndex]) { + quickPick.activeItems = [options[defaultIndex]] + } + + return new Promise((resolve, reject) => { + let isCompleted = false + + quickPick.onDidAccept(() => { + const selectedItem = quickPick.selectedItems[0] + if (!selectedItem) { + quickPick.dispose() + reject( + new ToolkitError('No authentication method selected', { + code: SmusErrorCodes.UserCancelled, + cancelled: true, + }) + ) + return + } + + const method: SmusAuthenticationMethod = selectedItem === iamOption ? 'iam' : 'sso' + + logger.debug(`User selected authentication method: ${method}`) + + isCompleted = true + quickPick.dispose() + + // Return the selected method without asking about preferences + resolve({ method }) + }) + + quickPick.onDidHide(() => { + if (!isCompleted) { + quickPick.dispose() + reject( + new ToolkitError('Authentication method selection cancelled', { + code: SmusErrorCodes.UserCancelled, + cancelled: true, + }) + ) + } + }) + + quickPick.show() + }) + } +} diff --git a/packages/core/src/sagemakerunifiedstudio/auth/ui/iamProfileSelection.ts b/packages/core/src/sagemakerunifiedstudio/auth/ui/iamProfileSelection.ts new file mode 100644 index 00000000000..fd61fc70f50 --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/auth/ui/iamProfileSelection.ts @@ -0,0 +1,1311 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as vscode from 'vscode' +import * as path from 'path' +import { getLogger } from '../../../shared/logger/logger' +import { ToolkitError } from '../../../shared/errors' +import { loadSharedCredentialsProfiles } from '../../../auth/credentials/sharedCredentials' +import { getCredentialsFilename, getConfigFilename } from '../../../auth/credentials/sharedCredentialsFile' +import { SmusErrorCodes, DataZoneServiceId } from '../../shared/smusUtils' +import globals from '../../../shared/extensionGlobals' +import fs from '../../../shared/fs/fs' + +/** + * Actions available in the credential management dialog + */ +enum CredentialManagementAction { + EditCredentialsFile = 'EDIT_CREDENTIALS_FILE', + EditConfigFile = 'EDIT_CONFIG_FILE', + AddNewProfile = 'ADD_NEW_PROFILE', +} + +/** + * Actions available in the profile selection dialog + */ +enum ProfileSelectionAction { + SelectProfile = 'SELECT_PROFILE', + ManageCredentials = 'MANAGE_CREDENTIALS', +} + +/** + * Actions available in the session token input dialog + */ +enum SessionTokenAction { + Skip = 'SKIP', + UseToken = 'USE_TOKEN', + Warning = 'WARNING', +} + +/** + * Result of IAM profile selection + */ +export interface IamProfileSelection { + profileName: string + region: string +} + +/** + * Result indicating user chose to edit credential files + */ +export interface IamProfileEditingInProgress { + isEditing: true + message: string +} + +/** + * Result indicating user chose to go back + */ +export interface IamProfileBackNavigation { + isBack: true + message: string +} + +/** + * IAM profile selection interface for SMUS + */ +export class SmusIamProfileSelector { + private static readonly logger = getLogger('smus') + + // Validation regex patterns (based on AWS STS API specifications) + // Reference: https://docs.aws.amazon.com/STS/latest/APIReference/API_Credentials.html + private static readonly profileNamePattern = /^[a-zA-Z0-9_-]+$/ + // AWS AccessKeyId: 16-128 chars, pattern [\w]* (alphanumeric + underscore) + private static readonly accessKeyIdPattern = /^[a-zA-Z0-9_]*$/ + // AWS SecretAccessKey and SessionToken: Required per STS API, but no pattern/length constraints specified + private static readonly regionLinePattern = /^region\s*=.*$/m + + /** + * Creates a QuickPick with common settings for input dialogs + * @param title Title for the QuickPick + * @param placeholder Placeholder text + * @returns Configured QuickPick instance + */ + private static createInputQuickPick(title: string, placeholder: string): vscode.QuickPick { + const quickPick = vscode.window.createQuickPick() + quickPick.title = title + quickPick.placeholder = placeholder + quickPick.canSelectMany = false + quickPick.ignoreFocusOut = true + quickPick.buttons = [vscode.QuickInputButtons.Back] + return quickPick + } + + /** + * Shows the IAM profile selection dialog matching the Figma design + * @returns Promise resolving to the selected profile and region, editing status, or back navigation + */ + public static async showIamProfileSelection(): Promise< + IamProfileSelection | IamProfileEditingInProgress | IamProfileBackNavigation + > { + const logger = this.logger + + try { + // Load available credential profiles + const profiles = await loadSharedCredentialsProfiles() + const profileNames = Object.keys(profiles) + + // Create QuickPick items for profiles + const profileItems: (vscode.QuickPickItem & { + action: ProfileSelectionAction + profileName: string + region: string + })[] = profileNames.map((profileName) => { + const profile = profiles[profileName] + const region = profile.region || 'not-set' + + return { + label: `$(key) ${profileName}`, + description: `IAM Credentials, configured locally (${region})`, + detail: `Profile: ${profileName} | Region: ${region}`, + action: ProfileSelectionAction.SelectProfile, + profileName, + region, + } + }) + + // Add "Add or edit credentials" option + const addCredentialsItem: vscode.QuickPickItem & { action: ProfileSelectionAction } = { + label: '$(add) Add or edit credentials', + description: 'Manage AWS credential profiles', + detail: 'Add new profiles or edit existing credential files', + action: ProfileSelectionAction.ManageCredentials, + } + + const options = [...profileItems, addCredentialsItem] + + const quickPick = vscode.window.createQuickPick() + quickPick.title = 'Select an IAM Profile' + quickPick.placeholder = 'Choose an AWS credential profile to authenticate with SageMaker Unified Studio' + quickPick.items = options + quickPick.canSelectMany = false + quickPick.ignoreFocusOut = true + + // Add back button + const backButton = vscode.QuickInputButtons.Back + quickPick.buttons = [backButton] + + return new Promise((resolve, reject) => { + let isCompleted = false + + quickPick.onDidAccept(() => { + const selectedItem = quickPick.selectedItems[0] + if (!selectedItem) { + quickPick.dispose() + reject( + new ToolkitError('No profile selected', { + code: SmusErrorCodes.UserCancelled, + cancelled: true, + }) + ) + return + } + + isCompleted = true + quickPick.dispose() + + const itemWithAction = selectedItem as vscode.QuickPickItem & { + action: ProfileSelectionAction + profileName?: string + region?: string + } + + // Check if user selected "Add or edit credentials" + if (itemWithAction.action === ProfileSelectionAction.ManageCredentials) { + // Handle the async credential management flow + void (async () => { + try { + const managementResult = await SmusIamProfileSelector.showCredentialManagement() + + // Check if a new profile was created (returns IamProfileSelection) + if (typeof managementResult === 'object' && 'profileName' in managementResult) { + // User created a new profile, use it directly + logger.debug( + `SMUS Auth: Using newly created profile: ${managementResult.profileName}` + ) + resolve(managementResult) + } else if (managementResult === true) { + // User wants to restart profile selection (e.g., clicked back) + const result = await SmusIamProfileSelector.showIamProfileSelection() + resolve(result) + } else { + // User chose to edit files, return a special result indicating this + resolve({ + isEditing: true, + message: + 'User chose to edit credential files. Please complete setup and try again.', + }) + } + } catch (error) { + // Handle user cancellation gracefully + if (error instanceof ToolkitError && error.code === SmusErrorCodes.UserCancelled) { + resolve({ + isEditing: true, + message: 'User cancelled credential management.', + }) + } else { + reject(error) + } + } + })() + return + } + + // User selected an existing profile + // Ensure we have profile data (should always be present for SelectProfile action) + if (!itemWithAction.profileName || !itemWithAction.region) { + reject(new ToolkitError('Invalid profile selection', { code: 'InvalidProfileSelection' })) + return + } + + const profileName = itemWithAction.profileName + const profileRegion = itemWithAction.region + + logger.debug(`User selected profile: ${profileName}`) + + // Check if region is not set and prompt for region selection + if (profileRegion === 'not-set') { + void (async () => { + try { + const selectedRegion = await SmusIamProfileSelector.showRegionSelection() + + // Check if user clicked back on region selection + if (selectedRegion === 'BACK') { + resolve({ + isBack: true, + message: 'User chose to go back from region selection.', + }) + return + } + + // Update the profile with the selected region + await SmusIamProfileSelector.updateProfileRegion(profileName, selectedRegion) + + resolve({ + profileName: profileName, + region: selectedRegion, + }) + } catch (error) { + reject(error) + } + })() + } else { + resolve({ + profileName: profileName, + region: profileRegion, + }) + } + }) + + quickPick.onDidTriggerButton((button) => { + if (button === vscode.QuickInputButtons.Back) { + isCompleted = true + quickPick.dispose() + resolve({ + isBack: true, + message: 'User chose to go back to authentication method selection.', + }) + } + }) + + quickPick.onDidHide(() => { + if (!isCompleted) { + quickPick.dispose() + reject( + new ToolkitError('Profile selection cancelled', { + code: SmusErrorCodes.UserCancelled, + cancelled: true, + }) + ) + } + }) + + quickPick.show() + }) + } catch (error) { + // Don't log or chain user cancellation as an error + if (error instanceof ToolkitError && error.code === SmusErrorCodes.UserCancelled) { + throw error + } + logger.error('Failed to show IAM profile selection: %s', error) + throw ToolkitError.chain(error, 'Failed to show IAM profile selection') + } + } + + /** + * Shows region selection dialog for IAM authentication + * @param options Configuration options for the region selection dialog + * @returns Promise resolving to the selected region or 'BACK' if user wants to go back + */ + public static async showRegionSelection(options?: { + defaultRegion?: string + title?: string + placeholder?: string + returnBackOnCancel?: boolean + }): Promise { + const logger = this.logger + + // Get regions where DataZone service is available + const allRegions = globals.regionProvider.getRegions() + const dataZoneRegions = allRegions.filter((region) => + globals.regionProvider.isServiceInRegion(DataZoneServiceId, region.id) + ) + + // If no regions found with DataZone service, fall back to all regions + const regions = dataZoneRegions.length > 0 ? dataZoneRegions : allRegions + + const regionItems: vscode.QuickPickItem[] = regions.map( + (region) => + ({ + label: region.name, + description: region.id, + detail: `AWS Region: ${region.id}`, + regionCode: region.id, + }) as vscode.QuickPickItem & { regionCode: string } + ) + + const quickPick = this.createInputQuickPick( + options?.title ?? 'Select AWS Region', + options?.placeholder ?? 'Choose the AWS region for SageMaker Unified Studio' + ) + quickPick.items = regionItems + + // Allow users to find matches by typing in the region code (e.g., us-east-1) + quickPick.matchOnDescription = true + + // Pre-select default region if provided + if (options?.defaultRegion) { + const defaultItem = regionItems.find((item) => (item as any).regionCode === options.defaultRegion) + if (defaultItem) { + quickPick.activeItems = [defaultItem] + } + } + + return new Promise((resolve, reject) => { + let isCompleted = false + + quickPick.onDidAccept(() => { + const selectedItem = quickPick.selectedItems[0] + if (!selectedItem) { + if (options?.returnBackOnCancel) { + quickPick.dispose() + resolve('BACK') + } else { + quickPick.dispose() + reject( + new ToolkitError('No region selected', { + code: SmusErrorCodes.UserCancelled, + cancelled: true, + }) + ) + } + return + } + + isCompleted = true + quickPick.dispose() + + const regionItem = selectedItem as vscode.QuickPickItem & { regionCode: string } + + logger.debug(`User selected region: ${regionItem.regionCode}`) + + resolve(regionItem.regionCode) + }) + + quickPick.onDidTriggerButton((button) => { + if (button === vscode.QuickInputButtons.Back) { + isCompleted = true + quickPick.dispose() + resolve('BACK') + } + }) + + quickPick.onDidHide(() => { + if (!isCompleted) { + quickPick.dispose() + if (options?.returnBackOnCancel) { + resolve('BACK') + } else { + reject( + new ToolkitError('Region selection cancelled', { + code: SmusErrorCodes.UserCancelled, + cancelled: true, + }) + ) + } + } + }) + + quickPick.show() + }) + } + + /** + * Shows credential management options (Add/Edit credentials) + * @returns Promise resolving to boolean indicating if profile selection should restart, or profile data if a new profile was created + */ + public static async showCredentialManagement(): Promise { + const logger = this.logger + + logger.debug('Showing credential management options') + + const options: (vscode.QuickPickItem & { action: CredentialManagementAction })[] = [ + { + label: '$(file-text) Edit AWS Credentials File', + description: 'Open ~/.aws/credentials file for editing', + detail: 'Edit existing credential profiles or add new ones', + action: CredentialManagementAction.EditCredentialsFile, + }, + { + label: '$(file-text) Edit AWS Config File', + description: 'Open ~/.aws/config file for editing', + detail: 'Edit AWS configuration settings and profiles', + action: CredentialManagementAction.EditConfigFile, + }, + { + label: '$(add) Add New Profile', + description: 'Create a new AWS credential profile', + detail: 'Interactive setup for a new credential profile', + action: CredentialManagementAction.AddNewProfile, + }, + ] + + const quickPick = vscode.window.createQuickPick() + quickPick.title = 'Manage AWS Credentials' + quickPick.placeholder = 'Choose how you want to manage your AWS credentials' + quickPick.items = options + quickPick.canSelectMany = false + quickPick.ignoreFocusOut = true + + // Add back button + const backButton = vscode.QuickInputButtons.Back + quickPick.buttons = [backButton] + + return new Promise((resolve, reject) => { + let isCompleted = false + + quickPick.onDidAccept(() => { + const selectedItem = quickPick.selectedItems[0] + if (!selectedItem) { + quickPick.dispose() + reject( + new ToolkitError('No option selected', { code: SmusErrorCodes.UserCancelled, cancelled: true }) + ) + return + } + + isCompleted = true + quickPick.dispose() + + // Handle the async operations after disposing the quick pick + void (async () => { + try { + const itemWithAction = selectedItem as vscode.QuickPickItem & { + action: CredentialManagementAction + } + + switch (itemWithAction.action) { + case CredentialManagementAction.EditCredentialsFile: { + const result = await this.openAwsFile('credentials') + // If user clicked "Select Profile", restart profile selection + resolve(result === 'RESTART_PROFILE_SELECTION') + break + } + case CredentialManagementAction.EditConfigFile: { + const result = await this.openAwsFile('config') + // If user clicked "Select Profile", restart profile selection + resolve(result === 'RESTART_PROFILE_SELECTION') + break + } + case CredentialManagementAction.AddNewProfile: { + const newProfile = await this.addNewProfile() + // Return the newly created profile data to use it directly + resolve(newProfile) + break + } + } + } catch (error) { + if (error instanceof ToolkitError && error.code === SmusErrorCodes.UserCancelled) { + // User cancelled, don't treat as error + reject(error) + } else { + reject(error) + } + } + })() + }) + + quickPick.onDidTriggerButton((button) => { + if (button === vscode.QuickInputButtons.Back) { + isCompleted = true + quickPick.dispose() + // User wants to go back to profile selection + resolve(true) + } + }) + + quickPick.onDidHide(() => { + if (!isCompleted) { + quickPick.dispose() + reject( + new ToolkitError('Credential management cancelled', { + code: SmusErrorCodes.UserCancelled, + cancelled: true, + }) + ) + } + }) + + quickPick.show() + }) + } + + /** + * Opens the AWS credentials file in VS Code editor + */ + /** + * Opens an AWS configuration file in VS Code editor + * @param fileType Type of file to open ('credentials' or 'config') + */ + private static async openAwsFile(fileType: 'credentials' | 'config'): Promise { + const logger = this.logger + const isCredentials = fileType === 'credentials' + + try { + const filePath = isCredentials ? getCredentialsFilename() : getConfigFilename() + const fileLabel = isCredentials ? 'credentials' : 'config' + + logger.debug(`Opening ${fileLabel} file: ${filePath}`) + + // Ensure the .aws directory exists + await this.ensureAwsDirectoryExists() + + // Create the file if it doesn't exist + if (!(await fs.existsFile(filePath))) { + await fs.writeFile(filePath, '') + logger.debug(`Created new ${fileLabel} file`) + } + + // Open the file in VS Code + const document = await vscode.workspace.openTextDocument(filePath) + await vscode.window.showTextDocument(document) + + logger.debug(`${fileLabel} file opened successfully`) + } catch (error) { + const fileLabel = isCredentials ? 'credentials' : 'config' + logger.error(`Failed to open ${fileLabel} file: %s`, error) + throw new ToolkitError(`Failed to open AWS ${fileLabel} file: ${(error as Error).message}`, { + code: isCredentials ? 'CredentialsFileError' : 'ConfigFileError', + }) + } + } + + /** + * Interactive flow to add a new AWS credential profile with back navigation + * @returns Promise resolving to the newly created profile data + */ + private static async addNewProfile(): Promise { + const logger = this.logger + + try { + logger.debug('Starting add new profile flow') + + const profileData = await this.collectProfileData() + + if (profileData === 'BACK') { + // User navigated back, throw error to go back to credential management + throw new ToolkitError('User navigated back', { code: SmusErrorCodes.UserCancelled, cancelled: true }) + } + + // Add the profile to credentials file + await this.addProfileToCredentialsFile( + profileData.profileName, + profileData.accessKeyId, + profileData.secretAccessKey, + profileData.sessionToken, + profileData.region + ) + + // Show success message + void vscode.window.showInformationMessage( + `AWS profile '${profileData.profileName}' has been added successfully and will be used for authentication.` + ) + + logger.debug(`Successfully added new profile: ${profileData.profileName}`) + + // Return the profile data to use it directly + return { + profileName: profileData.profileName, + region: profileData.region, + } + } catch (error) { + // Only log actual errors, not user cancellations + if (error instanceof ToolkitError && error.code === SmusErrorCodes.UserCancelled) { + logger.debug('User cancelled add new profile flow') + throw error // Re-throw for telemetry but don't log as error + } + logger.error('Failed to add new profile: %s', error) + throw new ToolkitError(`Failed to add new profile: ${(error as Error).message}`, { + code: 'AddProfileError', + }) + } + } + + /** + * Collects profile data through a multi-step flow with back navigation + */ + private static async collectProfileData(): Promise< + | { + profileName: string + accessKeyId: string + secretAccessKey: string + sessionToken?: string + region: string + } + | 'BACK' + > { + let currentStep = 1 + let profileName = '' + let accessKeyId = '' + let secretAccessKey = '' + let sessionToken = '' + let region = '' + + while (currentStep <= 5) { + switch (currentStep) { + case 1: { + // Step 1: Profile Name + const result = await this.getProfileNameInput() + if (result === 'BACK') { + return 'BACK' // User wants to go back - exit to credential management menu + } + profileName = result + currentStep = 2 + break + } + case 2: { + // Step 2: Access Key ID + const result = await this.getAccessKeyIdInput() + if (result === 'BACK') { + currentStep = 1 // Go back to step 1 + } else { + accessKeyId = result + currentStep = 3 + } + break + } + case 3: { + // Step 3: Secret Access Key + const result = await this.getSecretAccessKeyInput() + if (result === 'BACK') { + currentStep = 2 // Go back to step 2 + } else { + secretAccessKey = result + currentStep = 4 + } + break + } + case 4: { + // Step 4: Session Token (optional) + const result = await this.getSessionTokenInput() + if (result === 'BACK') { + currentStep = 3 // Go back to step 3 + } else { + sessionToken = result + currentStep = 5 + } + break + } + case 5: { + // Step 5: Region + const result = await this.showRegionSelection({ + title: 'Add New AWS Profile - Step 5 of 5', + placeholder: 'Select a default region', + returnBackOnCancel: true, + }) + if (result === 'BACK') { + currentStep = 4 // Go back to step 4 + } else { + region = result + currentStep = 6 // Exit the loop + } + break + } + } + } + + return { + profileName, + accessKeyId, + secretAccessKey, + sessionToken: sessionToken || undefined, + region, // Region is always set since step 5 is required + } + } + + /** + * Gets profile name input with back navigation and existing profile validation + */ + private static async getProfileNameInput(): Promise { + return new Promise((resolve) => { + const quickPick = this.createInputQuickPick( + 'Add New AWS Profile - Step 1 of 5', + 'Type a profile name (e.g., my-profile, dev, prod)' + ) + quickPick.items = [] + + let isCompleted = false + + quickPick.onDidTriggerButton((button) => { + if (button === vscode.QuickInputButtons.Back) { + isCompleted = true + quickPick.dispose() + resolve('BACK') + } + }) + + quickPick.onDidChangeValue(async (value) => { + // Show placeholder when empty + if (!value) { + quickPick.items = [ + { + label: '$(edit) Enter profile name', + description: 'e.g., my-profile, dev, prod', + detail: 'Profile names can contain letters, numbers, hyphens, and underscores', + }, + ] + return + } + + // Validate input as user types + if (value.includes(' ')) { + quickPick.items = [ + { + label: `${value}`, + description: '$(error) Cannot contain spaces', + detail: 'Valid characters: letters, numbers, hyphens, underscores', + }, + ] + } else if (!this.profileNamePattern.test(value)) { + quickPick.items = [ + { + label: `${value}`, + description: '$(error) Invalid characters', + detail: 'Profile names can only contain letters, numbers, hyphens, and underscores', + }, + ] + } else if (value.length < 2) { + quickPick.items = [ + { + label: `${value}`, + description: `$(info) Too short (${value.length}/2 min)`, + detail: 'Profile names should be at least 2 characters long', + }, + ] + } else { + // Check if profile already exists + try { + const profiles = await loadSharedCredentialsProfiles() + const profileExists = profiles[value] !== undefined + + if (profileExists) { + quickPick.items = [ + { + label: `${value}`, + description: '$(warning) Profile exists - will be overwritten', + detail: 'Press Enter to overwrite the existing profile', + }, + ] + } else { + quickPick.items = [ + { + label: `${value}`, + description: `$(check) Valid (${value.length} characters)`, + detail: 'Press Enter to use this profile name', + }, + ] + } + } catch (error) { + // If we can't load profiles, just show as valid + quickPick.items = [ + { + label: `${value}`, + description: `$(check) Valid (${value.length} characters)`, + detail: 'Press Enter to use this profile name', + }, + ] + } + } + }) + + quickPick.onDidAccept(async () => { + const value = quickPick.value.trim() + + // Validate final input + if (!value || value.length < 2) { + return // Don't accept empty or too short input + } + if (value.includes(' ')) { + return // Don't accept names with spaces + } + if (!this.profileNamePattern.test(value)) { + return // Don't accept invalid characters + } + + // Check if profile exists and ask for confirmation + try { + const profiles = await loadSharedCredentialsProfiles() + const profileExists = profiles[value] !== undefined + + if (profileExists) { + isCompleted = true + quickPick.dispose() + + // Ask for confirmation to overwrite + const overwrite = await vscode.window.showWarningMessage( + `Profile '${value}' already exists. Do you want to overwrite it?`, + { modal: true }, + 'Overwrite' + ) + + if (overwrite === 'Overwrite') { + resolve(value) + } else { + // User cancelled, restart the input + const result = await this.getProfileNameInput() + resolve(result) + } + return + } + } catch (error) { + // If we can't load profiles, just continue + } + + isCompleted = true + quickPick.dispose() + resolve(value) + }) + + quickPick.onDidHide(() => { + if (!isCompleted) { + quickPick.dispose() + resolve('BACK') + } + }) + + quickPick.show() + }) + } + + /** + * Gets access key ID input with back navigation + */ + private static async getAccessKeyIdInput(): Promise { + return new Promise((resolve) => { + const quickPick = this.createInputQuickPick( + 'Add New AWS Profile - Step 2 of 5', + 'Type your AWS Access Key ID (e.g., AKIAIOSFODNN7EXAMPLE)' + ) + quickPick.items = [] + + let isCompleted = false + + quickPick.onDidTriggerButton((button) => { + if (button === vscode.QuickInputButtons.Back) { + isCompleted = true + quickPick.dispose() + resolve('BACK') + } + }) + + quickPick.onDidChangeValue((value) => { + // Show placeholder when empty + if (!value) { + quickPick.items = [ + { + label: '$(key) Enter AWS Access Key ID', + description: 'e.g., AKIAIOSFODNN7EXAMPLE', + detail: 'Access Key IDs are typically 16-32 characters long', + }, + ] + return + } + + // Validate input as user types (AWS STS API: 16-128 chars, pattern [\w]*) + // Reference: https://docs.aws.amazon.com/STS/latest/APIReference/API_Credentials.html + if (!this.accessKeyIdPattern.test(value)) { + quickPick.items = [ + { + label: `${value}`, + description: '$(error) Invalid characters', + detail: 'Access Key IDs can only contain letters, numbers, and underscores', + }, + ] + } else if (value.length < 16) { + quickPick.items = [ + { + label: `${value}`, + description: `$(info) Too short (${value.length}/16 min)`, + detail: 'AWS Access Key IDs must be 16-128 characters long', + }, + ] + } else if (value.length > 128) { + quickPick.items = [ + { + label: `${value}`, + description: `$(error) Too long (${value.length}/128 max)`, + detail: 'AWS Access Key IDs must be 16-128 characters long', + }, + ] + } else { + quickPick.items = [ + { + label: `${value}`, + description: `$(check) Valid (${value.length} characters)`, + detail: 'Press Enter to use this Access Key ID', + }, + ] + } + }) + + quickPick.onDidAccept(() => { + const value = quickPick.value.trim() + + // Validate final input (AWS STS API: 16-128 chars, pattern [\w]*) + // Reference: https://docs.aws.amazon.com/STS/latest/APIReference/API_Credentials.html + if (!value) { + return // Don't accept empty input + } + if (!this.accessKeyIdPattern.test(value)) { + return // Don't accept invalid characters + } + if (value.length < 16 || value.length > 128) { + return // Don't accept invalid length + } + + isCompleted = true + quickPick.dispose() + resolve(value) + }) + + quickPick.onDidHide(() => { + if (!isCompleted) { + quickPick.dispose() + resolve('BACK') + } + }) + + quickPick.show() + }) + } + + /** + * Gets secret access key input with back navigation + */ + private static async getSecretAccessKeyInput(): Promise { + return new Promise((resolve) => { + const quickPick = this.createInputQuickPick( + 'Add New AWS Profile - Step 3 of 5', + 'Type your AWS Secret Access Key (will be hidden when typing)' + ) + quickPick.items = [] + + let isCompleted = false + + quickPick.onDidTriggerButton((button) => { + if (button === vscode.QuickInputButtons.Back) { + isCompleted = true + quickPick.dispose() + resolve('BACK') + } + }) + + quickPick.onDidChangeValue((value) => { + // Show placeholder when empty + if (!value) { + quickPick.items = [ + { + label: '$(lock) Enter AWS Secret Access Key', + description: 'Required field', + detail: 'Enter your AWS Secret Access Key', + }, + ] + return + } + + // AWS STS API: Required, no specific pattern/length constraints in docs + // Reference: https://docs.aws.amazon.com/STS/latest/APIReference/API_Credentials.html + quickPick.items = [ + { + label: '•'.repeat(Math.min(value.length, 40)), + description: `$(check) ${value.length} characters entered`, + detail: 'Press Enter to continue', + }, + ] + }) + + quickPick.onDidAccept(() => { + const value = quickPick.value.trim() + + // Validate final input - AWS STS API only requires non-empty + // Reference: https://docs.aws.amazon.com/STS/latest/APIReference/API_Credentials.html + if (!value) { + return // Don't accept empty input + } + + isCompleted = true + quickPick.dispose() + resolve(value) + }) + + quickPick.onDidHide(() => { + if (!isCompleted) { + quickPick.dispose() + resolve('BACK') + } + }) + + quickPick.show() + }) + } + + /** + * Gets session token input with back navigation + */ + private static async getSessionTokenInput(): Promise { + return new Promise((resolve) => { + const quickPick = this.createInputQuickPick( + 'Add New AWS Profile - Step 4 of 5', + 'Enter your AWS Session Token (optional for temporary credentials)' + ) + + // Start with skip option only + quickPick.items = [ + { + label: '$(arrow-right) Skip', + description: 'Skip session token (for permanent credentials)', + detail: 'Use this for regular IAM user access keys', + action: SessionTokenAction.Skip, + } as vscode.QuickPickItem & { action: SessionTokenAction }, + ] + + let isCompleted = false + + quickPick.onDidTriggerButton((button) => { + if (button === vscode.QuickInputButtons.Back) { + isCompleted = true + quickPick.dispose() + resolve('BACK') + } + }) + + quickPick.onDidChangeValue((value) => { + if (!value) { + // Show skip option when empty + quickPick.items = [ + { + label: '$(arrow-right) Skip', + description: 'Skip session token (for permanent credentials)', + detail: 'Use this for regular IAM user access keys', + action: SessionTokenAction.Skip, + } as vscode.QuickPickItem & { action: SessionTokenAction }, + ] + return + } + + // AWS STS API: Required for temporary credentials, no specific pattern/length constraints in docs + // Reference: https://docs.aws.amazon.com/STS/latest/APIReference/API_Credentials.html + quickPick.items = [ + { + label: '•'.repeat(Math.min(value.length, 40)), + description: `$(check) ${value.length} characters entered`, + detail: 'Press Enter to use this session token', + action: SessionTokenAction.UseToken, + } as vscode.QuickPickItem & { action: SessionTokenAction }, + { + label: '$(arrow-right) Skip', + description: 'Skip session token (for permanent credentials)', + detail: 'Use this for regular IAM user access keys', + action: SessionTokenAction.Skip, + } as vscode.QuickPickItem & { action: SessionTokenAction }, + ] + }) + + quickPick.onDidAccept(() => { + const selectedItem = quickPick.selectedItems[0] + const currentValue = quickPick.value + + isCompleted = true + quickPick.dispose() + + // If user typed something and pressed Enter without selecting an item, use the typed value (trimmed) + if (!selectedItem && currentValue) { + resolve(currentValue.trim()) + return + } + + // If no selection with empty value, skip + if (!selectedItem) { + resolve('') + return + } + + const itemWithAction = selectedItem as vscode.QuickPickItem & { action: SessionTokenAction } + + // Handle based on action + switch (itemWithAction.action) { + case SessionTokenAction.Skip: + resolve('') + break + case SessionTokenAction.UseToken: + resolve(currentValue.trim()) + break + case SessionTokenAction.Warning: + // User can still proceed with warning, use the typed value + resolve(currentValue.trim()) + break + default: + resolve('') + } + }) + + quickPick.onDidHide(() => { + if (!isCompleted) { + quickPick.dispose() + resolve('BACK') + } + }) + + quickPick.show() + }) + } + + /** + * Ensures the ~/.aws directory exists + */ + private static async ensureAwsDirectoryExists(): Promise { + const awsDir = path.join(fs.getUserHomeDir(), '.aws') + if (!(await fs.existsDir(awsDir))) { + await fs.mkdir(awsDir) + } + } + + /** + * Adds a new profile to the credentials file or overwrites existing one + */ + private static async addProfileToCredentialsFile( + profileName: string, + accessKeyId: string, + secretAccessKey: string, + sessionToken?: string, + region?: string + ): Promise { + const credentialsPath = getCredentialsFilename() + + // Ensure the .aws directory exists + await this.ensureAwsDirectoryExists() + + // Read existing content or create new + let content = '' + if (await fs.existsFile(credentialsPath)) { + content = await fs.readFileText(credentialsPath) + } + + // Create new profile lines (no spaces around =) + const newProfileLines = [ + `[${profileName}]`, + `aws_access_key_id=${accessKeyId}`, + `aws_secret_access_key=${secretAccessKey}`, + ] + + if (sessionToken) { + newProfileLines.push(`aws_session_token=${sessionToken}`) + } + + if (region) { + newProfileLines.push(`region=${region}`) + } + + // Parse the file line by line to handle profile replacement properly + const lines = content.split('\n') + const newLines: string[] = [] + let inTargetProfile = false + let profileFound = false + + for (let i = 0; i < lines.length; i++) { + const line = lines[i].trim() + + // Check if this is a profile header + if (line.startsWith('[') && line.endsWith(']')) { + const currentProfileName = line.slice(1, -1) + + if (currentProfileName === profileName) { + // Found the target profile - replace it + if (!profileFound) { + newLines.push(...newProfileLines) + profileFound = true + } + inTargetProfile = true + continue + } else { + // Different profile - end replacement mode + inTargetProfile = false + newLines.push(lines[i]) + } + } else if (!inTargetProfile) { + // Not in target profile, keep the line + newLines.push(lines[i]) + } + // If inTargetProfile is true, we skip the line (removing old profile content) + } + + // If profile wasn't found, add it at the end + if (!profileFound) { + if (newLines.length > 0 && newLines[newLines.length - 1].trim() !== '') { + newLines.push('') // Add blank line before new profile + } + newLines.push(...newProfileLines) + } + + // Update content with the new lines + content = newLines.join('\n') + + // Write back to file + await fs.writeFile(credentialsPath, content) + } + + /** + * Updates an existing profile with a new region + */ + private static async updateProfileRegion(profileName: string, region: string): Promise { + const logger = this.logger + + try { + logger.debug(`Updating profile ${profileName} with region ${region}`) + + const credentialsPath = getCredentialsFilename() + + if (!(await fs.existsFile(credentialsPath))) { + throw new ToolkitError('Credentials file not found', { code: 'CredentialsFileNotFound' }) + } + + // Read the current credentials file + const content = await fs.readFileText(credentialsPath) + + // Find the profile section + const profileSectionRegex = new RegExp(`^\\[${profileName}\\]$`, 'm') + const profileMatch = content.match(profileSectionRegex) + + if (!profileMatch) { + throw new ToolkitError(`Profile ${profileName} not found in credentials file`, { + code: 'ProfileNotFound', + }) + } + + // Find the next profile section or end of file + const profileStartIndex = profileMatch.index! + const nextProfileMatch = content.slice(profileStartIndex + 1).match(/^\[.*\]$/m) + const profileEndIndex = nextProfileMatch ? profileStartIndex + 1 + nextProfileMatch.index! : content.length + + // Extract the profile section + const profileSection = content.slice(profileStartIndex, profileEndIndex) + + // Check if region already exists in the profile + let updatedProfileSection: string + + if (this.regionLinePattern.test(profileSection)) { + // Replace existing region + updatedProfileSection = profileSection.replace(this.regionLinePattern, `region = ${region}`) + } else { + // Add region to the profile (before any empty lines at the end) + const lines = profileSection.split('\n') + // Find the last non-empty line index (compatible with older JS versions) + let lastNonEmptyIndex = -1 + for (let i = lines.length - 1; i >= 0; i--) { + if (lines[i].trim() !== '') { + lastNonEmptyIndex = i + break + } + } + lines.splice(lastNonEmptyIndex + 1, 0, `region = ${region}`) + updatedProfileSection = lines.join('\n') + } + + // Replace the profile section in the content + const updatedContent = + content.slice(0, profileStartIndex) + updatedProfileSection + content.slice(profileEndIndex) + + // Write back to file + await fs.writeFile(credentialsPath, updatedContent) + + logger.debug(`Successfully updated profile ${profileName} with region ${region}`) + } catch (error) { + logger.error('Failed to update profile region: %s', error) + throw new ToolkitError(`Failed to update profile region: ${(error as Error).message}`, { + code: 'UpdateProfileError', + }) + } + } +} diff --git a/packages/core/src/sagemakerunifiedstudio/auth/ui/ssoAuthentication.ts b/packages/core/src/sagemakerunifiedstudio/auth/ui/ssoAuthentication.ts new file mode 100644 index 00000000000..2d2efaa15f8 --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/auth/ui/ssoAuthentication.ts @@ -0,0 +1,108 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as vscode from 'vscode' +import { SmusUtils } from '../../shared/smusUtils' + +/** + * SSO authentication UI components for SMUS + */ +export class SmusSsoAuthenticationUI { + /** + * Shows domain URL input with back button support + */ + public static async showDomainUrlInput(): Promise { + return new Promise((resolve) => { + const quickPick = vscode.window.createQuickPick() + quickPick.title = 'SageMaker Unified Studio Authentication' + quickPick.placeholder = 'Enter your SageMaker Unified Studio Domain URL' + quickPick.canSelectMany = false + quickPick.ignoreFocusOut = true + + // Add back button + const backButton = vscode.QuickInputButtons.Back + quickPick.buttons = [backButton] + + // Start with placeholder item + quickPick.items = [ + { + label: '$(globe) Enter Domain URL', + description: 'e.g., https://dzd_xxxxxxxxx.sagemaker.region.on.aws', + detail: 'Type your SageMaker Unified Studio domain URL above', + }, + ] + + let isCompleted = false + + quickPick.onDidTriggerButton((button) => { + if (button === backButton) { + isCompleted = true + quickPick.dispose() + resolve('BACK') + } + }) + + quickPick.onDidChangeValue((value) => { + if (!value) { + quickPick.items = [ + { + label: '$(globe) Enter Domain URL', + description: 'e.g., https://dzd_xxxxxxxxx.sagemaker.region.on.aws', + detail: 'Type your SageMaker Unified Studio domain URL above', + }, + ] + return + } + + // Validate input as user types + const validation = SmusUtils.validateDomainUrl(value) + if (validation) { + quickPick.items = [ + { + label: '$(error) Invalid Domain URL', + description: validation, + detail: `Current input: "${value}"`, + }, + ] + } else { + quickPick.items = [ + { + label: '$(check) Use this Domain URL', + description: 'Press Enter to connect', + detail: `Domain URL: ${value}`, + }, + ] + } + }) + + quickPick.onDidAccept(() => { + const value = quickPick.value.trim() + + // Validate final input + if (!value) { + return // Don't accept empty input + } + + const validation = SmusUtils.validateDomainUrl(value) + if (validation) { + return // Don't accept invalid URLs + } + + isCompleted = true + quickPick.dispose() + resolve(value) + }) + + quickPick.onDidHide(() => { + if (!isCompleted) { + quickPick.dispose() + resolve(undefined) // User cancelled + } + }) + + quickPick.show() + }) + } +} diff --git a/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/client/connectedSpaceDataZoneClient.ts b/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/client/connectedSpaceDataZoneClient.ts index 8f0998e295f..cc26dd3f431 100644 --- a/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/client/connectedSpaceDataZoneClient.ts +++ b/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/client/connectedSpaceDataZoneClient.ts @@ -22,7 +22,7 @@ export interface DataZoneConnection { */ export class ConnectedSpaceDataZoneClient { private datazoneClient: DataZone | undefined - private readonly logger = getLogger() + private readonly logger = getLogger('smus') constructor( private readonly region: string, diff --git a/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/services/connectionOptionsService.ts b/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/services/connectionOptionsService.ts index 901c2e5a60f..9c258536f68 100644 --- a/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/services/connectionOptionsService.ts +++ b/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/services/connectionOptionsService.ts @@ -133,7 +133,7 @@ class ConnectionOptionsService { this.cachedConnections = processedConnections return processedConnections } catch (error) { - getLogger().error('Failed to list DataZone connections: %s', error as Error) + getLogger('smus').error('Failed to list DataZone connections: %s', error as Error) return [] } } @@ -190,7 +190,7 @@ class ConnectionOptionsService { return options } catch (error) { - getLogger().error('Failed to get connection options: %s', error as Error) + getLogger('smus').error('Failed to get connection options: %s', error as Error) return [] } } @@ -231,7 +231,7 @@ class ConnectionOptionsService { return projectOptions } catch (error) { - getLogger().error('Failed to get project options: %s', error as Error) + getLogger('smus').error('Failed to get project options: %s', error as Error) return [] } } @@ -269,7 +269,7 @@ class ConnectionOptionsService { this.projectOptions = newProjectOptions } catch (error) { - getLogger().error('Failed to update connection and project options: %s', error as Error) + getLogger('smus').error('Failed to update connection and project options: %s', error as Error) this.connectionOptions = [] this.projectOptions = [] } diff --git a/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/services/notebookStateManager.ts b/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/services/notebookStateManager.ts index 80654b64ac0..f2d9ec2392b 100644 --- a/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/services/notebookStateManager.ts +++ b/packages/core/src/sagemakerunifiedstudio/connectionMagicsSelector/services/notebookStateManager.ts @@ -41,7 +41,7 @@ class NotebookStateManager { edit.set(cell.notebook.uri, [notebookEdit]) await vscode.workspace.applyEdit(edit) } catch (error) { - getLogger().warn('setCellMetadata: Failed to set metadata, falling back to in-memory storage') + getLogger('smus').warn('setCellMetadata: Failed to set metadata, falling back to in-memory storage') } } @@ -326,7 +326,7 @@ class NotebookStateManager { await this.updateCellContent(cell, newCellContent) } } catch (error) { - getLogger().error(`Error updating cell with magic command: ${error}`) + getLogger('smus').error(`Error updating cell with magic command: ${error}`) } } @@ -356,7 +356,9 @@ class NotebookStateManager { } } } catch (error) { - getLogger().error(`NotebookEdit failed, attempting to update cell content with WorkspaceEdit: ${error}`) + getLogger('smus').error( + `NotebookEdit failed, attempting to update cell content with WorkspaceEdit: ${error}` + ) } try { @@ -371,10 +373,10 @@ class NotebookStateManager { const success = await vscode.workspace.applyEdit(edit) if (!success) { - getLogger().error('WorkspaceEdit failed to apply') + getLogger('smus').error('WorkspaceEdit failed to apply') } } catch (error) { - getLogger().error(`Failed to update cell content with WorkspaceEdit: ${error}`) + getLogger('smus').error(`Failed to update cell content with WorkspaceEdit: ${error}`) try { const document = cell.document @@ -386,7 +388,7 @@ class NotebookStateManager { await vscode.workspace.applyEdit(edit) } } catch (finalError) { - getLogger().error(`All cell update methods failed: ${finalError}`) + getLogger('smus').error(`All cell update methods failed: ${finalError}`) } } } diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/activation.ts b/packages/core/src/sagemakerunifiedstudio/explorer/activation.ts index 8a686b48654..2179b767ffa 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/activation.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/activation.ts @@ -12,27 +12,29 @@ import { SageMakerUnifiedStudioRootNode, selectSMUSProject, } from './nodes/sageMakerUnifiedStudioRootNode' -import { DataZoneClient } from '../shared/client/datazoneClient' import { openRemoteConnect, stopSpace } from '../../awsService/sagemaker/commands' import { SagemakerUnifiedStudioSpaceNode } from './nodes/sageMakerUnifiedStudioSpaceNode' import { SageMakerUnifiedStudioProjectNode } from './nodes/sageMakerUnifiedStudioProjectNode' import { getLogger } from '../../shared/logger/logger' import { setSmusConnectedContext, SmusAuthenticationProvider } from '../auth/providers/smusAuthenticationProvider' +import { isSmusIamConnection } from '../auth/model' import { setupUserActivityMonitoring } from '../../awsService/sagemaker/sagemakerSpace' import { telemetry } from '../../shared/telemetry/telemetry' import { isSageMaker } from '../../shared/extensionUtilities' import { recordSpaceTelemetry } from '../shared/telemetry' +import { DataZoneClient } from '../shared/client/datazoneClient' +import { handleCredExpiredError } from '../shared/credentialExpiryHandler' export async function activate(extensionContext: vscode.ExtensionContext): Promise { // Initialize the SMUS authentication provider - const logger = getLogger() - logger.debug('SMUS: Initializing authentication provider') + const logger = getLogger('smus') + logger.debug('Initializing authentication provider') // Create the auth provider instance (this will trigger restore() in the constructor) const smusAuthProvider = SmusAuthenticationProvider.fromContext() await smusAuthProvider.restore() // Set initial auth context after restore void setSmusConnectedContext(smusAuthProvider.isConnected()) - logger.debug('SMUS: Authentication provider initialized') + logger.debug('Authentication provider initialized') // Create the SMUS projects tree view const smusRootNode = new SageMakerUnifiedStudioRootNode(smusAuthProvider, extensionContext) @@ -44,9 +46,9 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi // Register the commands extensionContext.subscriptions.push( - smusLoginCommand.register(), + smusLoginCommand.register(extensionContext), smusLearnMoreCommand.register(), - smusSignOutCommand.register(), + smusSignOutCommand.register(extensionContext), treeView, vscode.commands.registerCommand('aws.smus.rootView.refresh', () => { treeDataProvider.refresh() @@ -64,6 +66,10 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi await projectNode.refreshNode() }), + vscode.commands.registerCommand('aws.smus.refresh', async () => { + treeDataProvider.refresh() + }), + vscode.commands.registerCommand('aws.smus.switchProject', async () => { // Get the project node from the root node to ensure we're using the same instance const projectNode = smusRootNode.getProjectSelectNode() @@ -75,8 +81,13 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi return } await telemetry.smus_stopSpace.run(async (span) => { - await recordSpaceTelemetry(span, node) - await stopSpace(node.resource, extensionContext, node.resource.sageMakerClient) + try { + await recordSpaceTelemetry(span, node) + await stopSpace(node.resource, extensionContext, node.resource.sageMakerClient) + } catch (err) { + await handleCredExpiredError(err) + throw err + } }) }), @@ -87,8 +98,13 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi return } await telemetry.smus_openRemoteConnection.run(async (span) => { - await recordSpaceTelemetry(span, node) - await openRemoteConnect(node.resource, extensionContext, node.resource.sageMakerClient) + try { + await recordSpaceTelemetry(span, node) + await openRemoteConnect(node.resource, extensionContext, node.resource.sageMakerClient) + } catch (err) { + await handleCredExpiredError(err) + throw err + } }) } ), @@ -97,16 +113,33 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi if (connection) { try { await smusAuthProvider.reauthenticate(connection) - // Refresh the tree view after successful reauthentication treeDataProvider.refresh() - // Show success message - void vscode.window.showInformationMessage( - 'Successfully reauthenticated with SageMaker Unified Studio' - ) + + // IAM connections handle their own success messages + // Only show success message for SSO connections + if (!isSmusIamConnection(connection)) { + void vscode.window.showInformationMessage( + 'Successfully reauthenticated with SageMaker Unified Studio' + ) + } } catch (error) { - // Show error message if reauthentication fails - void vscode.window.showErrorMessage(`Failed to reauthenticate: ${error}`) - logger.error('SMUS: Reauthentication failed: %O', error) + // Extract the most detailed error message available + let errorMessage = 'Unknown error' + if (error instanceof Error) { + // Check if this is a ToolkitError with a cause chain + const cause = (error as any).cause + if (cause instanceof Error) { + // Use the cause's message as it contains the detailed validation error + errorMessage = cause.message + } else { + // Fall back to the error's own message + errorMessage = error.message + } + } + + // Show the detailed error message to the user + void vscode.window.showErrorMessage(`${errorMessage}`) + logger.error('Reauthentication failed: %O', error) } } }), diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/federatedConnectionStrategy.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/federatedConnectionStrategy.ts new file mode 100644 index 00000000000..467509cefa7 --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/federatedConnectionStrategy.ts @@ -0,0 +1,333 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as vscode from 'vscode' +import { TreeNode } from '../../../shared/treeview/resourceTreeDataProvider' +import { getLogger } from '../../../shared/logger/logger' +import { DataZoneConnection } from '../../shared/client/datazoneClient' +import { GlueClient, ListEntitiesCommand, DescribeEntityCommand, Entity, Field } from '@aws-sdk/client-glue' +import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider' +import { getIcon } from '../../../shared/icons' +import { createPlaceholderItem } from '../../../shared/treeview/utils' +import { createErrorItem, createColumnTreeItem } from './utils' +import { NO_DATA_FOUND_MESSAGE, NodeType } from './types' +import { handleCredExpiredError } from '../../shared/credentialExpiryHandler' + +/** + * Creates a federated connection node + */ +export async function createFederatedConnectionNode( + connection: DataZoneConnection, + connectionCredentialsProvider: ConnectionCredentialsProvider, + region: string +): Promise { + const logger = getLogger('smus') + + // Check for error message in glue properties + // Create error node directly in this case + const connectionError = connection.props?.glueProperties?.errorMessage + if (connectionError) { + return createErrorItem(connectionError, 'glue-error', connection.connectionId) + } + + return { + id: `federated-${connection.connectionId}`, + resource: connection, + getTreeItem: () => { + const item = new vscode.TreeItem(connection.name, vscode.TreeItemCollapsibleState.Collapsed) + item.contextValue = 'federatedConnection' + item.iconPath = getIcon('aws-sagemakerunifiedstudio-catalog') + item.tooltip = `Federated Connection: ${connection.name}` + return item + }, + getChildren: async () => { + try { + return await getFederatedEntities(connection, connectionCredentialsProvider, region) + } catch (err) { + logger.error(`Failed to get federated entities: ${(err as Error).message}`) + const errorMessage = (err as Error).message + await handleCredExpiredError(err, true) + return [ + createErrorItem(`Failed to load entities - ${errorMessage}`, 'entities', connection.connectionId), + ] + } + }, + getParent: () => undefined, + } +} + +/** + * Gets federated entities from Glue API + */ +async function getFederatedEntities( + connection: DataZoneConnection, + connectionCredentialsProvider: ConnectionCredentialsProvider, + region: string +): Promise { + const awsCredentialProvider = async () => { + const credentials = await connectionCredentialsProvider.getCredentials() + return { + accessKeyId: credentials.accessKeyId, + secretAccessKey: credentials.secretAccessKey, + sessionToken: credentials.sessionToken, + expiration: credentials.expiration, + } + } + const glueClient = new GlueClient({ + region: region, + credentials: awsCredentialProvider, + }) + + const glueConnectionName = connection?.glueConnectionName + if (!glueConnectionName) { + return [createErrorItem('No Glue connection name found', 'glue-connection', connection.connectionId)] + } + + const allEntities: Entity[] = [] + let nextToken: string | undefined + + do { + const response = await glueClient.send( + new ListEntitiesCommand({ + ConnectionName: glueConnectionName, + NextToken: nextToken, + }) + ) + + if (response.Entities) { + allEntities.push(...response.Entities) + } + nextToken = response.NextToken + } while (nextToken) + + if (allEntities.length === 0) { + return [createPlaceholderItem(NO_DATA_FOUND_MESSAGE)] + } + + const entityNodes: TreeNode[] = [] + const tableNodes: TreeNode[] = [] + + for (const entity of allEntities) { + const nodeType = getGlueNodeType(entity.Category) + const isTable = nodeType === NodeType.GLUE_TABLE + + const entityNode = createGlueEntityNode(entity, connection, glueClient, glueConnectionName) + + if (isTable) { + tableNodes.push(entityNode) + } else { + entityNodes.push(entityNode) + } + } + + // Always group tables under a "Tables" container + if (tableNodes.length > 0) { + const tablesContainer = createTablesContainer(tableNodes, connection.connectionId) + return [...entityNodes, tablesContainer] + } + + return entityNodes +} + +/** + * Creates a Glue entity node + */ +function createGlueEntityNode( + entity: Entity, + connection: DataZoneConnection, + glueClient: GlueClient, + glueConnectionName: string +): TreeNode { + const logger = getLogger('smus') + const nodeType = getGlueNodeType(entity.Category) + const isTable = nodeType === NodeType.GLUE_TABLE + + return { + id: `${connection.connectionId}-${entity.EntityName}`, + resource: entity, + getTreeItem: () => { + const item = new vscode.TreeItem( + entity.Label || entity.EntityName || 'Unknown', + entity.IsParentEntity || (isTable && !entity.IsParentEntity) + ? vscode.TreeItemCollapsibleState.Collapsed + : vscode.TreeItemCollapsibleState.None + ) + item.contextValue = nodeType + item.iconPath = getGlueEntityIcon(nodeType) + item.tooltip = `${entity.Category}: ${entity.Label || entity.EntityName}` + return item + }, + getChildren: async () => { + try { + if (entity.IsParentEntity) { + return await getChildEntities(entity, connection, glueClient, glueConnectionName) + } else if (isTable) { + return await getTableColumns(entity, glueClient, glueConnectionName) + } + return [] + } catch (err) { + logger.error(`Failed to get children for entity ${entity.EntityName}: ${(err as Error).message}`) + const errorMessage = (err as Error).message + await handleCredExpiredError(err, true) + return [ + createErrorItem( + `Failed to load children - ${errorMessage}`, + 'entity-children', + entity.EntityName || 'unknown' + ), + ] + } + }, + getParent: () => undefined, + } +} + +/** + * Gets child entities for parent entities + */ +async function getChildEntities( + parentEntity: Entity, + connection: DataZoneConnection, + glueClient: GlueClient, + glueConnectionName: string +): Promise { + const allEntities: Entity[] = [] + let nextToken: string | undefined + + do { + const response = await glueClient.send( + new ListEntitiesCommand({ + ConnectionName: glueConnectionName, + ParentEntityName: parentEntity.EntityName, + NextToken: nextToken, + }) + ) + + if (response.Entities) { + allEntities.push(...response.Entities) + } + nextToken = response.NextToken + } while (nextToken) + + if (allEntities.length === 0) { + return [createPlaceholderItem(NO_DATA_FOUND_MESSAGE)] + } + + const entityNodes: TreeNode[] = [] + const tableNodes: TreeNode[] = [] + + for (const entity of allEntities) { + const nodeType = getGlueNodeType(entity.Category) + const isTable = nodeType === NodeType.GLUE_TABLE + const entityNode = createGlueEntityNode(entity, connection, glueClient, glueConnectionName) + + if (isTable) { + tableNodes.push(entityNode) + } else { + entityNodes.push(entityNode) + } + } + + // Always group tables under a "Tables" container if there are any + if (tableNodes.length > 0) { + const tablesContainer = createTablesContainer( + tableNodes, + `${connection.connectionId}-${parentEntity.EntityName}` + ) + return [...entityNodes, tablesContainer] + } + + return entityNodes +} + +/** + * Gets table columns using DescribeEntity + */ +async function getTableColumns( + entity: Entity, + glueClient: GlueClient, + glueConnectionName: string +): Promise { + const response = await glueClient.send( + new DescribeEntityCommand({ + ConnectionName: glueConnectionName, + EntityName: entity.EntityName, + }) + ) + + if (!response.Fields || response.Fields.length === 0) { + return [createPlaceholderItem('No columns found')] + } + + return response.Fields.map((field) => createColumnNode(field, entity.EntityName || 'unknown')) +} + +/** + * Creates a column node + */ +function createColumnNode(field: Field, tableName: string): TreeNode { + return { + id: `${tableName}-${field.FieldName}`, + resource: field, + getTreeItem: () => { + return createColumnTreeItem( + field.Label || field.FieldName || 'Unknown', + field.FieldType || 'unknown', + NodeType.REDSHIFT_COLUMN + ) + }, + getChildren: async () => [], + getParent: () => undefined, + } +} + +/** + * Creates a tables container node + */ +function createTablesContainer(tableNodes: TreeNode[], connectionId: string): TreeNode { + return { + id: `${connectionId}-tables`, + resource: {}, + getTreeItem: () => { + const item = new vscode.TreeItem('Tables', vscode.TreeItemCollapsibleState.Collapsed) + item.contextValue = NodeType.GLUE_TABLE + item.iconPath = new vscode.ThemeIcon('table') + return item + }, + getChildren: async () => tableNodes, + getParent: () => undefined, + } +} + +/** + * Maps Glue entity category to node type + */ +function getGlueNodeType(category?: string): NodeType { + const lowerCategory = category?.toLowerCase() + if (lowerCategory?.includes('schema')) { + return NodeType.GLUE_DATABASE + } else if (lowerCategory?.includes('table')) { + return NodeType.GLUE_TABLE + } else if (lowerCategory?.includes('database')) { + return NodeType.GLUE_DATABASE + } + return NodeType.GLUE_CATALOG +} + +/** + * Gets icon for Glue entity node type + */ +function getGlueEntityIcon(nodeType: NodeType): vscode.ThemeIcon | any { + switch (nodeType) { + case NodeType.GLUE_DATABASE: + return new vscode.ThemeIcon('database') + case NodeType.GLUE_TABLE: + return getIcon('aws-redshift-table') + case NodeType.GLUE_CATALOG: + return getIcon('aws-sagemakerunifiedstudio-catalog') + default: + return getIcon('vscode-circle-outline') + } +} diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.ts index 546a73135c6..8e399e1bea6 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.ts @@ -21,6 +21,7 @@ import { DatabaseObjects, NO_DATA_FOUND_MESSAGE, } from './types' +import { handleCredExpiredError } from '../../shared/credentialExpiryHandler' import { getLabel, isLeafNode, @@ -42,7 +43,7 @@ import { recordDataConnectionTelemetry } from '../../shared/telemetry' export class LakehouseNode implements TreeNode { private childrenNodes: TreeNode[] | undefined private isLoading = false - private readonly logger = getLogger() + private readonly logger = getLogger('smus') constructor( public readonly data: NodeData, @@ -81,7 +82,7 @@ export class LakehouseNode implements TreeNode { this.logger.error(`Failed to get children for node ${this.data.id}: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'getChildren', this.id) as LakehouseNode] } } @@ -129,7 +130,7 @@ export function createLakehouseConnectionNode( connectionCredentialsProvider: ConnectionCredentialsProvider, region: string ): LakehouseNode { - const logger = getLogger() + const logger = getLogger('smus') // Create Glue clients const clientStore = ConnectionClientStore.getInstance() @@ -177,15 +178,17 @@ export function createLakehouseConnectionNode( const errors: LakehouseNode[] = [] if (awsDataCatalogResult.status === 'rejected') { - const errorMessage = (awsDataCatalogResult.reason as Error).message - void vscode.window.showErrorMessage(errorMessage) + const error = awsDataCatalogResult.reason as Error + const errorMessage = error.message errors.push(createErrorItem(errorMessage, 'aws-data-catalog', node.id) as LakehouseNode) + await handleCredExpiredError(error, true) } if (catalogsResult.status === 'rejected') { - const errorMessage = (catalogsResult.reason as Error).message - void vscode.window.showErrorMessage(errorMessage) + const error = catalogsResult.reason as Error + const errorMessage = error.message errors.push(createErrorItem(errorMessage, 'catalogs', node.id) as LakehouseNode) + await handleCredExpiredError(error, true) } const allNodes = [...awsDataCatalog, ...apiCatalogs, ...errors] @@ -195,7 +198,7 @@ export function createLakehouseConnectionNode( } catch (err) { logger.error(`Failed to get Lakehouse catalogs: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'lakehouse-catalogs', node.id) as LakehouseNode] } }) @@ -364,7 +367,7 @@ function createCatalogNode( parent: LakehouseNode, isParent: boolean = false ): LakehouseNode { - const logger = getLogger() + const logger = getLogger('smus') return new LakehouseNode( { @@ -409,7 +412,7 @@ function createCatalogNode( } catch (err) { logger.error(`Failed to get databases for catalog ${catalogId}: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'catalog-databases', node.id) as LakehouseNode] } } @@ -425,7 +428,7 @@ function createDatabaseNode( glueClient: GlueClient, parent: LakehouseNode ): LakehouseNode { - const logger = getLogger() + const logger = getLogger('smus') return new LakehouseNode( { @@ -482,7 +485,7 @@ function createDatabaseNode( } catch (err) { logger.error(`Failed to get tables for database ${databaseName}: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'database-tables', node.id) as LakehouseNode] } } @@ -498,7 +501,7 @@ function createTableNode( glueClient: GlueClient, parent: LakehouseNode ): LakehouseNode { - const logger = getLogger() + const logger = getLogger('smus') return new LakehouseNode( { @@ -530,6 +533,7 @@ function createTableNode( : [createPlaceholderItem(NO_DATA_FOUND_MESSAGE) as LakehouseNode] } catch (err) { logger.error(`Failed to get columns for table ${tableName}: ${(err as Error).message}`) + await handleCredExpiredError(err) return [] } } diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts index af0d7cfbbac..00e1e74f19c 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts @@ -23,6 +23,7 @@ import { import { createPlaceholderItem } from '../../../shared/treeview/utils' import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider' import { GlueCatalog } from '../../shared/client/glueCatalogClient' +import { handleCredExpiredError } from '../../shared/credentialExpiryHandler' import { telemetry } from '../../../shared/telemetry/telemetry' import { recordDataConnectionTelemetry } from '../../shared/telemetry' @@ -32,7 +33,7 @@ import { recordDataConnectionTelemetry } from '../../shared/telemetry' export class RedshiftNode implements TreeNode { private childrenNodes: TreeNode[] | undefined private isLoading = false - private readonly logger = getLogger() + private readonly logger = getLogger('smus') constructor( public readonly data: NodeData, @@ -71,7 +72,7 @@ export class RedshiftNode implements TreeNode { this.logger.error(`Failed to get children for node ${this.data.id}: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'getChildren', this.id) as RedshiftNode] } } @@ -119,7 +120,7 @@ export function createRedshiftConnectionNode( connection: DataZoneConnection, connectionCredentialsProvider: ConnectionCredentialsProvider ): RedshiftNode { - const logger = getLogger() + const logger = getLogger('smus') return new RedshiftNode( { id: connection.connectionId, @@ -209,9 +210,10 @@ export function createRedshiftConnectionNode( // Add database nodes if (filteredDatabases.length === 0) { if (databasesResult.status === 'rejected') { + const error = databasesResult.reason as Error const errorMessage = `Failed to fetch databases - ${databasesResult.reason?.message || databasesResult.reason}.` - void vscode.window.showErrorMessage(errorMessage) allNodes.push(createErrorItem(errorMessage, 'databases', node.id) as RedshiftNode) + await handleCredExpiredError(error, true) } else { allNodes.push(createPlaceholderItem(NO_DATA_FOUND_MESSAGE) as RedshiftNode) } @@ -226,9 +228,10 @@ export function createRedshiftConnectionNode( // Add catalog nodes if (filteredCatalogs.length === 0) { if (catalogsResult.status === 'rejected') { + const error = catalogsResult.reason as Error const errorMessage = `Failed to fetch catalogs - ${catalogsResult.reason?.message || catalogsResult.reason}` - void vscode.window.showErrorMessage(errorMessage) allNodes.push(createErrorItem(errorMessage, 'catalogs', node.id) as RedshiftNode) + await handleCredExpiredError(error, true) } else { allNodes.push(createPlaceholderItem(NO_DATA_FOUND_MESSAGE) as RedshiftNode) } @@ -289,7 +292,7 @@ async function wakeUpDatabase( connectionCredentialsProvider: ConnectionCredentialsProvider, connection: DataZoneConnection ) { - const logger = getLogger() + const logger = getLogger('smus') const clientStore = ConnectionClientStore.getInstance() const sqlClient = clientStore.getSQLWorkbenchClient(connection.connectionId, region, connectionCredentialsProvider) try { @@ -307,7 +310,7 @@ function createDatabaseNode( connectionConfig: ConnectionConfig, parent: RedshiftNode ): RedshiftNode { - const logger = getLogger() + const logger = getLogger('smus') return new RedshiftNode( { @@ -384,7 +387,7 @@ function createDatabaseNode( } catch (err) { logger.error(`Failed to get schemas: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'schemas', node.id) as RedshiftNode] } } @@ -395,7 +398,7 @@ function createDatabaseNode( * Creates a schema node */ function createSchemaNode(schemaName: string, connectionConfig: ConnectionConfig, parent: RedshiftNode): RedshiftNode { - const logger = getLogger() + const logger = getLogger('smus') return new RedshiftNode( { @@ -521,7 +524,7 @@ function createSchemaNode(schemaName: string, connectionConfig: ConnectionConfig } catch (err) { logger.error(`Failed to get schema contents: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'schema-contents', node.id) as RedshiftNode] } } @@ -578,7 +581,7 @@ function createObjectNode( connectionConfig: ConnectionConfig, parent: RedshiftNode ): RedshiftNode { - const logger = getLogger() + const logger = getLogger('smus') return new RedshiftNode( { @@ -689,7 +692,7 @@ function createObjectNode( } catch (err) { logger.error(`Failed to get columns: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'columns', node.id) as RedshiftNode] } } @@ -870,7 +873,7 @@ function createCatalogDatabaseNode( return [createContainerNode(NodeType.REDSHIFT_TABLE, tables, connectionConfig, node)] } catch (err) { const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'catalog-tables', node.id) as RedshiftNode] } } @@ -967,7 +970,7 @@ function createCatalogTableNode( : [createPlaceholderItem(NO_DATA_FOUND_MESSAGE) as RedshiftNode] } catch (err) { const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'catalog-columns', node.id) as RedshiftNode] } } @@ -1030,7 +1033,7 @@ function createCatalogNode( : [createPlaceholderItem(NO_DATA_FOUND_MESSAGE) as RedshiftNode] } catch (err) { const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'catalog-databases', node.id) as RedshiftNode] } } diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/s3Strategy.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/s3Strategy.ts index 4106a0b4889..53d481bd2de 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/s3Strategy.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/s3Strategy.ts @@ -18,6 +18,7 @@ import { ListCallerAccessGrantsEntry, } from '@aws-sdk/client-s3-control' import { S3, ListObjectsV2Command } from '@aws-sdk/client-s3' +import { handleCredExpiredError } from '../../shared/credentialExpiryHandler' import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider' import { telemetry } from '../../../shared/telemetry/telemetry' import { recordDataConnectionTelemetry } from '../../shared/telemetry' @@ -30,7 +31,7 @@ export const DATA_DEFAULT_S3_CONNECTION_NAME_REGEXP = /^(project\.s3_default_fol * S3 data node for SageMaker Unified Studio */ export class S3Node implements TreeNode { - private readonly logger = getLogger() + private readonly logger = getLogger('smus') private childrenNodes: TreeNode[] | undefined private isLoading = false @@ -71,7 +72,7 @@ export class S3Node implements TreeNode { this.logger.error(`Failed to get children for node ${this.data.id}: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'getChildren', this.id) as S3Node] } } @@ -112,7 +113,7 @@ export function createS3ConnectionNode( connectionCredentialsProvider: ConnectionCredentialsProvider, region: string ): S3Node { - const logger = getLogger() + const logger = getLogger('smus') // Parse S3 URI from connection const s3Info = parseS3Uri(connection) @@ -123,6 +124,9 @@ export function createS3ConnectionNode( return createErrorItem(errorMessage, 'connection', connection.connectionId) as S3Node } + // Handle case where s3Uri is "s3://" (all buckets access) + const isAllBucketsAccess = !s3Info.bucket + // Get S3 client from store const clientStore = ConnectionClientStore.getInstance() const s3Client = clientStore.getS3Client(connection.connectionId, region, connectionCredentialsProvider) @@ -146,7 +150,90 @@ export function createS3ConnectionNode( return telemetry.smus_renderS3Node.run(async (span) => { await recordDataConnectionTelemetry(span, connection, connectionCredentialsProvider) try { - if (isDefaultConnection && s3Info.prefix) { + if (isAllBucketsAccess) { + // For all buckets access (s3://), list all accessible buckets + try { + const buckets = await s3Client.listBuckets() + if (buckets.length === 0) { + return [createPlaceholderItem(NO_DATA_FOUND_MESSAGE) as S3Node] + } + + return buckets.map((bucket) => { + return new S3Node( + { + id: bucket.Name || 'unknown-bucket', + nodeType: NodeType.S3_BUCKET, + connectionType: ConnectionType.S3, + value: { bucket: bucket.Name }, + path: { + connection: connection.name, + bucket: bucket.Name, + }, + parent: node, + }, + async (bucketNode) => { + try { + const allPaths = [] + let nextToken: string | undefined + + do { + const result = await s3Client.listPaths( + bucket.Name || '', + undefined, + nextToken + ) + allPaths.push(...result.paths) + nextToken = result.nextToken + } while (nextToken) + + if (allPaths.length === 0) { + return [createPlaceholderItem(NO_DATA_FOUND_MESSAGE) as S3Node] + } + + return allPaths.map((path) => { + const nodeId = `${path.bucket}-${path.prefix || 'root'}` + + return new S3Node( + { + id: nodeId, + nodeType: path.isFolder ? NodeType.S3_FOLDER : NodeType.S3_FILE, + connectionType: ConnectionType.S3, + value: path, + path: { + connection: connection.name, + bucket: path.bucket, + key: path.prefix, + label: path.displayName, + }, + parent: bucketNode, + }, + path.isFolder + ? createFolderChildrenProvider(s3Client, path) + : undefined + ) + }) + } catch (err) { + logger.error(`Failed to list bucket contents: ${(err as Error).message}`) + const errorMessage = (err as Error).message + await handleCredExpiredError(err, true) + return [ + createErrorItem( + errorMessage, + 'bucket-contents-all-access', + bucketNode.id + ) as S3Node, + ] + } + } + ) + }) + } catch (err) { + logger.error(`Failed to list buckets: ${(err as Error).message}`) + const errorMessage = (err as Error).message + await handleCredExpiredError(err, true) + return [createErrorItem(errorMessage, 'list-buckets', node.id) as S3Node] + } + } else if (isDefaultConnection && s3Info.prefix) { // For default connections, show the full path as the first node const fullPath = `${s3Info.bucket}/${s3Info.prefix}` return [ @@ -208,7 +295,7 @@ export function createS3ConnectionNode( } catch (err) { logger.error(`Failed to list bucket contents: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [ createErrorItem( errorMessage, @@ -279,7 +366,7 @@ export function createS3ConnectionNode( } catch (err) { logger.error(`Failed to list bucket contents: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [ createErrorItem( errorMessage, @@ -295,7 +382,7 @@ export function createS3ConnectionNode( } catch (err) { logger.error(`Failed to create bucket node: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'bucket-node', node.id) as S3Node] } }) @@ -323,7 +410,7 @@ export async function createS3AccessGrantNodes( * Creates a children provider function for a folder node */ function createFolderChildrenProvider(s3Client: S3Client, folderPath: any): (node: S3Node) => Promise { - const logger = getLogger() + const logger = getLogger('smus') return async (node: S3Node) => { try { @@ -365,7 +452,7 @@ function createFolderChildrenProvider(s3Client: S3Client, folderPath: any): (nod } catch (err) { logger.error(`Failed to list folder contents: ${(err as Error).message}`) const errorMessage = (err as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'folder-contents', node.id) as S3Node] } } @@ -382,11 +469,22 @@ function parseS3Uri(connection: DataZoneConnection): { bucket: string; prefix?: return undefined } + // Handle case where s3Uri is just "s3://" (all buckets access) + if (s3Uri === 's3://') { + return { bucket: '', prefix: undefined } + } + // Parse S3 URI: s3://bucket-name/prefix/path/ const uriWithoutPrefix = s3Uri.replace('s3://', '') + + // Handle empty URI after removing prefix + if (!uriWithoutPrefix) { + return { bucket: '', prefix: undefined } + } + // Since the URI ends with a slash, the last item will be an empty string, so ignore it in the parts. const parts = uriWithoutPrefix.split('/').slice(0, -1) - const bucket = parts[0] + const bucket = parts[0] || '' // If parts only contains 1 item, then only a bucket was provided, and the key is empty. const prefix = parts.length > 1 ? parts.slice(1).join('/') + '/' : undefined @@ -400,7 +498,7 @@ async function listCallerAccessGrants( accountId: string, connectionId: string ): Promise { - const logger = getLogger() + const logger = getLogger('smus') try { const clientStore = ConnectionClientStore.getInstance() const s3ControlClient = clientStore.getS3ControlClient(connectionId, region, connectionCredentialsProvider) @@ -428,6 +526,7 @@ async function listCallerAccessGrants( return accessGrantNodes } catch (error) { logger.error(`Failed to list caller access grants: ${(error as Error).message}`) + await handleCredExpiredError(error) return [] } } @@ -484,7 +583,7 @@ async function fetchAccessGrantChildren( connectionCredentialsProvider: ConnectionCredentialsProvider, connectionId: string ): Promise { - const logger = getLogger() + const logger = getLogger('smus') const path = node.data.path try { @@ -593,7 +692,7 @@ async function fetchAccessGrantChildren( } catch (error) { logger.error(`Failed to fetch access grant children: ${(error as Error).message}`) const errorMessage = (error as Error).message - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(error, true) return [createErrorItem(errorMessage, 'access-grant-children', node.id) as S3Node] } } diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioAuthInfoNode.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioAuthInfoNode.ts index ff25f64cf74..cca19378c22 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioAuthInfoNode.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioAuthInfoNode.ts @@ -7,6 +7,9 @@ import * as vscode from 'vscode' import { TreeNode } from '../../../shared/treeview/resourceTreeDataProvider' import { SageMakerUnifiedStudioRootNode } from './sageMakerUnifiedStudioRootNode' import { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticationProvider' +import { SmusIamConnection } from '../../auth/model' +import { getContext } from '../../../shared/vscode/setContext' +import { loadSharedConfigFiles } from '@smithy/shared-ini-file-loader' /** * Node representing the SageMaker Unified Studio authentication information @@ -26,9 +29,12 @@ export class SageMakerUnifiedStudioAuthInfoNode implements TreeNode { this.authProvider.onDidChange(() => { this.onDidChangeEmitter.fire() }) + this.authProvider.onDidChangeActiveConnection(() => { + this.onDidChangeEmitter.fire() + }) } - public getTreeItem(): vscode.TreeItem { + public async getTreeItem(): Promise { // Use the cached authentication provider to check connection status const isConnected = this.authProvider.isConnected() const isValid = this.authProvider.isConnectionValid() @@ -38,37 +44,61 @@ export class SageMakerUnifiedStudioAuthInfoNode implements TreeNode { let region = 'Unknown' if (isConnected && this.authProvider.activeConnection) { - const conn = this.authProvider.activeConnection - domainId = conn.domainId || 'Unknown' - region = conn.ssoRegion || 'Unknown' + domainId = this.authProvider.getDomainId() || 'Unknown' + region = this.authProvider.getDomainRegion() || 'Unknown' } // Create display based on connection status let label: string let iconPath: vscode.ThemeIcon let tooltip: string + let description: string | undefined + + // Get profile name for IAM mode + const isIamMode = getContext('aws.smus.isIamMode') + let profileName: string | undefined + if (isIamMode) { + const activeConnection = this.authProvider.activeConnection! + const { configFile } = await loadSharedConfigFiles() + profileName = + (activeConnection as SmusIamConnection).profileName || (configFile['default'] ? 'default' : undefined) + } if (isConnected && isValid) { - label = `Domain: ${domainId}` + // Get session name and role ARN dynamically for IAM connections in IAM mode + let sessionName: string | undefined + let roleArn: string | undefined + if (isIamMode) { + sessionName = await this.authProvider.getSessionName() + roleArn = await this.authProvider.getIamPrincipalArn() + } + + // Format label with session name if available + const sessionSuffix = sessionName ? ` (session: ${sessionName})` : '' + label = isIamMode ? `Connected with profile: ${profileName}${sessionSuffix}` : `Domain: ${domainId}` iconPath = new vscode.ThemeIcon('key', new vscode.ThemeColor('charts.green')) - tooltip = `Connected to SageMaker Unified Studio\nDomain ID: ${domainId}\nRegion: ${region}\nStatus: Connected` + + // Add role ARN and session name to tooltip if available (role ARN before session) + const roleArnTooltip = roleArn ? `\nRole ARN: ${roleArn}` : '' + const sessionTooltip = sessionName ? `\nSession: ${sessionName}` : '' + tooltip = `Connected to SageMaker Unified Studio\n${isIamMode ? `Profile: ${profileName}` : `Domain ID: ${domainId}`}\nRegion: ${region}${roleArnTooltip}${sessionTooltip}\nStatus: Connected` + description = region } else if (isConnected && !isValid) { - label = `Domain: ${domainId} (Expired) - Click to reauthenticate` + label = isIamMode + ? `Profile: ${profileName} (Expired) - Click to reauthenticate` + : `Domain: ${domainId} (Expired) - Click to reauthenticate` iconPath = new vscode.ThemeIcon('warning', new vscode.ThemeColor('charts.yellow')) - tooltip = `Connection to SageMaker Unified Studio has expired\nDomain ID: ${domainId}\nRegion: ${region}\nStatus: Expired - Click to reauthenticate` + tooltip = `Connection to SageMaker Unified Studio has expired\n${isIamMode ? `Profile: ${profileName}` : `Domain ID: ${domainId}`}\nRegion: ${region}\nStatus: Expired - Click to reauthenticate` + description = region } else { label = 'Not Connected' iconPath = new vscode.ThemeIcon('circle-slash', new vscode.ThemeColor('charts.red')) tooltip = 'Not connected to SageMaker Unified Studio\nPlease sign in to access your projects' + description = undefined } const item = new vscode.TreeItem(label, vscode.TreeItemCollapsibleState.None) - // Add region as description (appears to the right) if connected - if (isConnected) { - item.description = region - } - // Add command for reauthentication when connection is expired if (isConnected && !isValid) { item.command = { @@ -81,6 +111,7 @@ export class SageMakerUnifiedStudioAuthInfoNode implements TreeNode { item.tooltip = tooltip item.contextValue = 'smusAuthInfo' item.iconPath = iconPath + item.description = description return item } diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioComputeNode.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioComputeNode.ts index 01293e7e523..eab4a58fbfb 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioComputeNode.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioComputeNode.ts @@ -12,6 +12,7 @@ import { SagemakerClient } from '../../../shared/clients/sagemaker' import { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticationProvider' import { SageMakerUnifiedStudioConnectionParentNode } from './sageMakerUnifiedStudioConnectionParentNode' import { ConnectionType } from '@aws-sdk/client-datazone' +import { getContext } from '../../../shared/vscode/setContext' export class SageMakerUnifiedStudioComputeNode implements TreeNode { public readonly id = 'smusComputeNode' @@ -37,12 +38,14 @@ export class SageMakerUnifiedStudioComputeNode implements TreeNode { const projectId = this.parent.getProject()?.id if (projectId) { - childrenNodes.push( - new SageMakerUnifiedStudioConnectionParentNode(this, ConnectionType.REDSHIFT, 'Data warehouse') - ) - childrenNodes.push( - new SageMakerUnifiedStudioConnectionParentNode(this, ConnectionType.SPARK, 'Data processing') - ) + if (!getContext('aws.smus.isIamMode')) { + childrenNodes.push( + new SageMakerUnifiedStudioConnectionParentNode(this, ConnectionType.REDSHIFT, 'Data warehouse') + ) + childrenNodes.push( + new SageMakerUnifiedStudioConnectionParentNode(this, ConnectionType.SPARK, 'Data processing') + ) + } this.spacesNode = new SageMakerUnifiedStudioSpacesParentNode( this, projectId, diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionNode.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionNode.ts index 969efa9823d..588d4f42062 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionNode.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionNode.ts @@ -12,7 +12,7 @@ import { ConnectionSummary, ConnectionType } from '@aws-sdk/client-datazone' export class SageMakerUnifiedStudioConnectionNode implements TreeNode { public resource: SageMakerUnifiedStudioConnectionNode contextValue: string - private readonly logger = getLogger() + private readonly logger = getLogger('smus') id: string public constructor( private readonly parent: SageMakerUnifiedStudioConnectionParentNode, diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionParentNode.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionParentNode.ts index a04377f0133..3bb7fa80222 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionParentNode.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionParentNode.ts @@ -8,7 +8,7 @@ import { SageMakerUnifiedStudioComputeNode } from './sageMakerUnifiedStudioCompu import { TreeNode } from '../../../shared/treeview/resourceTreeDataProvider' import { ListConnectionsCommandOutput, ConnectionType } from '@aws-sdk/client-datazone' import { SageMakerUnifiedStudioConnectionNode } from './sageMakerUnifiedStudioConnectionNode' -import { DataZoneClient } from '../../shared/client/datazoneClient' +import { createDZClientBaseOnDomainMode } from './utils' // eslint-disable-next-line id-length export class SageMakerUnifiedStudioConnectionParentNode implements TreeNode { @@ -31,7 +31,7 @@ export class SageMakerUnifiedStudioConnectionParentNode implements TreeNode { } public async getChildren(): Promise { - const client = await DataZoneClient.getInstance(this.parent.authProvider) + const client = await createDZClientBaseOnDomainMode(this.parent.authProvider) this.connections = await client.fetchConnections( this.parent.parent.project?.domainId, this.parent.parent.project?.id, diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioDataNode.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioDataNode.ts index 4294a3e42f4..20d181bd1c3 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioDataNode.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioDataNode.ts @@ -8,15 +8,24 @@ import { TreeNode } from '../../../shared/treeview/resourceTreeDataProvider' import { getIcon } from '../../../shared/icons' import { getLogger } from '../../../shared/logger/logger' -import { DataZoneClient, DataZoneConnection, DataZoneProject } from '../../shared/client/datazoneClient' +import { DataZoneConnection, DataZoneProject } from '../../shared/client/datazoneClient' import { createS3ConnectionNode, createS3AccessGrantNodes } from './s3Strategy' import { createRedshiftConnectionNode } from './redshiftStrategy' import { createLakehouseConnectionNode } from './lakehouseStrategy' import { SageMakerUnifiedStudioProjectNode } from './sageMakerUnifiedStudioProjectNode' import { isFederatedConnection, createErrorItem } from './utils' import { createPlaceholderItem } from '../../../shared/treeview/utils' -import { ConnectionType, NO_DATA_FOUND_MESSAGE } from './types' +import { + ConnectionType, + DATA_DEFAULT_S3_CONNECTION_NAME_REGEXP, + NO_DATA_FOUND_MESSAGE, + S3_PROJECT_NON_GIT_PROJECT_REPOSITORY_LOCATION_NAME_REGEXP, +} from './types' import { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticationProvider' +import { createFederatedConnectionNode } from './federatedConnectionStrategy' +import { createDZClientForProject } from './utils' +import { getContext } from '../../../shared/vscode/setContext' +import { handleCredExpiredError } from '../../shared/credentialExpiryHandler' /** * Tree node representing a Data folder that contains S3 and Redshift connections @@ -24,7 +33,7 @@ import { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticat export class SageMakerUnifiedStudioDataNode implements TreeNode { public readonly id = 'smusDataExplorer' public readonly resource = {} - private readonly logger = getLogger() + private readonly logger = getLogger('smus') private childrenNodes: TreeNode[] | undefined private readonly authProvider: SmusAuthenticationProvider @@ -57,7 +66,8 @@ export class SageMakerUnifiedStudioDataNode implements TreeNode { return [createErrorItem(errorMessage, 'project', this.id)] } - const datazoneClient = await DataZoneClient.getInstance(this.authProvider) + const datazoneClient = await createDZClientForProject(this.authProvider, project.id) + const connections = await datazoneClient.listConnections(project.domainId, undefined, project.id) this.logger.info(`Found ${connections.length} connections for project ${project.id}`) @@ -74,7 +84,7 @@ export class SageMakerUnifiedStudioDataNode implements TreeNode { const projectInfo = project ? `project: ${project.id}, domain: ${project.domainId}` : 'unknown project' const errorMessage = 'Failed to get connections' this.logger.error(`Failed to get connections for ${projectInfo}: ${(err as Error).message}`) - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(err, true) return [createErrorItem(errorMessage, 'connections', this.id)] } } @@ -105,15 +115,23 @@ export class SageMakerUnifiedStudioDataNode implements TreeNode { } // Add Redshift nodes second - for (const connection of redshiftConnections) { - if (connection.name.startsWith('project.lakehouse')) { - continue + if (!getContext('aws.smus.isIamMode')) { + for (const connection of redshiftConnections) { + if (connection.name.startsWith('project.lakehouse')) { + continue + } + if (isFederatedConnection(connection)) { + continue + } + const node = await this.createRedshiftNode(project, connection, region) + dataNodes.push(node) } - if (isFederatedConnection(connection)) { - continue + } else { + const federatedConnections = connections.filter((conn) => isFederatedConnection(conn)) + if (federatedConnections.length > 0) { + const connectionsNode = this.createConnectionsParentNode(project, federatedConnections, region) + dataNodes.push(connectionsNode) } - const node = await this.createRedshiftNode(project, connection, region) - dataNodes.push(node) } // Add S3 Bucket parent node last @@ -132,37 +150,30 @@ export class SageMakerUnifiedStudioDataNode implements TreeNode { region: string ): Promise { try { - const datazoneClient = await DataZoneClient.getInstance(this.authProvider) - const getConnectionResponse = await datazoneClient.getConnection({ - domainIdentifier: project.domainId, - identifier: connection.connectionId, - withSecret: true, - }) - const connectionCredentialsProvider = await this.authProvider.getConnectionCredentialsProvider( connection.connectionId, project.id, - getConnectionResponse.location?.awsRegion || region + connection.location?.awsRegion || region ) const s3ConnectionNode = createS3ConnectionNode( connection, connectionCredentialsProvider, - getConnectionResponse.location?.awsRegion || region + connection.location?.awsRegion || region ) const accessGrantNodes = await createS3AccessGrantNodes( connection, connectionCredentialsProvider, - getConnectionResponse.location?.awsRegion || region, - getConnectionResponse.location?.awsAccountId + connection.location?.awsRegion || region, + connection.location?.awsAccountId ) return [s3ConnectionNode, ...accessGrantNodes] } catch (connErr) { const errorMessage = `Failed to get S3 connection - ${(connErr as Error).message}` this.logger.error(`Failed to get S3 connection details: ${(connErr as Error).message}`) - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(connErr, true) return [createErrorItem(errorMessage, `s3-${connection.connectionId}`, this.id)] } } @@ -173,7 +184,7 @@ export class SageMakerUnifiedStudioDataNode implements TreeNode { region: string ): Promise { try { - const datazoneClient = await DataZoneClient.getInstance(this.authProvider) + const datazoneClient = await createDZClientForProject(this.authProvider, project.id) const getConnectionResponse = await datazoneClient.getConnection({ domainIdentifier: project.domainId, identifier: connection.connectionId, @@ -190,7 +201,7 @@ export class SageMakerUnifiedStudioDataNode implements TreeNode { } catch (connErr) { const errorMessage = `Failed to get Redshift connection - ${(connErr as Error).message}` this.logger.error(`Failed to get Redshift connection details: ${(connErr as Error).message}`) - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(connErr, true) return createErrorItem(errorMessage, `redshift-${connection.connectionId}`, this.id) } } @@ -201,24 +212,17 @@ export class SageMakerUnifiedStudioDataNode implements TreeNode { region: string ): Promise { try { - const datazoneClient = await DataZoneClient.getInstance(this.authProvider) - const getConnectionResponse = await datazoneClient.getConnection({ - domainIdentifier: project.domainId, - identifier: connection.connectionId, - withSecret: true, - }) - const connectionCredentialsProvider = await this.authProvider.getConnectionCredentialsProvider( connection.connectionId, project.id, - getConnectionResponse.location?.awsRegion || region + connection.location?.awsRegion || region ) return createLakehouseConnectionNode(connection, connectionCredentialsProvider, region) } catch (connErr) { const errorMessage = `Failed to get Lakehouse connection - ${(connErr as Error).message}` this.logger.error(`Failed to get Lakehouse connection details: ${(connErr as Error).message}`) - void vscode.window.showErrorMessage(errorMessage) + await handleCredExpiredError(connErr, true) return createErrorItem(errorMessage, `lakehouse-${connection.connectionId}`, this.id) } } @@ -237,8 +241,26 @@ export class SageMakerUnifiedStudioDataNode implements TreeNode { return item }, getChildren: async () => { + // Filter connections inside the bucket parent node + const defaultS3Connection = s3Connections.find((conn) => + DATA_DEFAULT_S3_CONNECTION_NAME_REGEXP.test(conn.name) + ) + const otherS3Connections = s3Connections.filter( + (conn) => + !DATA_DEFAULT_S3_CONNECTION_NAME_REGEXP.test(conn.name) && + !S3_PROJECT_NON_GIT_PROJECT_REPOSITORY_LOCATION_NAME_REGEXP.test(conn.name) + ) + const s3Nodes: TreeNode[] = [] - for (const connection of s3Connections) { + + // Add default connections first + if (defaultS3Connection) { + const defaultS3Node = await this.createS3Node(project, defaultS3Connection, region) + s3Nodes.push(...defaultS3Node) + } + + // Add other connections + for (const connection of otherS3Connections) { const nodes = await this.createS3Node(project, connection, region) s3Nodes.push(...nodes) } @@ -247,4 +269,47 @@ export class SageMakerUnifiedStudioDataNode implements TreeNode { getParent: () => this, } } + + private createConnectionsParentNode( + project: DataZoneProject, + federatedConnections: DataZoneConnection[], + region: string + ): TreeNode { + return { + id: 'connections-parent', + resource: {}, + getTreeItem: () => { + const item = new vscode.TreeItem('Connections', vscode.TreeItemCollapsibleState.Collapsed) + item.contextValue = 'connectionsFolder' + return item + }, + getChildren: async () => { + const nodes: TreeNode[] = [] + for (const connection of federatedConnections) { + try { + const connectionCredentialsProvider = await this.authProvider.getConnectionCredentialsProvider( + connection.connectionId, + project.id, + connection.location?.awsRegion || region + ) + const node = await createFederatedConnectionNode( + connection, + connectionCredentialsProvider, + region + ) + nodes.push(node) + } catch (err) { + const errorMessage = `Failed to create federated connection - ${(err as Error).message}` + this.logger.error( + `Failed to create federated connection ${connection.name}: ${(err as Error).message}` + ) + nodes.push(createErrorItem(errorMessage, `federated-${connection.connectionId}`, this.id)) + await handleCredExpiredError(err) + } + } + return nodes + }, + getParent: () => this, + } + } } diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.ts index 114ffe77212..6b41189b78b 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.ts @@ -17,6 +17,11 @@ import { SageMakerUnifiedStudioComputeNode } from './sageMakerUnifiedStudioCompu import { getIcon } from '../../../shared/icons' import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils' import { getContext } from '../../../shared/vscode/setContext' +import { ToolkitError } from '../../../shared/errors' +import { SmusErrorCodes } from '../../shared/smusUtils' +import { handleCredExpiredError } from '../../shared/credentialExpiryHandler' +import { SmusIamConnection } from '../../auth/model' +import { createDZClientBaseOnDomainMode, createErrorItem } from './utils' /** * Tree node representing a SageMaker Unified Studio project @@ -28,7 +33,7 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode { public readonly onDidChangeTreeItem = this.onDidChangeEmitter.event public readonly onDidChangeChildren = this.onDidChangeEmitter.event public project?: DataZoneProject - private logger = getLogger() + private logger = getLogger('smus') private sagemakerClient?: SagemakerClient private hasShownFirstTimeMessage = false private isFirstTimeSelection = false @@ -82,6 +87,10 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode { return telemetry.smus_renderProjectChildrenNode.run(async (span) => { try { const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') + + // Get auth mode directly from connection type + const authMode = this.authProvider.activeConnection?.type + const accountId = await this.authProvider.getDomainAccountId() span.record({ smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', @@ -89,26 +98,34 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode { smusDomainAccountId: accountId, smusProjectId: this.project?.id, smusDomainRegion: this.authProvider.getDomainRegion(), + ...(authMode && { smusAuthMode: authMode }), }) // Skip access check if we're in SMUS space environment (already in project space) if (!getContext('aws.smus.inSmusSpaceEnvironment')) { - const hasAccess = await this.checkProjectCredsAccess(this.project!.id) - if (!hasAccess) { - return [ - { - id: 'smusProjectAccessDenied', - resource: {}, - getTreeItem: () => { - const item = new vscode.TreeItem( - 'You do not have access to this project. Contact your administrator.', - vscode.TreeItemCollapsibleState.None - ) - return item + try { + const hasAccess = await this.checkProjectCredsAccess(this.project!.id) + if (!hasAccess) { + return [ + { + id: 'smusProjectAccessDenied', + resource: {}, + getTreeItem: () => { + const item = new vscode.TreeItem( + 'You do not have access to this project. Contact your administrator.', + vscode.TreeItemCollapsibleState.None + ) + return item + }, + getParent: () => this, }, - getParent: () => this, - }, - ] + ] + } + } catch (err) { + const errorMessage = (err as Error).message + this.logger.error('Failed to check project credentials: %s', errorMessage) + await handleCredExpiredError(err, true) + return [createErrorItem(`Failed to load the project`, this.project?.id || '', this.id)] } } @@ -119,7 +136,7 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode { return [dataNode] } - const dzClient = await DataZoneClient.getInstance(this.authProvider) + const dzClient = await createDZClientBaseOnDomainMode(this.authProvider) if (!this.project?.id) { throw new Error('Project ID is required') } @@ -145,6 +162,7 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode { return [dataNode, computeNode] } catch (err) { this.logger.error('Failed to select project: %s', (err as Error).message) + await handleCredExpiredError(err) throw err } }) @@ -201,8 +219,9 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode { this.logger.debug( 'Access denied when obtaining project credentials, user likely lacks project access or role permissions' ) + return false } - return false + throw err } } @@ -212,7 +231,7 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode { } try { - const dzClient = await DataZoneClient.getInstance(this.authProvider) + const dzClient = await createDZClientBaseOnDomainMode(this.authProvider) const projectDetails = await dzClient.getProject(this.project.id) if (projectDetails && projectDetails.name) { @@ -231,10 +250,35 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode { if (!this.project) { throw new Error('No project selected for initializing SageMaker client') } - const projectProvider = await this.authProvider.getProjectCredentialProvider(this.project.id) - this.logger.info(`Successfully obtained project credentials provider for project ${this.project.id}`) - const awsCredentialProvider = async (): Promise => { - return await projectProvider.getCredentials() + let awsCredentialProvider + if (getContext('aws.smus.isIamMode')) { + const datazoneClient = DataZoneClient.createWithCredentials( + this.authProvider.getDomainRegion(), + this.authProvider.getDomainId(), + await this.authProvider.getCredentialsProviderForIamProfile( + (this.authProvider.activeConnection as SmusIamConnection).profileName + ) + ) + const projectId = this.project.id + awsCredentialProvider = async (): Promise => { + const creds = await datazoneClient.getProjectDefaultEnvironmentCreds(projectId) + if (!creds.accessKeyId || !creds.secretAccessKey) { + throw new ToolkitError('Missing default environment credentials', { + code: SmusErrorCodes.CredentialRetrievalFailed, + }) + } + return { + accessKeyId: creds.accessKeyId!, + secretAccessKey: creds.secretAccessKey!, + sessionToken: creds.sessionToken, + } + } + } else { + const projectProvider = await this.authProvider.getProjectCredentialProvider(this.project.id) + this.logger.info(`Successfully obtained project credentials provider for project ${this.project.id}`) + awsCredentialProvider = async (): Promise => { + return await projectProvider.getCredentials() + } } const sagemakerClient = new SagemakerClient(regionCode, awsCredentialProvider) return sagemakerClient diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.ts index db3f6959969..2051ac7c52b 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.ts @@ -7,15 +7,22 @@ import * as vscode from 'vscode' import { TreeNode } from '../../../shared/treeview/resourceTreeDataProvider' import { getIcon } from '../../../shared/icons' import { getLogger } from '../../../shared/logger/logger' -import { DataZoneClient, DataZoneProject } from '../../shared/client/datazoneClient' +import { DataZoneProject, DataZoneClient } from '../../shared/client/datazoneClient' import { Commands } from '../../../shared/vscode/commands2' import { telemetry } from '../../../shared/telemetry/telemetry' import { createQuickPick } from '../../../shared/ui/pickerPrompter' import { SageMakerUnifiedStudioProjectNode } from './sageMakerUnifiedStudioProjectNode' import { SageMakerUnifiedStudioAuthInfoNode } from './sageMakerUnifiedStudioAuthInfoNode' import { SmusErrorCodes, SmusUtils } from '../../shared/smusUtils' +import { handleCredExpiredError } from '../../shared/credentialExpiryHandler' import { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticationProvider' import { ToolkitError } from '../../../../src/shared/errors' +import { SmusAuthenticationMethod } from '../../auth/ui/authenticationMethodSelection' +import { SmusAuthenticationOrchestrator } from '../../auth/authenticationOrchestrator' +import { isSmusSsoConnection, isSmusIamConnection } from '../../auth/model' +import { getContext } from '../../../shared/vscode/setContext' +import { createDZClientBaseOnDomainMode } from './utils' +import { DataZoneCustomClientHelper } from '../../shared/client/datazoneCustomClientHelper' import { recordAuthTelemetry } from '../../shared/telemetry' const contextValueSmusRoot = 'sageMakerUnifiedStudioRoot' @@ -27,7 +34,7 @@ const projectPickerPlaceholder = 'Select project' export class SageMakerUnifiedStudioRootNode implements TreeNode { public readonly id = 'smusRootNode' public readonly resource = this - private readonly logger = getLogger() + private readonly logger = getLogger('smus') private readonly projectNode: SageMakerUnifiedStudioProjectNode private readonly authInfoNode: SageMakerUnifiedStudioAuthInfoNode private readonly onDidChangeEmitter = new vscode.EventEmitter() @@ -132,7 +139,7 @@ export class SageMakerUnifiedStudioRootNode implements TreeNode { ] } - // When authenticated, show auth info and projects + // When authenticated, show auth info and projects (same for both IAM and non-IAM mode) return [this.authInfoNode, this.projectNode] } @@ -156,7 +163,7 @@ export class SageMakerUnifiedStudioRootNode implements TreeNode { try { // Check if the connection is valid using the authentication provider const result = this.authProvider.isConnectionValid() - this.logger.debug(`SMUS Root Node: Authentication check result: ${result}`) + this.logger.debug(`Authentication check result: ${result}`) return result } catch (err) { this.logger.debug('Authentication check failed: %s', (err as Error).message) @@ -177,9 +184,14 @@ export class SageMakerUnifiedStudioRootNode implements TreeNode { const hasExpiredConnection = activeConnection && !isConnectionValid if (hasExpiredConnection) { - this.logger.debug('SMUS Root Node: Connection is expired, showing reauthentication prompt') - // Show reauthentication prompt to user - void this.authProvider.showReauthenticationPrompt(activeConnection as any) + this.logger.debug('Connection is expired') + // Only show reauthentication prompt for SSO connections, not IAM connections + if (isSmusSsoConnection(activeConnection)) { + this.logger.debug('Showing reauthentication prompt for SSO connection') + void this.authProvider.showReauthenticationPrompt(activeConnection) + } else { + this.logger.debug('Skipping reauthentication prompt for non-SSO connection') + } return true } return false @@ -194,7 +206,7 @@ export class SageMakerUnifiedStudioRootNode implements TreeNode { * Command to open the SageMaker Unified Studio documentation */ export const smusLearnMoreCommand = Commands.declare('aws.smus.learnMore', () => async () => { - const logger = getLogger() + const logger = getLogger('smus') try { // Open the SageMaker Unified Studio documentation await vscode.env.openExternal(vscode.Uri.parse('https://aws.amazon.com/sagemaker/unified-studio/')) @@ -220,89 +232,124 @@ export const smusLearnMoreCommand = Commands.declare('aws.smus.learnMore', () => /** * Command to login to SageMaker Unified Studio */ -export const smusLoginCommand = Commands.declare('aws.smus.login', () => async () => { - const logger = getLogger() +export const smusLoginCommand = Commands.declare('aws.smus.login', (context: vscode.ExtensionContext) => async () => { + const logger = getLogger('smus') return telemetry.smus_login.run(async (span) => { try { - // Get DataZoneClient instance for URL validation - - // Show domain URL input dialog - const domainUrl = await vscode.window.showInputBox({ - title: 'SageMaker Unified Studio Authentication', - prompt: 'Enter your SageMaker Unified Studio Domain URL', - placeHolder: 'https://.sagemaker..on.aws', - validateInput: (value) => SmusUtils.validateDomainUrl(value), - }) - - if (!domainUrl) { - // User cancelled - logger.debug('User cancelled domain URL input') - throw new ToolkitError('User cancelled domain URL input', { - cancelled: true, - code: SmusErrorCodes.UserCancelled, - }) - } - - // Show a simple status bar message instead of progress dialog - vscode.window.setStatusBarMessage('Connecting to SageMaker Unified Studio...', 10000) - - try { - // Get the authentication provider instance - const authProvider = SmusAuthenticationProvider.fromContext() + // Get the authentication provider instance + const authProvider = SmusAuthenticationProvider.fromContext() - // Connect to SMUS using the authentication provider - const connection = await authProvider.connectToSmus(domainUrl) + // Import authentication method selection components + const { SmusAuthenticationMethodSelector } = await import('../../auth/ui/authenticationMethodSelection.js') + const { SmusAuthenticationPreferencesManager } = await import( + '../../auth/preferences/authenticationPreferences.js' + ) - if (!connection) { - throw new ToolkitError('Failed to establish connection', { - code: SmusErrorCodes.FailedAuthConnecton, - }) + // Check for preferred authentication method + const preferredMethod = SmusAuthenticationPreferencesManager.getPreferredMethod(context) + logger.debug(`Retrieved preferred method: ${preferredMethod}`) + + let selectedMethod: SmusAuthenticationMethod | undefined = preferredMethod + let authCompleted = false + + // Main authentication loop - handles back navigation + while (!authCompleted) { + // Check if we should skip method selection (user has a remembered preference) + if (selectedMethod) { + logger.debug(`Using authentication method: ${selectedMethod}`) + } else { + // Show authentication method selection dialog + logger.debug('Showing authentication method selection dialog') + const methodSelection = await SmusAuthenticationMethodSelector.showAuthenticationMethodSelection() + selectedMethod = methodSelection.method } - // Extract domain account ID, domain ID, and region for logging - const domainId = connection.domainId - const region = connection.ssoRegion - - logger.info(`Connected to SageMaker Unified Studio domain: ${domainId} in region ${region}`) - await recordAuthTelemetry(span, authProvider, domainId, region) + // Handle the selected authentication method + logger.debug(`Processing authentication method: ${selectedMethod}`) + if (selectedMethod === 'sso') { + // SSO Authentication - use SSO flow + const ssoResult = await SmusAuthenticationOrchestrator.handleSsoAuthentication( + authProvider, + span, + context + ) + + if (ssoResult.status === 'BACK') { + // User wants to go back to authentication method selection + selectedMethod = undefined // Reset to show method selection again + continue // Restart the loop + } + + authCompleted = true + } else { + // IAM Authentication - use new IAM profile selection flow + const iamResult = await SmusAuthenticationOrchestrator.handleIamAuthentication( + authProvider, + span, + context + ) + + if (iamResult.status === 'BACK') { + // User wants to go back to authentication method selection + selectedMethod = undefined // Reset to show method selection again + continue // Restart the loop + } + + if (iamResult.status === 'EDITING') { + // User is editing credentials, show helpful message with option to return to profile selection + const action = await vscode.window.showInformationMessage( + 'Complete your AWS credential setup and try again, or return to profile selection.', + 'Select Profile', + 'Done' + ) - // Show success message - void vscode.window.showInformationMessage( - `Successfully connected to SageMaker Unified Studio domain: ${domainId}` - ) + if (action === 'Select Profile') { + // User wants to return to profile selection, continue the loop + continue + } else { + // User chose "Done" or dismissed, exit the authentication flow + throw new ToolkitError('User cancelled credential setup', { + code: SmusErrorCodes.UserCancelled, + cancelled: true, + }) + } + } + + if (iamResult.status === 'INVALID_PROFILE') { + // Profile validation failed, show error with option to select another profile + const action = await vscode.window.showErrorMessage( + `${iamResult.error}`, + 'Select Another Profile', + 'Cancel' + ) - // Clear the status bar message - vscode.window.setStatusBarMessage('Connected to SageMaker Unified Studio', 3000) + if (action === 'Select Another Profile') { + // User wants to select a different profile, continue the loop + continue + } else { + // User chose "Cancel" or dismissed, exit the authentication flow + throw new ToolkitError('User cancelled profile selection', { + code: SmusErrorCodes.UserCancelled, + cancelled: true, + }) + } + } - // Immediately refresh the tree view to show authenticated state - try { - await vscode.commands.executeCommand('aws.smus.rootView.refresh') - } catch (refreshErr) { - logger.debug(`Failed to refresh views after login: ${(refreshErr as Error).message}`) + authCompleted = true } - } catch (connectionErr) { - // Clear the status bar message - vscode.window.setStatusBarMessage('Connection to SageMaker Unified Studio Failed') - - // Log the error and re-throw to be handled by the outer catch block - logger.error('Connection failed: %s', (connectionErr as Error).message) - throw new ToolkitError('Connection failed.', { - cause: connectionErr as Error, - code: (connectionErr as Error).name, - }) } + + // Record telemetry with connection details after successful login + const domainId = authProvider.getDomainId?.() + const region = authProvider.getDomainRegion?.() + await recordAuthTelemetry(span, authProvider, domainId, region) } catch (err) { const isUserCancelled = err instanceof ToolkitError && err.code === SmusErrorCodes.UserCancelled if (!isUserCancelled) { - void vscode.window.showErrorMessage( - `SageMaker Unified Studio: Failed to initiate login: ${(err as Error).message}` - ) + void vscode.window.showErrorMessage(`Failed to initiate login: ${(err as Error).message}`) + logger.error('Failed to initiate login: %s', (err as Error).message) } - logger.error('Failed to initiate login: %s', (err as Error).message) - throw new ToolkitError('Failed to initiate login.', { - cause: err as Error, - code: (err as Error).name, - }) + throw err } }) }) @@ -310,66 +357,71 @@ export const smusLoginCommand = Commands.declare('aws.smus.login', () => async ( /** * Command to sign out from SageMaker Unified Studio */ -export const smusSignOutCommand = Commands.declare('aws.smus.signOut', () => async () => { - const logger = getLogger() - return telemetry.smus_signOut.run(async (span) => { - try { - // Get the authentication provider instance - const authProvider = SmusAuthenticationProvider.fromContext() +export const smusSignOutCommand = Commands.declare( + 'aws.smus.signOut', + (context: vscode.ExtensionContext) => async () => { + const logger = getLogger('smus') + return telemetry.smus_signOut.run(async (span) => { + try { + // Get the authentication provider instance + const authProvider = SmusAuthenticationProvider.fromContext() - // Check if there's an active connection to sign out from - if (!authProvider.isConnected()) { - void vscode.window.showInformationMessage( - 'No active SageMaker Unified Studio connection to sign out from.' - ) - return - } + // Check if there's an active connection to sign out from + if (!authProvider.isConnected()) { + void vscode.window.showInformationMessage( + 'No active SageMaker Unified Studio connection to sign out from.' + ) + return + } - // Get connection details for logging - const activeConnection = authProvider.activeConnection - const domainId = activeConnection?.domainId - const region = activeConnection?.ssoRegion + // Capture connection details BEFORE signing out (for telemetry) + const activeConnection = authProvider.activeConnection + const domainId = authProvider.getDomainId?.() + const region = authProvider.getDomainRegion?.() - // Show status message - vscode.window.setStatusBarMessage('Signing out from SageMaker Unified Studio...', 5000) - await recordAuthTelemetry(span, authProvider, domainId, region) + // Record telemetry with captured values BEFORE signing out + await recordAuthTelemetry(span, authProvider, domainId, region) - // Delete the connection (this will also invalidate tokens and clear cache) - if (activeConnection) { - await authProvider.secondaryAuth.deleteConnection() - logger.info(`Signed out from SageMaker Unified Studio${domainId}`) - } + // Sign out from SMUS (behavior depends on connection type) + if (activeConnection) { + await authProvider.signOut() + logger.info(`Signed out from SageMaker Unified Studio: ${domainId}`) - // Show success message - void vscode.window.showInformationMessage('Successfully signed out from SageMaker Unified Studio.') + // Clear connection-specific preferences on sign out (but keep auth method preference) + const { SmusAuthenticationPreferencesManager } = await import( + '../../auth/preferences/authenticationPreferences.js' + ) + await SmusAuthenticationPreferencesManager.clearConnectionPreferences(context) + } - // Clear the status bar message - vscode.window.setStatusBarMessage('Signed out from SageMaker Unified Studio', 3000) + // Show success message + void vscode.window.showInformationMessage('Successfully signed out from SageMaker Unified Studio.') - // Refresh the tree view to show the sign-in state - try { - await vscode.commands.executeCommand('aws.smus.rootView.refresh') - } catch (refreshErr) { - logger.debug(`Failed to refresh views after sign out: ${(refreshErr as Error).message}`) - throw new ToolkitError('Failed to refresh views after sign out.', { - cause: refreshErr as Error, - code: (refreshErr as Error).name, + // Refresh the tree view to show the sign-in state + try { + await vscode.commands.executeCommand('aws.smus.rootView.refresh') + } catch (refreshErr) { + logger.debug(`Failed to refresh views after sign out: ${(refreshErr as Error).message}`) + throw new ToolkitError('Failed to refresh views after sign out.', { + cause: refreshErr as Error, + code: (refreshErr as Error).name, + }) + } + } catch (err) { + void vscode.window.showErrorMessage( + `SageMaker Unified Studio: Failed to sign out: ${(err as Error).message}` + ) + logger.error('Failed to sign out: %s', (err as Error).message) + + // Log failure telemetry + throw new ToolkitError('Failed to sign out.', { + cause: err as Error, + code: (err as Error).name, }) } - } catch (err) { - void vscode.window.showErrorMessage( - `SageMaker Unified Studio: Failed to sign out: ${(err as Error).message}` - ) - logger.error('Failed to sign out: %s', (err as Error).message) - - // Log failure telemetry - throw new ToolkitError('Failed to sign out.', { - cause: err as Error, - code: (err as Error).name, - }) - } - }) -}) + }) + } +) function isAccessDenied(error: Error): boolean { return error.name.includes('AccessDenied') @@ -399,8 +451,81 @@ async function showQuickPick(items: any[]) { return await quickPick.prompt() } +/** + * Fetches projects filtered by IAM principal + * For IAM users: filters by user profile using userIdentifier + * For IAM role sessions: filters by group profile using groupIdentifier + * @param authProvider The SMUS authentication provider + * @param datazoneClient The DataZone client instance + * @returns Promise resolving to filtered projects array + * @throws Error if profile retrieval fails + */ +async function fetchProjectsByIamProfile( + authProvider: SmusAuthenticationProvider, + datazoneClient: DataZoneClient +): Promise { + const logger = getLogger('smus') + + // Get credentials provider for IAM profile + const activeConnection = authProvider.activeConnection + if (!isSmusIamConnection(activeConnection)) { + throw new Error('Active connection is not a valid IAM connection') + } + + // Use cached caller identity ARN from auth provider + const callerIdentityArn = await authProvider.getIamPrincipalArn() + if (!callerIdentityArn) { + throw new Error('Unable to retrieve caller identity ARN from cache') + } + + // Determine if this is an IAM user or IAM role session using utility method + const isIamUser = SmusUtils.isIamUserArn(callerIdentityArn) + logger.debug( + `Using cached caller identity ARN: ${callerIdentityArn}. Identity type: ${isIamUser ? 'IAM User' : 'IAM Role Session'}` + ) + + let projects: DataZoneProject[] + + if (isIamUser) { + // IAM User flow - use GetUserProfile and filter by userIdentifier + logger.debug('Using IAM user flow with GetUserProfile API') + + // Get user profile ID for the IAM user using DataZone client + const userProfileId = await datazoneClient.getUserProfileIdForIamPrincipal( + callerIdentityArn, + authProvider.getDomainId() + ) + logger.info(`Retrieved user profile ID: ${userProfileId} for IAM principal ${callerIdentityArn}`) + + // Fetch projects filtered by user profile + projects = await datazoneClient.fetchAllProjects({ userIdentifier: userProfileId }) + logger.debug(`Fetched ${projects.length} projects for user profile ${userProfileId}`) + } else { + const credentialsProvider = await authProvider.getCredentialsProviderForIamProfile(activeConnection.profileName) + const datazoneCustomClientHelper = DataZoneCustomClientHelper.getInstance( + credentialsProvider, + authProvider.getDomainRegion() + ) + + // IAM Role Session flow - use SearchGroupProfile and filter by groupIdentifier + // The cached ARN needs conversion for role sessions + const roleArn = SmusUtils.convertAssumedRoleArnToIamRoleArn(callerIdentityArn) + logger.debug(`Using IAM role ARN: ${roleArn}`) + + // Get group profile ID for the current role + const groupProfileId = await datazoneCustomClientHelper.getGroupProfileId(authProvider.getDomainId(), roleArn) + logger.info(`Retrieved group profile ID: ${groupProfileId}`) + + // Fetch projects filtered by group profile + projects = await datazoneClient.fetchAllProjects({ groupIdentifier: groupProfileId }) + logger.debug(`Fetched ${projects.length} projects for group profile ${groupProfileId}`) + } + + return projects +} + export async function selectSMUSProject(projectNode?: SageMakerUnifiedStudioProjectNode) { - const logger = getLogger() + const logger = getLogger('smus') return telemetry.smus_accessProject.run(async (span) => { try { @@ -410,22 +535,73 @@ export async function selectSMUSProject(projectNode?: SageMakerUnifiedStudioProj return } - const client = await DataZoneClient.getInstance(authProvider) + const datazoneClient = await createDZClientBaseOnDomainMode(authProvider) logger.debug('DataZone client instance obtained successfully') - const allProjects = await client.fetchAllProjects() + let allProjects: DataZoneProject[] + + if (getContext('aws.smus.isIamMode')) { + // Filter projects by IAM profile (user or role session) + try { + allProjects = await fetchProjectsByIamProfile(authProvider, datazoneClient) + } catch (err) { + const error = err as Error + + // Handle no profile found (user or group) + if ( + error instanceof ToolkitError && + (error.code === SmusErrorCodes.NoGroupProfileFound || + error.code === SmusErrorCodes.NoUserProfileFound) + ) { + logger.error('No profile found for IAM principal: %s', error.message) + + const principalArn = await authProvider.getIamPrincipalArn() + const arnSuffix = principalArn ? `: ${principalArn}` : '' + void vscode.window.showErrorMessage( + `No resources found for IAM principal${arnSuffix}. Ensure SageMaker Unified Studio resources exist for this IAM principal.` + ) + return error + } + + // Handle access denied + if (isAccessDenied(error)) { + logger.error('Access denied when retrieving profile: %s', error.message) + void vscode.window.showErrorMessage( + "You don't have permissions to access this resource. Please contact your administrator" + ) + return error + } + + // Handle other errors + logger.error('Failed to retrieve profile information: %s', error.message) + void vscode.window.showErrorMessage('Failed to fetch IAM principal information. Try again.') + return error + } + } else { + // In non-IAM mode, fetch all projects without filtering + allProjects = await datazoneClient.fetchAllProjects() + } + const items = createProjectQuickPickItems(allProjects) + // Handle no projects scenario if (items.length === 0) { - logger.info('No projects found in the domain') - void vscode.window.showInformationMessage('No projects found in the domain') - await showQuickPick([{ label: 'No projects found', detail: '', description: '', data: {} }]) + if (getContext('aws.smus.isIamMode')) { + logger.debug('No accessible projects found for IAM principal') + void vscode.window.showInformationMessage('No accessible projects found for your IAM principal') + } else { + logger.debug('No projects found in the domain') + void vscode.window.showInformationMessage('No projects found in the domain') + } return } + // Show project picker const selectedProject = await showQuickPick(items) + const accountId = await authProvider.getDomainAccountId() span.record({ + smusAuthMode: authProvider.activeConnection?.type, smusDomainId: authProvider.getDomainId(), smusProjectId: (selectedProject as DataZoneProject).id as string | undefined, smusDomainRegion: authProvider.getDomainRegion(), @@ -446,7 +622,9 @@ export async function selectSMUSProject(projectNode?: SageMakerUnifiedStudioProj } catch (err) { const error = err as Error + // Handle access denied scenarios if (isAccessDenied(error)) { + logger.error('Access denied when fetching projects: %s', error.message) await showQuickPick([ { label: '$(error)', @@ -456,8 +634,9 @@ export async function selectSMUSProject(projectNode?: SageMakerUnifiedStudioProj return } + // Handle network/API failures logger.error('Failed to select project: %s', error.message) - void vscode.window.showErrorMessage(`Failed to select project: ${error.message}`) + await handleCredExpiredError(err, true) } }) } diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode.ts index 728ab03127d..d2d7e4a8173 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode.ts @@ -6,7 +6,6 @@ import * as vscode from 'vscode' import { SageMakerUnifiedStudioComputeNode } from './sageMakerUnifiedStudioComputeNode' import { updateInPlace } from '../../../shared/utilities/collectionUtils' -import { DataZoneClient } from '../../shared/client/datazoneClient' import { DescribeDomainResponse } from '@amzn/sagemaker-client' import { getDomainUserProfileKey } from '../../../awsService/sagemaker/utils' import { getLogger } from '../../../shared/logger/logger' @@ -16,9 +15,14 @@ import { UserProfileMetadata } from '../../../awsService/sagemaker/explorer/sage import { SagemakerUnifiedStudioSpaceNode } from './sageMakerUnifiedStudioSpaceNode' import { PollingSet } from '../../../shared/utilities/pollingSet' import { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticationProvider' -import { SmusUtils } from '../../shared/smusUtils' +import { SmusUtils, SmusErrorCodes } from '../../shared/smusUtils' import { getIcon } from '../../../shared/icons' import { PENDING_NODE_POLLING_INTERVAL_MS } from './utils' +import { getContext } from '../../../shared/vscode/setContext' +import { createDZClientBaseOnDomainMode } from './utils' +import { SmusIamConnection } from '../../auth/model' +import { DataZoneCustomClientHelper } from '../../shared/client/datazoneCustomClientHelper' +import { ToolkitError } from '../../../shared/errors' export class SageMakerUnifiedStudioSpacesParentNode implements TreeNode { public readonly id = 'smusSpacesParentNode' @@ -26,7 +30,7 @@ export class SageMakerUnifiedStudioSpacesParentNode implements TreeNode { private readonly sagemakerSpaceNodes: Map = new Map() private spaceApps: Map = new Map() private domainUserProfiles: Map = new Map() - private readonly logger = getLogger() + private readonly logger = getLogger('smus') private readonly onDidChangeEmitter = new vscode.EventEmitter() public readonly onDidChangeTreeItem = this.onDidChangeEmitter.event public readonly onDidChangeChildren = this.onDidChangeEmitter.event @@ -70,6 +74,16 @@ export class SageMakerUnifiedStudioSpacesParentNode implements TreeNode { if (error.name === 'AccessDeniedException') { return this.getAccessDeniedChildren() } + // Handle no profile found (user or group) using error codes + if ( + error instanceof ToolkitError && + (error.code === SmusErrorCodes.NoGroupProfileFound || error.code === SmusErrorCodes.NoUserProfileFound) + ) { + return await this.getNoUserProfileChildren() + } + if (error.message.includes('Failed to retrieve user profile information')) { + return this.getUserProfileErrorChildren(error.message) + } return this.getNoSpacesFoundChildren() } const nodes = [...this.sagemakerSpaceNodes.values()] @@ -108,6 +122,48 @@ export class SageMakerUnifiedStudioSpacesParentNode implements TreeNode { ] } + private async getNoUserProfileChildren(): Promise { + // Log the IAM principal ARN for debugging + const principalArn = await this.authProvider.getIamPrincipalArn() + if (principalArn) { + this.logger.error(`No spaces found for IAM principal: ${principalArn}`) + } + + return [ + { + id: 'smusNoUserProfile', + resource: {}, + getTreeItem: () => { + const item = new vscode.TreeItem( + 'No spaces found for IAM principal', + vscode.TreeItemCollapsibleState.None + ) + item.iconPath = getIcon('vscode-error') + return item + }, + getParent: () => this, + }, + ] + } + + private getUserProfileErrorChildren(message: string): TreeNode[] { + return [ + { + id: 'smusUserProfileError', + resource: {}, + getTreeItem: () => { + const item = new vscode.TreeItem( + 'Failed to retrieve spaces. Please try again.', + vscode.TreeItemCollapsibleState.None + ) + item.iconPath = getIcon('vscode-error') + return item + }, + getParent: () => this, + }, + ] + } + public getParent(): TreeNode | undefined { return this.parent } @@ -144,8 +200,8 @@ export class SageMakerUnifiedStudioSpacesParentNode implements TreeNode { throw new Error('No active connection found to get SageMaker domain ID') } - this.logger.debug('SMUS: Getting DataZone client instance') - const datazoneClient = await DataZoneClient.getInstance(this.authProvider) + this.logger.debug('Getting DataZone client instance') + const datazoneClient = await createDZClientBaseOnDomainMode(this.authProvider) if (!datazoneClient) { throw new Error('DataZone client is not initialized') } @@ -158,7 +214,7 @@ export class SageMakerUnifiedStudioSpacesParentNode implements TreeNode { if (!resource.value) { throw new Error('SageMaker domain ID not found in tooling environment') } - getLogger().debug(`Found SageMaker domain ID: ${resource.value}`) + getLogger('smus').debug(`Found SageMaker domain ID: ${resource.value}`) return resource.value } } @@ -181,21 +237,102 @@ export class SageMakerUnifiedStudioSpacesParentNode implements TreeNode { } } + /** + * Retrieves the user profile ID for IAM mode (IAM authentication) + * @returns Promise resolving to the user profile ID + * @throws Error if user profile retrieval fails + */ + private async getUserProfileIdForIamAuthMode(): Promise { + try { + // Get cached caller IAM identity ARN from auth provider + const callerArn = await this.authProvider.getIamPrincipalArn() + + if (!callerArn) { + throw new Error('Unable to retrieve caller identity ARN') + } + // Determine if this is an IAM user or role session based on ARN format + if (SmusUtils.isIamUserArn(callerArn)) { + // For IAM users, use GetUserProfile API directly via DataZoneClient + this.logger.debug(`Detected IAM user, using GetUserProfile API with ARN: ${callerArn}`) + + const datazoneClient = await createDZClientBaseOnDomainMode(this.authProvider) + const userProfileId = await datazoneClient.getUserProfileIdForIamPrincipal( + callerArn, + this.authProvider.getDomainId() + ) + + if (!userProfileId) { + throw new ToolkitError('No user profile found for IAM user') + } + + this.logger.debug(`Retrieved user profile ID for IAM user: ${userProfileId}`) + return userProfileId + } else { + // For IAM role sessions, use SearchUserProfile API via DataZoneCustomClientHelper + // Need to get the full assumed role ARN (with session) for filtering + const assumedRoleArn = await this.authProvider.getCachedIamCallerIdentityArn() + + if (!assumedRoleArn) { + throw new Error('Unable to retrieve assumed role ARN with session') + } + + this.logger.debug( + `SMUS: Detected IAM role session, using SearchUserProfile API with ARN: ${assumedRoleArn}` + ) + + // Get credentials provider for the IAM profile + const credentialsProvider = await this.authProvider.getCredentialsProviderForIamProfile( + (this.authProvider.activeConnection as SmusIamConnection).profileName + ) + + const datazoneCustomClientHelper = DataZoneCustomClientHelper.getInstance( + credentialsProvider, + this.authProvider.getDomainRegion() + ) + + const userProfileId = await datazoneCustomClientHelper.getUserProfileIdForSession( + this.authProvider.getDomainId(), + assumedRoleArn + ) + + this.logger.debug(`Retrieved user profile ID for role session: ${userProfileId}`) + return userProfileId + } + } catch (err) { + const error = err as Error + this.logger.error(`Failed to retrieve user profile information: ${error.message}`) + + if (error.name === 'AccessDeniedException') { + throw new Error("You don't have permissions to access this resource. Please contact your administrator") + } + throw err + } + } + private async updateChildren(): Promise { - const datazoneClient = await DataZoneClient.getInstance(this.authProvider) - // Will be of format: 'ABCA4NU3S7PEOLDQPLXYZ:user-12345678-d061-70a4-0bf2-eeee67a6ab12' - const userId = await datazoneClient.getUserId() - const ssoUserProfileId = SmusUtils.extractSSOIdFromUserId(userId || '') + const datazoneClient = await createDZClientBaseOnDomainMode(this.authProvider) + + let userProfileId + if (getContext('aws.smus.isIamMode')) { + userProfileId = await this.getUserProfileIdForIamAuthMode() + } else { + // Will be of format: 'ABCA4NU3S7PEOLDQPLXYZ:user-12345678-d061-70a4-0bf2-eeee67a6ab12' + const userId = await datazoneClient.getUserId() + userProfileId = SmusUtils.extractSSOIdFromUserId(userId || '') + } + const sagemakerDomainId = await this.getSageMakerDomainId() const [spaceApps, domains] = await this.sagemakerClient.fetchSpaceAppsAndDomains( sagemakerDomainId, false /* filterSmusDomains */ ) + // Filter spaceApps to only show spaces owned by current user + this.logger.debug(`Filtering spaces for user profile ID: ${userProfileId}`) const filteredSpaceApps = new Map() for (const [key, app] of spaceApps.entries()) { const userProfile = app.OwnershipSettingsSummary?.OwnerUserProfileName - if (ssoUserProfileId === userProfile) { + if (userProfileId === userProfile) { filteredSpaceApps.set(key, app) } } diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/types.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/types.ts index a94d25fccc4..73da925a8f7 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/types.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/types.ts @@ -17,6 +17,9 @@ export const DATA_DEFAULT_LAKEHOUSE_CONNECTION_NAME_REGEXP = /^(project\.default export const DATA_DEFAULT_ATHENA_CONNECTION_NAME_REGEXP = /^(project\.athena)|(default\.sql)$/ // eslint-disable-next-line @typescript-eslint/naming-convention export const DATA_DEFAULT_S3_CONNECTION_NAME_REGEXP = /^(project\.s3_default_folder)|(default\.s3)$/ +// eslint-disable-next-line @typescript-eslint/naming-convention, id-length +export const S3_PROJECT_NON_GIT_PROJECT_REPOSITORY_LOCATION_NAME_REGEXP = + /^(project\.non_git_project_repository_location)|(default\.s3_shared)$/ // Database object types export enum DatabaseObjects { @@ -205,3 +208,22 @@ export const LEAF_NODE_TYPES = [ // eslint-disable-next-line @typescript-eslint/naming-convention export const NO_DATA_FOUND_MESSAGE = '[No data found]' + +/** + * Glue connection types + */ +export const glueConnectionTypes = [ + 'BIGQUERY', + 'DOCUMENTDB', + 'DYNAMODB', + 'MYSQL', + 'OPENSEARCH', + 'ORACLE', + 'POSTGRESQL', + 'REDSHIFT', + 'SAPHANA', + 'SNOWFLAKE', + 'SQLSERVER', + 'TERADATA', + 'VERTICA', +] diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/utils.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/utils.ts index 32924ad3d9f..ddf0ed5f3ea 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/utils.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/utils.ts @@ -17,8 +17,13 @@ import { DATA_DEFAULT_LAKEHOUSE_CONNECTION_NAME_REGEXP, redshiftColumnTypes, lakeHouseColumnTypes, + glueConnectionTypes, } from './types' -import { DataZoneConnection } from '../../shared/client/datazoneClient' +import { DataZoneClient, DataZoneConnection } from '../../shared/client/datazoneClient' +import { getContext } from '../../../shared/vscode/setContext' +import { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticationProvider' +import { SmusIamConnection } from '../../auth/model' +import { ConnectionStatus } from '@aws-sdk/client-datazone' /** * Polling interval in milliseconds for checking space status updates @@ -47,6 +52,9 @@ export function getLabel(data: { data.value?.connection?.type === ConnectionType.LAKEHOUSE && DATA_DEFAULT_LAKEHOUSE_CONNECTION_NAME_REGEXP.test(data.value?.connection?.name) ) { + if (getContext('aws.smus.isIamMode')) { + return 'Catalogs' + } return 'Lakehouse' } const formattedType = data.value?.connection?.type?.replace(/([A-Z]+(?:_[A-Z]+)*)/g, (match: string) => { @@ -230,6 +238,14 @@ function getColumnIcon(columnType: string): vscode.ThemeIcon | IconPath { return getIcon('vscode-calendar') } + // Check if it's a boolean type + if ( + lakeHouseColumnTypes.BOOLEAN.some((type) => upperType.includes(type)) || + redshiftColumnTypes.BOOLEAN.some((type) => upperType.includes(type)) + ) { + return getIcon('vscode-symbol-boolean') + } + // Default icon for unknown types return new vscode.ThemeIcon('symbol-field') } @@ -374,6 +390,69 @@ export function getRedshiftTypeFromHost(host?: string): RedshiftType | undefined } } +/** + * This function searches for property keys that end with "Properties" (like "snowflakeProperties", + * "redshiftProperties", "athenaProperties") and returns the actual property object, not just the key name. + * It only works for connections that have a glueConnectionName, indicating they are federated connections. + * + * @param connection - The DataZone connection object to search + * @returns The property object (not the key name) if found, undefined otherwise + * + * @example + * ```typescript + * // Redshift connection + * const redshiftConnection = { + * glueConnectionName: 'my-redshift-glue-conn', + * props: { + * redshiftProperties: { + * status: 'FAILED', + * errorMessage: 'Connection timeout' + * } + * } + * } + * const result = getGluePropertiesKey(redshiftConnection) + * // Returns: { status: 'FAILED', errorMessage: 'Connection timeout' } + */ +export function getGluePropertiesKey(connection: DataZoneConnection) { + if (!connection?.props) { + return undefined + } + if (!connection.glueConnectionName) { + return undefined + } + // Check for other properties that might contain glue connection info + const propertiesKey = Object.keys(connection.props).find( + (key) => + key.endsWith('Properties') && + typeof connection.props![key] === 'object' && + !Array.isArray(connection.props![key]) + ) + + return propertiesKey ? connection.props[propertiesKey] : undefined +} + +/** + * This function handles the refactor where connections moved from a single `glueProperties` object to + * connector-specific property bags (like `snowflakeProperties`, `redshiftProperties`, `athenaProperties`). + * It first checks for the legacy `glueProperties` field, then falls back to connector-specific properties. + * + * @param connection - The DataZone connection object to extract properties from + * @returns Object with optional status and errorMessage fields, or undefined if no properties found + */ +export function getGlueProperties(connection?: DataZoneConnection) { + if (!connection?.props) { + return undefined + } + // Check for direct glueProperties + if ('glueProperties' in connection.props) { + return connection.props.glueProperties + } + + return connection?.props?.[getGluePropertiesKey(connection)!] as + | { status?: ConnectionStatus; errorMessage?: string } + | undefined +} + /** * Determines if a connection is a federated connection by checking its type. * A connection is considered federated if it's either: @@ -385,7 +464,57 @@ export function getRedshiftTypeFromHost(host?: string): RedshiftType | undefined */ export function isFederatedConnection(connection?: DataZoneConnection): boolean { if (connection?.type === ConnectionType.REDSHIFT) { - return !!connection?.props?.glueProperties + return !!getGlueProperties(connection) } - return false + + // Check if connection type exists in GlueConnectionType enum values + return glueConnectionTypes.includes(connection?.type || '') +} + +/** + * Creates a DataZoneClient with appropriate credentials provider based on domain mode + * If domain mode is IAM mode, use the credential profile credential provider + * If domain mode is not IAM mode, use the DER credential provider + * @param smusAuthProvider The SMUS authentication provider + * @returns Promise resolving to DataZoneClient instance + */ +export async function createDZClientBaseOnDomainMode( + smusAuthProvider: SmusAuthenticationProvider +): Promise { + let credentialsProvider + if (getContext('aws.smus.isIamMode') && !getContext('aws.smus.inSmusSpaceEnvironment')) { + credentialsProvider = await smusAuthProvider.getCredentialsProviderForIamProfile( + (smusAuthProvider.activeConnection as SmusIamConnection).profileName + ) + } else { + credentialsProvider = await smusAuthProvider.getDerCredentialsProvider() + } + return DataZoneClient.createWithCredentials( + smusAuthProvider.getDomainRegion(), + smusAuthProvider.getDomainId(), + credentialsProvider + ) +} + +/** + * Creates a DataZoneClient with appropriate credentials provider for a specific project + * If domain mode is IAM mode, use the project credential provider + * If domain mode is not IAM mode, use the DER credential provider + * @param smusAuthProvider The SMUS authentication provider + * @param projectId The project ID for project-specific credentials + * @returns Promise resolving to DataZoneClient instance + */ +export async function createDZClientForProject( + smusAuthProvider: SmusAuthenticationProvider, + projectId: string +): Promise { + const credentialsProvider = getContext('aws.smus.isIamMode') + ? await smusAuthProvider.getProjectCredentialProvider(projectId) + : await smusAuthProvider.getDerCredentialsProvider() + + return DataZoneClient.createWithCredentials( + smusAuthProvider.getDomainRegion(), + smusAuthProvider.getDomainId(), + credentialsProvider + ) } diff --git a/packages/core/src/sagemakerunifiedstudio/shared/client/connectionClientStore.ts b/packages/core/src/sagemakerunifiedstudio/shared/client/connectionClientStore.ts index edf317f6479..e48f96f6bdb 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/client/connectionClientStore.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/client/connectionClientStore.ts @@ -132,7 +132,7 @@ export class ConnectionClientStore { * Clears all cached clients */ public clearAll(): void { - getLogger().info('SMUS Connection: Clearing all cached clients') + getLogger('smus').info('SMUS Connection: Clearing all cached clients') this.clientCache = {} } } diff --git a/packages/core/src/sagemakerunifiedstudio/shared/client/credentialsAdapter.ts b/packages/core/src/sagemakerunifiedstudio/shared/client/credentialsAdapter.ts index 88d08c93b86..ee8fef5a5d3 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/client/credentialsAdapter.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/client/credentialsAdapter.ts @@ -6,12 +6,13 @@ import * as AWS from 'aws-sdk' import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider' import { getLogger } from '../../../shared/logger/logger' +import { CredentialsProvider } from '../../../auth/providers/credentials' /** * Adapts a ConnectionCredentialsProvider (SDK v3) to work with SDK v2's CredentialProviderChain */ export function adaptConnectionCredentialsProvider( - connectionCredentialsProvider: ConnectionCredentialsProvider + connectionCredentialsProvider: ConnectionCredentialsProvider | CredentialsProvider ): AWS.CredentialProviderChain { const provider = () => { // Create SDK v2 Credentials that will resolve the provider when needed @@ -23,12 +24,12 @@ export function adaptConnectionCredentialsProvider( // Override the get method to use the connection credentials provider credentials.get = (callback) => { - getLogger().debug('Attempting to get credentials from ConnectionCredentialsProvider') + getLogger('smus').debug('Attempting to get credentials from ConnectionCredentialsProvider') connectionCredentialsProvider .getCredentials() .then((creds) => { - getLogger().debug('Successfully got credentials') + getLogger('smus').debug('Successfully got credentials') credentials.accessKeyId = creds.accessKeyId as string credentials.secretAccessKey = creds.secretAccessKey as string @@ -37,7 +38,7 @@ export function adaptConnectionCredentialsProvider( callback() }) .catch((err) => { - getLogger().debug(`Failed to get credentials: ${err}`) + getLogger('smus').debug(`Failed to get credentials: ${err}`) callback(err) }) diff --git a/packages/core/src/sagemakerunifiedstudio/shared/client/datazoneClient.ts b/packages/core/src/sagemakerunifiedstudio/shared/client/datazoneClient.ts index ffa0e7bfbf3..be612ce459b 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/client/datazoneClient.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/client/datazoneClient.ts @@ -18,8 +18,24 @@ import { GetEnvironmentCommandOutput, } from '@aws-sdk/client-datazone' import { getLogger } from '../../../shared/logger/logger' -import type { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticationProvider' import { DefaultStsClient } from '../../../shared/clients/stsClient' +import { getContext } from '../../../shared/vscode/setContext' +import { CredentialsProvider } from '../../../auth/providers/credentials' +import { DevSettings } from '../../../shared/settings' +import { ToolkitError } from '../../../shared/errors' +import { SmusErrorCodes } from '../smusUtils' + +/** + * Represents a DataZone domain + */ +export interface DataZoneDomain { + id: string + name: string + description?: string + status?: string + createdAt?: Date + updatedAt?: Date +} /** * Represents a DataZone project @@ -29,6 +45,7 @@ export interface DataZoneProject { name: string description?: string domainId: string + createdBy?: string createdAt?: Date updatedAt?: Date } @@ -82,6 +99,10 @@ export interface DataZoneConnection { awsAccountId?: string iamConnectionId?: string } + /** + * Glue connection name + */ + glueConnectionName?: string } // Constants for DataZone environment configuration @@ -89,14 +110,64 @@ const toolingBlueprintName = 'Tooling' const sageMakerProviderName = 'Amazon SageMaker' /** - * Client for interacting with AWS DataZone API with DER credential support - * - * This client integrates with SmusAuthenticationProvider to provide authenticated - * DataZone operations using Domain Execution Role (DER) credentials. + * Client for interacting with AWS DataZone API * - * One instance per connection/domainId is maintained to avoid duplication. + * This client can be used with different credential providers */ export class DataZoneClient { + private datazoneClient: DataZone | undefined + private static instances = new Map() + private readonly logger = getLogger('smus') + + private constructor( + private readonly region: string, + private readonly domainId: string, + private readonly credentialsProvider?: CredentialsProvider + ) {} + + /** + * Creates a new DataZoneClient instance with specific credentials + * @param region AWS region + * @param domainId DataZone domain ID + * @param credentialsProvider Credentials provider + * @returns DataZoneClient instance with credentials + */ + public static createWithCredentials( + region: string, + domainId: string, + credentialsProvider: CredentialsProvider + ): DataZoneClient { + const instanceKey = credentialsProvider.getHashCode() + + if (DataZoneClient.instances.has(instanceKey)) { + const existingInstance = DataZoneClient.instances.get(instanceKey)! + getLogger('smus').debug(`DataZoneClient: Using existing instance, instance key is ${instanceKey}`) + return existingInstance + } + + // Create new instance + getLogger('smus').debug(`DataZoneClient: Creating new instance with instance key ${instanceKey}`) + const instance = new DataZoneClient(region, domainId, credentialsProvider) + DataZoneClient.instances.set(instanceKey, instance) + + return instance + } + + /** + * Disposes all cached DataZoneClient instances + */ + public static dispose(): void { + const logger = getLogger('smus') + getLogger('smus').debug('DataZoneClient: Disposing all cached instances') + + for (const [key, instance] of DataZoneClient.instances.entries()) { + instance.datazoneClient = undefined + logger.debug(`DataZoneClient: Disposed instance for: ${key}`) + } + + DataZoneClient.instances.clear() + } + /** * Parse a Redshift connection info object from JDBC URL * @param jdbcURL Example JDBC URL: jdbc:redshift://redshift-serverless-workgroup-3zzw0fjmccdixz.123456789012.us-east-1.redshift-serverless.amazonaws.com:5439/dev @@ -140,73 +211,6 @@ export class DataZoneClient { } } - private datazoneClient: DataZone | undefined - private static instances = new Map() - private readonly logger = getLogger() - - private constructor( - private readonly authProvider: SmusAuthenticationProvider, - private readonly domainId: string, - private readonly region: string - ) {} - - /** - * Gets an authenticated DataZoneClient instance using DER credentials - * One instance per connection/domainId is maintained - * @param authProvider The SMUS authentication provider - * @returns Promise resolving to authenticated DataZoneClient instance - */ - public static async getInstance(authProvider: SmusAuthenticationProvider): Promise { - const logger = getLogger() - - if (!authProvider.isConnected()) { - throw new Error('SMUS authentication provider is not connected') - } - - const activeConnection = authProvider.activeConnection! - const instanceKey = `${activeConnection.domainId}:${activeConnection.ssoRegion}` - - logger.debug(`DataZoneClient: Getting instance for domain: ${instanceKey}`) - - // Check if we already have an instance for this domain/region - if (DataZoneClient.instances.has(instanceKey)) { - const existingInstance = DataZoneClient.instances.get(instanceKey)! - logger.debug('DataZoneClient: Using existing instance') - return existingInstance - } - - // Create new instance - logger.debug('DataZoneClient: Creating new instance') - const instance = new DataZoneClient(authProvider, activeConnection.domainId, activeConnection.ssoRegion) - DataZoneClient.instances.set(instanceKey, instance) - - // Set up cleanup when connection changes - const disposable = authProvider.onDidChangeActiveConnection(() => { - logger.debug(`DataZoneClient: Connection changed, cleaning up instance for: ${instanceKey}`) - DataZoneClient.instances.delete(instanceKey) - instance.datazoneClient = undefined - disposable.dispose() - }) - - logger.info(`DataZoneClient: Created instance for domain ${activeConnection.domainId}`) - return instance - } - - /** - * Disposes all instances and cleans up resources - */ - public static dispose(): void { - const logger = getLogger() - logger.debug('DataZoneClient: Disposing all instances') - - for (const [key, instance] of DataZoneClient.instances.entries()) { - instance.datazoneClient = undefined - logger.debug(`DataZoneClient: Disposed instance for: ${key}`) - } - - DataZoneClient.instances.clear() - } - /** * Gets the DataZone domain ID * @returns DataZone domain ID @@ -240,9 +244,8 @@ export class DataZoneClient { const domainBlueprints = await datazoneClient.listEnvironmentBlueprints({ domainIdentifier: this.domainId, managed: true, - name: toolingBlueprintName, + name: this.getToolingBlueprintName(), }) - const toolingBlueprint = domainBlueprints.items?.[0] if (!toolingBlueprint) { this.logger.error('Failed to get tooling blueprint') @@ -257,7 +260,7 @@ export class DataZoneClient { provider: sageMakerProviderName, }) - const defaultEnv = listEnvs.items?.find((env) => env.name === toolingBlueprintName) + const defaultEnv = listEnvs.items?.[0] if (!defaultEnv) { this.logger.error('Failed to find default Tooling environment') throw new Error('Failed to find default Tooling environment') @@ -282,23 +285,38 @@ export class DataZoneClient { private async getDataZoneClient(): Promise { if (!this.datazoneClient) { try { - this.logger.debug('DataZoneClient: Creating authenticated DataZone client with DER credentials') - - const credentialsProvider = async () => { - const credentials = await (await this.authProvider.getDerCredentialsProvider()).getCredentials() - return { - accessKeyId: credentials.accessKeyId, - secretAccessKey: credentials.secretAccessKey, - sessionToken: credentials.sessionToken, - expiration: credentials.expiration, + if (this.credentialsProvider) { + const awsCredentialProvider = async () => { + const credentials = await this.credentialsProvider!.getCredentials() + return { + accessKeyId: credentials.accessKeyId, + secretAccessKey: credentials.secretAccessKey, + sessionToken: credentials.sessionToken, + expiration: credentials.expiration, + } + } + + const clientConfig: any = { + region: this.region, + credentials: awsCredentialProvider, } + + // Use user setting for endpoint if provided + const devSettings = DevSettings.instance + const customEndpoint = devSettings.get('endpoints', {})['datazone'] + if (customEndpoint) { + clientConfig.endpoint = customEndpoint + this.logger.debug( + `DataZoneClient: Using custom DataZone endpoint from settings: ${customEndpoint}` + ) + } + + this.datazoneClient = new DataZone(clientConfig) + } else { + throw new Error('No credentials provider provided') } - this.datazoneClient = new DataZone({ - region: this.region, - credentials: credentialsProvider, - }) - this.logger.debug('DataZoneClient: Successfully created authenticated DataZone client') + this.logger.info('DataZoneClient: Successfully created authenticated DataZone client') } catch (err) { this.logger.error('DataZoneClient: Failed to create DataZone client: %s', err as Error) throw err @@ -414,6 +432,7 @@ export class DataZoneClient { name: project.name || '', description: project.description, domainId: this.domainId, + createdBy: project.createdBy, createdAt: project.createdAt ? new Date(project.createdAt) : undefined, updatedAt: project.updatedAt ? new Date(project.updatedAt) : undefined, })) @@ -531,6 +550,22 @@ export class DataZoneClient { return undefined } + /** + * Parses glueConnectionName from physical endpoints + * @param physicalEndpoints Array of physical endpoints + * @returns glueConnectionName or undefined + */ + // eslint-disable-next-line id-length + private parseGlueConnectionNameFromPhysicalEndpoints( + physicalEndpoints?: PhysicalEndpoint[] + ): DataZoneConnection['glueConnectionName'] { + if (physicalEndpoints && physicalEndpoints.length > 0) { + const physicalEndpoint = physicalEndpoints[0] + return physicalEndpoint.glueConnectionName + } + return undefined + } + /** * Gets a specific connection by ID * @param params Parameters for getting a connection @@ -561,6 +596,8 @@ export class DataZoneClient { // Parse location from physical endpoints const location = this.parseLocationFromPhysicalEndpoints(response.physicalEndpoints) + const glueConnectionName = this.parseGlueConnectionNameFromPhysicalEndpoints(response.physicalEndpoints) + // Return as DataZoneConnection, currently only required fields are added // Can always include new fields in DataZoneConnection when needed const connection: DataZoneConnection = { @@ -573,6 +610,7 @@ export class DataZoneClient { props: response.props || {}, connectionCredentials: response.connectionCredentials, location, + glueConnectionName, } return connection @@ -634,6 +672,10 @@ export class DataZoneClient { // Parse location from physical endpoints const location = this.parseLocationFromPhysicalEndpoints(connection.physicalEndpoints) + const glueConnectionName = this.parseGlueConnectionNameFromPhysicalEndpoints( + connection.physicalEndpoints + ) + return { connectionId: connection.connectionId || '', name: connection.name || '', @@ -644,6 +686,7 @@ export class DataZoneClient { projectId, props: connection.props || {}, location, + glueConnectionName, } }) allConnections = [...allConnections, ...connections] @@ -676,7 +719,7 @@ export class DataZoneClient { domainBlueprints = await datazoneClient.listEnvironmentBlueprints({ domainIdentifier: domainId, managed: true, - name: toolingBlueprintName, + name: this.getToolingBlueprintName(), }) } catch (err) { this.logger.error( @@ -713,7 +756,7 @@ export class DataZoneClient { throw err } - const defaultEnv = listEnvs.items?.find((env) => env.name === toolingBlueprintName) + const defaultEnv = listEnvs.items?.[0] if (!defaultEnv || !defaultEnv.id) { this.logger.error( 'No default Tooling environment found for domainId: %s, projectId: %s', @@ -728,7 +771,6 @@ export class DataZoneClient { /** * Gets environment details - * @param domainId The DataZone domain identifier * @param environmentId The environment identifier * @returns Promise resolving to environment details */ @@ -760,33 +802,72 @@ export class DataZoneClient { * @returns The tooling environment details */ public async getToolingEnvironment(projectId: string): Promise { - const logger = getLogger() - - const datazoneClient = await DataZoneClient.getInstance(this.authProvider) - if (!datazoneClient) { - throw new Error('DataZone client is not initialized') - } - - const toolingEnvId = await datazoneClient - .getToolingEnvironmentId(datazoneClient.getDomainId(), projectId) - .catch((err) => { - logger.error('Failed to get tooling environment ID for project %s', projectId) - throw new Error(`Failed to get tooling environment ID: ${err.message}`) - }) - + const toolingEnvId = await this.getToolingEnvironmentId(this.getDomainId(), projectId) if (!toolingEnvId) { throw new Error('No default environment found for project') } - - return await datazoneClient.getEnvironmentDetails(toolingEnvId) + return await this.getEnvironmentDetails(toolingEnvId) } public async getUserId(): Promise { - const derCredProvider = await this.authProvider.getDerCredentialsProvider() - this.logger.debug(`Calling STS GetCallerIdentity using DER credentials of ${this.getDomainId()}`) - const stsClient = new DefaultStsClient(this.getRegion(), await derCredProvider.getCredentials()) + if (!this.credentialsProvider) { + throw new Error('Credentials provider is required for getUserId') + } + const callerCredentials = await this.credentialsProvider.getCredentials() + const stsClient = new DefaultStsClient(this.getRegion(), callerCredentials) const callerIdentity = await stsClient.getCallerIdentity() this.logger.debug(`Retrieved caller identity, UserId: ${callerIdentity.UserId}`) return callerIdentity.UserId } + + /** + * Gets the user profile ID for a given IAM principal + * @param userIdentifier IAM user or role ARN + * @param domainIdentifier Optional domain identifier. If not provided, uses the client's domain ID + * @returns Promise resolving to the user profile ID + * @throws ToolkitError with appropriate error code + */ + public async getUserProfileIdForIamPrincipal( + userIdentifier: string, + domainIdentifier?: string + ): Promise { + try { + this.logger.debug(`DataZoneClient: Getting user profile for IAM ARN: ${userIdentifier}`) + + const datazoneClient = await this.getDataZoneClient() + + const params = { + domainIdentifier: domainIdentifier || this.getDomainId(), + userIdentifier: userIdentifier, + } + + const userProfile = await datazoneClient.getUserProfile(params) + + if (!userProfile.id) { + this.logger.error(`DataZoneClient: No user profile ID returned for ARN: ${userIdentifier}`) + throw new ToolkitError(`No user profile found for IAM principal: ${userIdentifier}`, { + code: SmusErrorCodes.NoUserProfileFound, + }) + } + + this.logger.debug(`DataZoneClient: Retrieved user profile ID: ${userProfile.id}`) + return userProfile.id + } catch (err) { + // Re-throw if it's already a ToolkitError + if (err instanceof ToolkitError) { + throw err + } + + // Log and wrap other errors + this.logger.error('DataZoneClient: Failed to get user profile ID: %s', (err as Error).message) + throw ToolkitError.chain(err, 'Failed to get user profile ID') + } + } + + /** + * Gets the correct tooling blueprint name + */ + private getToolingBlueprintName(): string { + return getContext('aws.smus.isIamMode') ? 'ToolingLite' : toolingBlueprintName + } } diff --git a/packages/core/src/sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper.ts b/packages/core/src/sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper.ts new file mode 100644 index 00000000000..3823886506d --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper.ts @@ -0,0 +1,540 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import { getLogger } from '../../../shared/logger/logger' +import apiConfig = require('./datazonecustomclient.json') +import globals from '../../../shared/extensionGlobals' +import { Service } from 'aws-sdk' +import { ServiceConfigurationOptions } from 'aws-sdk/lib/service' +import * as DataZoneCustomClient from './datazonecustomclient' +import { adaptConnectionCredentialsProvider } from './credentialsAdapter' +import { CredentialsProvider } from '../../../auth/providers/credentials' +import { ToolkitError } from '../../../shared/errors' +import { SmusUtils } from '../smusUtils' +import { DevSettings } from '../../../shared/settings' + +import { SmusErrorCodes } from '../smusUtils' + +/** + * Error codes for DataZone operations + * @deprecated Use SmusErrorCodes instead + */ +export const DataZoneErrorCode = { + NoGroupProfileFound: SmusErrorCodes.NoGroupProfileFound, + NoUserProfileFound: SmusErrorCodes.NoUserProfileFound, +} as const + +/** + * Helper client for interacting with AWS DataZone Custom API + */ +export class DataZoneCustomClientHelper { + private datazoneCustomClient: DataZoneCustomClient | undefined + private static instances = new Map() + private readonly logger = getLogger('smus') + + private constructor( + private readonly credentialProvider: CredentialsProvider, + private readonly region: string + ) {} + + /** + * Gets a singleton instance of the DataZoneCustomClientHelper + * @returns DataZoneCustomClientHelper instance + */ + public static getInstance(credentialProvider: CredentialsProvider, region: string): DataZoneCustomClientHelper { + const logger = getLogger('smus') + + const instanceKey = `${region}` + + // Check if we already have an instance for this instanceKey + if (DataZoneCustomClientHelper.instances.has(instanceKey)) { + const existingInstance = DataZoneCustomClientHelper.instances.get(instanceKey)! + logger.debug(`DataZoneCustomClientHelper: Using existing instance for instanceKey ${instanceKey}`) + return existingInstance + } + + // Create new instance + logger.debug('DataZoneCustomClientHelper: Creating new instance') + const instance = new DataZoneCustomClientHelper(credentialProvider, region) + DataZoneCustomClientHelper.instances.set(instanceKey, instance) + + logger.debug(`DataZoneCustomClientHelper: Created instance with instanceKey ${instanceKey}`) + + return instance + } + + /** + * Disposes all instances and cleans up resources + */ + public static dispose(): void { + const logger = getLogger('smus') + logger.debug('DataZoneCustomClientHelper: Disposing all instances') + + for (const [key, instance] of DataZoneCustomClientHelper.instances.entries()) { + instance.datazoneCustomClient = undefined + logger.debug(`DataZoneCustomClientHelper: Disposed instance for: ${key}`) + } + + DataZoneCustomClientHelper.instances.clear() + } + + /** + * Gets the AWS region + * @returns AWS region + */ + public getRegion(): string { + return this.region + } + + /** + * Gets the DataZone client, initializing it if necessary + */ + private async getDataZoneCustomClient(): Promise { + if (!this.datazoneCustomClient) { + try { + this.logger.info('DataZoneCustomClientHelper: Creating authenticated DataZone client') + + // Use user setting for endpoint if provided, otherwise use default + const devSettings = DevSettings.instance + const customEndpoint = devSettings.get('endpoints', {})['datazone'] + const endpoint = customEndpoint || `https://datazone.${this.region}.api.aws` + + if (customEndpoint) { + this.logger.debug( + `DataZoneCustomClientHelper: Using custom DataZone endpoint from settings: ${endpoint}` + ) + } + + this.datazoneCustomClient = (await globals.sdkClientBuilder.createAwsService( + Service, + { + apiConfig: apiConfig, + endpoint: endpoint, + region: this.region, + credentialProvider: adaptConnectionCredentialsProvider(this.credentialProvider), + } as ServiceConfigurationOptions, + undefined, + false + )) as DataZoneCustomClient + + this.logger.info('DataZoneCustomClientHelper: Successfully created authenticated DataZone client') + } catch (err) { + this.logger.error('DataZoneCustomClientHelper: Failed to create DataZone client: %s', err as Error) + throw err + } + } + return this.datazoneCustomClient + } + + /** + * Lists domains in DataZone with pagination support + * @param options Options for listing domains + * @returns Paginated list of DataZone domains with nextToken + */ + public async listDomains(options?: { + maxResults?: number + status?: string + nextToken?: string + }): Promise<{ domains: DataZoneCustomClient.Types.DomainSummary[]; nextToken?: string }> { + try { + this.logger.info(`DataZoneCustomClientHelper: Listing domains in region ${this.region}`) + + const datazoneCustomClient = await this.getDataZoneCustomClient() + + // Call DataZone API to list domains with pagination + const response = await datazoneCustomClient + .listDomains({ + maxResults: options?.maxResults, + status: options?.status, + nextToken: options?.nextToken, + }) + .promise() + + const domains = response.items || [] + + if (domains.length === 0) { + this.logger.info(`DataZoneCustomClientHelper: No domains found`) + } else { + this.logger.debug(`DataZoneCustomClientHelper: Found ${domains.length} domains`) + } + + return { domains, nextToken: response.nextToken } + } catch (err) { + this.logger.error('DataZoneCustomClientHelper: Failed to list domains: %s', (err as Error).message) + throw err + } + } + + /** + * Fetches all domains by handling pagination automatically + * @param options Options for listing domains (excluding nextToken which is handled internally) + * @returns Promise resolving to an array of all DataZone domains + */ + public async fetchAllDomains(options?: { status?: string }): Promise { + try { + let allDomains: DataZoneCustomClient.Types.DomainSummary[] = [] + let nextToken: string | undefined + do { + const maxResultsPerPage = 25 + const response = await this.listDomains({ + ...options, + nextToken, + maxResults: maxResultsPerPage, + }) + allDomains = [...allDomains, ...response.domains] + nextToken = response.nextToken + } while (nextToken) + + this.logger.debug(`DataZoneCustomClientHelper: Fetched a total of ${allDomains.length} domains`) + return allDomains + } catch (err) { + this.logger.error('DataZoneCustomClientHelper: Failed to fetch all domains: %s', (err as Error).message) + throw err + } + } + + /** + * Gets the domain with IAM authentication mode in preferences using pagination with early termination + * @returns Promise resolving to the DataZone domain or undefined if not found + */ + public async getIamDomain(): Promise { + const logger = getLogger('smus') + + try { + logger.info('DataZoneCustomClientHelper: Getting the domain info') + + let nextToken: string | undefined + let totalDomainsChecked = 0 + const maxResultsPerPage = 25 + + // Paginate through domains and check each page for IAM-based domain + do { + const response = await this.listDomains({ + status: 'AVAILABLE', + nextToken, + maxResults: maxResultsPerPage, + }) + + const { domains } = response + totalDomainsChecked += domains.length + + logger.debug( + `DataZoneCustomClientHelper: Checking ${domains.length} domains in current page (total checked: ${totalDomainsChecked})` + ) + + // Check each domain in the current page for IAM authentication mode + for (const domain of domains) { + if (domain.preferences && domain.preferences.DOMAIN_MODE === 'EXPRESS') { + logger.info( + `DataZoneCustomClientHelper: Found IAM-based domain, id: ${domain.id} (${domain.name})` + ) + return domain + } + } + + nextToken = response.nextToken + } while (nextToken) + + logger.info( + `DataZoneCustomClientHelper: No domain with IAM authentication (DOMAIN_MODE: EXPRESS) found after checking all ${totalDomainsChecked} domains` + ) + return undefined + } catch (err) { + logger.error('DataZoneCustomClientHelper: Failed to get domain info: %s', err as Error) + throw new Error(`Failed to get domain info: ${(err as Error).message}`) + } + } + + /** + * Gets a specific domain by its ID + * @param domainId The ID of the domain to retrieve + * @returns Promise resolving to the GetDomainOutput + */ + public async getDomain(domainId: string): Promise { + try { + this.logger.debug(`DataZoneCustomClientHelper: Getting domain with ID: ${domainId}`) + + const datazoneCustomClient = await this.getDataZoneCustomClient() + + const response = await datazoneCustomClient + .getDomain({ + identifier: domainId, + }) + .promise() + + this.logger.debug(`DataZoneCustomClientHelper: Successfully retrieved domain: ${domainId}`) + return response + } catch (err) { + this.logger.error('DataZoneCustomClientHelper: Failed to get domain: %s', (err as Error).message) + throw err + } + } + + /** + * Checks if a specific domain is an IAM-based domain + * @param domainId The ID of the domain to check + * @returns Promise resolving to true if the domain is IAM-based, false otherwise + */ + public async isIamDomain(domainId: string): Promise { + try { + this.logger.debug(`DataZoneCustomClientHelper: Checking if domain ${domainId} is IAM-based`) + + const domain = await this.getDomain(domainId) + const isIamMode = domain.preferences?.DOMAIN_MODE === 'EXPRESS' || false + + this.logger.debug( + `DataZoneCustomClientHelper: Domain ${domainId} is ${isIamMode ? 'IAM-based' : 'not IAM-based'}` + ) + return isIamMode + } catch (err) { + this.logger.error('DataZoneCustomClientHelper: Failed to check if domain is IAM-based: %s', err as Error) + throw err + } + } + + /** + * Searches for group profiles in the DataZone domain + * @param domainIdentifier The domain identifier to search in + * @param options Options for searching group profiles + * @returns Promise resolving to group profile search results with pagination + */ + public async searchGroupProfiles( + domainIdentifier: string, + options?: { + groupType?: string + searchText?: string + maxResults?: number + nextToken?: string + } + ): Promise<{ items: DataZoneCustomClient.Types.GroupProfileSummary[]; nextToken?: string }> { + try { + this.logger.debug( + `DataZoneCustomClientHelper: Searching group profiles in domain ${domainIdentifier} with groupType: ${options?.groupType}, searchText: ${options?.searchText}` + ) + + const datazoneCustomClient = await this.getDataZoneCustomClient() + + // Build the request parameters + const params: DataZoneCustomClient.Types.SearchGroupProfilesInput = { + domainIdentifier, + groupType: options?.groupType as DataZoneCustomClient.Types.GroupSearchType, + searchText: options?.searchText, + maxResults: options?.maxResults, + nextToken: options?.nextToken, + } + + // Call DataZone API to search group profiles + const response = await datazoneCustomClient.searchGroupProfiles(params).promise() + + const items = response.items || [] + + if (items.length === 0) { + this.logger.debug(`DataZoneCustomClientHelper: No group profiles found`) + } else { + this.logger.debug(`DataZoneCustomClientHelper: Found ${items.length} group profiles`) + } + + return { items, nextToken: response.nextToken } + } catch (err) { + this.logger.error('DataZoneCustomClientHelper: Failed to search group profiles: %s', (err as Error).message) + throw err + } + } + + /** + * Searches for user profiles in the DataZone domain + * @param domainIdentifier The domain identifier to search in + * @param options Options for searching user profiles + * @returns Promise resolving to user profile search results with pagination + */ + public async searchUserProfiles( + domainIdentifier: string, + options: { + userType: string + searchText?: string + maxResults?: number + nextToken?: string + } + ): Promise<{ items: DataZoneCustomClient.Types.UserProfileSummary[]; nextToken?: string }> { + try { + this.logger.debug( + `DataZoneCustomClientHelper: Searching user profiles in domain ${domainIdentifier} with userType: ${options.userType}, searchText: ${options.searchText}` + ) + + const datazoneCustomClient = await this.getDataZoneCustomClient() + + // Build the request parameters + const params: DataZoneCustomClient.Types.SearchUserProfilesInput = { + domainIdentifier, + userType: options.userType as DataZoneCustomClient.Types.UserSearchType, + searchText: options.searchText, + maxResults: options.maxResults, + nextToken: options.nextToken, + } + + // Call DataZone API to search user profiles + const response = await datazoneCustomClient.searchUserProfiles(params).promise() + + const items = response.items || [] + + if (items.length === 0) { + this.logger.debug(`DataZoneCustomClientHelper: No user profiles found`) + } else { + this.logger.debug(`DataZoneCustomClientHelper: Found ${items.length} user profiles`) + } + + return { items, nextToken: response.nextToken } + } catch (err) { + this.logger.error('DataZoneCustomClientHelper: Failed to search user profiles: %s', (err as Error).message) + throw err + } + } + + /** + * Gets the group profile ID for a given IAM role ARN + * @param domainIdentifier The domain identifier to search in + * @param roleArn The base IAM role ARN (format: arn:aws:iam::ACCOUNT:role/ROLE_NAME) + * @returns Promise resolving to the group profile ID + * @throws ToolkitError with appropriate error code + */ + public async getGroupProfileId(domainIdentifier: string, roleArn: string): Promise { + try { + this.logger.debug( + `DataZoneCustomClientHelper: Getting group profile ID for role ARN: ${roleArn} in domain ${domainIdentifier}` + ) + + // Use searchText to filter server-side for better performance + const response = await this.searchGroupProfiles(domainIdentifier, { + groupType: 'IAM_ROLE_SESSION_GROUP', + searchText: roleArn, + maxResults: 50, + }) + + this.logger.debug( + `DataZoneCustomClientHelper: Received ${response.items.length} group profiles from search` + ) + + // Find exact match in filtered results + for (const profile of response.items) { + this.logger.debug( + `DataZoneCustomClientHelper: Checking group profile - ID: ${profile.id}, rolePrincipalArn: ${profile.rolePrincipalArn}, status: ${profile.status}` + ) + + if (profile.rolePrincipalArn === roleArn) { + this.logger.info(`DataZoneCustomClientHelper: Found matching group profile with ID: ${profile.id}`) + return profile.id! + } + } + + // No matching profile found + this.logger.error(`DataZoneCustomClientHelper: No group profile found for IAM role: ${roleArn}`) + throw new ToolkitError(`No group profile found for IAM role: ${roleArn}`, { + code: SmusErrorCodes.NoGroupProfileFound, + }) + } catch (err) { + // Re-throw if it's already a ToolkitError + if (err instanceof ToolkitError) { + throw err + } + + // Log and wrap other errors + this.logger.error('DataZoneCustomClientHelper: Failed to get group profile ID: %s', (err as Error).message) + throw ToolkitError.chain(err, 'Failed to get group profile ID') + } + } + + /** + * Gets the user profile ID for a given IAM role session + * @param domainIdentifier The domain identifier to search in + * @param roleArnWithSession The assumed role ARN with session name (format: arn:aws:sts::ACCOUNT:assumed-role/ROLE_NAME/SESSION_NAME) + * @returns Promise resolving to the user profile ID + * @throws ToolkitError with appropriate error code + */ + public async getUserProfileIdForSession(domainIdentifier: string, roleArnWithSession: string): Promise { + try { + this.logger.debug( + `DataZoneCustomClientHelper: Getting user profile ID for role ARN with session: ${roleArnWithSession} in domain ${domainIdentifier}` + ) + + // Extract session name from the assumed role ARN + // Format: arn:aws:sts::ACCOUNT:assumed-role/ROLE_NAME/SESSION_NAME + const sessionName = SmusUtils.extractSessionNameFromArn(roleArnWithSession) + if (!sessionName) { + throw new ToolkitError(`Unable to extract session name from ARN: ${roleArnWithSession}`, { + code: SmusErrorCodes.NoUserProfileFound, + }) + } + + // Convert assumed role ARN to IAM role ARN for matching + // Format: arn:aws:sts::ACCOUNT:assumed-role/ROLE_NAME/SESSION_NAME -> arn:aws:iam::ACCOUNT:role/ROLE_NAME + const iamRoleArn = SmusUtils.convertAssumedRoleArnToIamRoleArn(roleArnWithSession) + if (!iamRoleArn) { + throw new ToolkitError(`Unable to convert assumed role ARN to IAM role ARN: ${roleArnWithSession}`, { + code: SmusErrorCodes.NoUserProfileFound, + }) + } + + this.logger.debug( + `DataZoneCustomClientHelper: Extracted session name: ${sessionName}, IAM role ARN: ${iamRoleArn}` + ) + + // Use searchText to filter by role ARN on server side, then filter by session name on client side + let nextToken: string | undefined + let totalProfilesChecked = 0 + + do { + this.logger.debug( + `DataZoneCustomClientHelper: Calling searchUserProfiles with searchText: ${iamRoleArn}` + ) + + const response = await this.searchUserProfiles(domainIdentifier, { + userType: 'DATAZONE_IAM_USER', + searchText: iamRoleArn, // Server-side filter by role ARN + maxResults: 50, + nextToken, + }) + + totalProfilesChecked += response.items.length + this.logger.debug( + `DataZoneCustomClientHelper: Received ${response.items.length} user profiles matching role ARN in current page (total checked: ${totalProfilesChecked})` + ) + + // Find exact match in current page using client-side filtering for session name + // Server-side filtering by role ARN should have already reduced the result set significantly + for (const profile of response.items) { + // Match based on session name (role ARN already filtered by searchText) + // principalId format: PRINCIPAL_ID:SESSION_NAME + const matchesSession = profile.details?.iam?.principalId?.includes(sessionName) + + if (matchesSession) { + this.logger.info( + `DataZoneCustomClientHelper: Found matching user profile with ID: ${profile.id} (role: ${iamRoleArn}, session: ${sessionName}) after checking ${totalProfilesChecked} profiles` + ) + return profile.id! + } + } + + nextToken = response.nextToken + } while (nextToken) + + // No matching profile found after checking all pages + this.logger.error( + `DataZoneCustomClientHelper: No user profile found for role: ${iamRoleArn} with session: ${sessionName} after checking ${totalProfilesChecked} profiles` + ) + throw new ToolkitError(`No user profile found for role: ${iamRoleArn} with session: ${sessionName}`, { + code: SmusErrorCodes.NoUserProfileFound, + }) + } catch (err) { + // Re-throw if it's already a ToolkitError + if (err instanceof ToolkitError) { + throw err + } + + // Log and wrap other errors + this.logger.error('DataZoneCustomClientHelper: Failed to get user profile ID: %s', (err as Error).message) + throw ToolkitError.chain(err, 'Failed to get user profile ID') + } + } +} diff --git a/packages/core/src/sagemakerunifiedstudio/shared/client/datazonecustomclient.json b/packages/core/src/sagemakerunifiedstudio/shared/client/datazonecustomclient.json new file mode 100644 index 00000000000..8414d3340b8 --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/shared/client/datazonecustomclient.json @@ -0,0 +1,919 @@ +{ + "version": "2.0", + "metadata": { + "apiVersion": "2018-05-10", + "auth": ["aws.auth#sigv4"], + "endpointPrefix": "datazone", + "protocol": "rest-json", + "protocols": ["rest-json"], + "serviceFullName": "Amazon DataZone", + "serviceId": "DataZone", + "signatureVersion": "v4", + "signingName": "datazone", + "uid": "datazone-2018-05-10" + }, + "operations": { + "GetDomain": { + "name": "GetDomain", + "http": { + "method": "GET", + "requestUri": "/v2/domains/{identifier}", + "responseCode": 200 + }, + "input": { + "shape": "GetDomainInput" + }, + "output": { + "shape": "GetDomainOutput" + }, + "errors": [ + { + "shape": "InternalServerException" + }, + { + "shape": "ResourceNotFoundException" + }, + { + "shape": "AccessDeniedException" + }, + { + "shape": "ThrottlingException" + }, + { + "shape": "ServiceQuotaExceededException" + }, + { + "shape": "ValidationException" + }, + { + "shape": "UnauthorizedException" + } + ], + "readonly": true + }, + "ListDomains": { + "name": "ListDomains", + "http": { + "method": "GET", + "requestUri": "/v2/domains", + "responseCode": 200 + }, + "input": { + "shape": "ListDomainsInput" + }, + "output": { + "shape": "ListDomainsOutput" + }, + "errors": [ + { + "shape": "InternalServerException" + }, + { + "shape": "ResourceNotFoundException" + }, + { + "shape": "AccessDeniedException" + }, + { + "shape": "ThrottlingException" + }, + { + "shape": "ServiceQuotaExceededException" + }, + { + "shape": "ConflictException" + }, + { + "shape": "ValidationException" + }, + { + "shape": "UnauthorizedException" + } + ], + "readonly": true + }, + "SearchGroupProfiles": { + "name": "SearchGroupProfiles", + "http": { + "method": "POST", + "requestUri": "/v2/domains/{domainIdentifier}/search-group-profiles", + "responseCode": 200 + }, + "input": { + "shape": "SearchGroupProfilesInput" + }, + "output": { + "shape": "SearchGroupProfilesOutput" + }, + "errors": [ + { + "shape": "InternalServerException" + }, + { + "shape": "ResourceNotFoundException" + }, + { + "shape": "AccessDeniedException" + }, + { + "shape": "ThrottlingException" + }, + { + "shape": "ConflictException" + }, + { + "shape": "ValidationException" + }, + { + "shape": "UnauthorizedException" + } + ] + }, + "SearchUserProfiles": { + "name": "SearchUserProfiles", + "http": { + "method": "POST", + "requestUri": "/v2/domains/{domainIdentifier}/search-user-profiles", + "responseCode": 200 + }, + "input": { + "shape": "SearchUserProfilesInput" + }, + "output": { + "shape": "SearchUserProfilesOutput" + }, + "errors": [ + { + "shape": "InternalServerException" + }, + { + "shape": "ResourceNotFoundException" + }, + { + "shape": "AccessDeniedException" + }, + { + "shape": "ThrottlingException" + }, + { + "shape": "ConflictException" + }, + { + "shape": "ValidationException" + }, + { + "shape": "UnauthorizedException" + } + ] + } + }, + "shapes": { + "GetDomainInput": { + "type": "structure", + "required": ["identifier"], + "members": { + "identifier": { + "shape": "DomainId", + "location": "uri", + "locationName": "identifier" + } + } + }, + "DomainId": { + "type": "string", + "pattern": "dzd[-_][a-zA-Z0-9_-]{1,36}" + }, + "GetDomainOutput": { + "type": "structure", + "required": ["id", "status"], + "members": { + "id": { + "shape": "DomainId" + }, + "rootDomainUnitId": { + "shape": "DomainUnitId" + }, + "name": { + "shape": "String" + }, + "description": { + "shape": "String" + }, + "singleSignOn": { + "shape": "SingleSignOn" + }, + "domainExecutionRole": { + "shape": "RoleArn" + }, + "arn": { + "shape": "String" + }, + "kmsKeyIdentifier": { + "shape": "KmsKeyArn" + }, + "status": { + "shape": "DomainStatus" + }, + "failureReasons": { + "shape": "FailureReasonsList" + }, + "portalUrl": { + "shape": "String" + }, + "createdAt": { + "shape": "CreatedAt" + }, + "lastUpdatedAt": { + "shape": "UpdatedAt" + }, + "tags": { + "shape": "Tags" + }, + "provisionStatus": { + "shape": "ProvisionStatus", + "internalonly": true + }, + "domainVersion": { + "shape": "DomainVersion" + }, + "domainServiceRole": { + "shape": "RoleArn", + "deprecated": true, + "internalonly": true + }, + "serviceRole": { + "shape": "RoleArn" + }, + "supportedDomainVersions": { + "shape": "SupportedDomainVersions", + "internalonly": true + }, + "iamSignIns": { + "shape": "IamSignIns", + "internalonly": true + }, + "preferences": { + "shape": "Preferences", + "internalonly": true + } + } + }, + "DomainUnitId": { + "type": "string", + "max": 256, + "min": 1, + "pattern": "[a-z0-9_\\-]+" + }, + "String": { + "type": "string" + }, + "SingleSignOn": { + "type": "structure", + "members": { + "type": { + "shape": "AuthType" + }, + "userAssignment": { + "shape": "UserAssignment" + }, + "idcInstanceArn": { + "shape": "SingleSignOnIdcInstanceArnString" + }, + "ssoUrl": { + "shape": "String", + "internalonly": true + }, + "idcApplicationArn": { + "shape": "String", + "internalonly": true + } + } + }, + "AuthType": { + "type": "string", + "enum": ["IAM_IDC", "DISABLED", "SAML"] + }, + "UserAssignment": { + "type": "string", + "enum": ["AUTOMATIC", "MANUAL"] + }, + "SingleSignOnIdcInstanceArnString": { + "type": "string", + "pattern": ".*arn:(aws|aws-us-gov|aws-cn|aws-iso|aws-iso-b):sso:::instance/(sso)?ins-[a-zA-Z0-9-.]{16}.*" + }, + "RoleArn": { + "type": "string", + "pattern": "arn:aws[^:]*:iam::\\d{12}:role(/[a-zA-Z0-9+=,.@_-]+)*/[a-zA-Z0-9+=,.@_-]+" + }, + "KmsKeyArn": { + "type": "string", + "max": 1024, + "min": 1, + "pattern": "arn:aws(|-cn|-us-gov):kms:[a-zA-Z0-9-]*:[0-9]{12}:key/[a-zA-Z0-9-]{36}" + }, + "DomainStatus": { + "type": "string", + "enum": ["CREATING", "AVAILABLE", "CREATION_FAILED", "DELETING", "DELETED", "DELETION_FAILED"] + }, + "FailureReasonsList": { + "type": "list", + "member": { + "shape": "FailureReason" + } + }, + "FailureReason": { + "type": "structure", + "members": { + "code": { + "shape": "String" + }, + "message": { + "shape": "String" + } + } + }, + "CreatedAt": { + "type": "timestamp" + }, + "UpdatedAt": { + "type": "timestamp" + }, + "Tags": { + "type": "map", + "key": { + "shape": "TagKey" + }, + "value": { + "shape": "TagValue" + } + }, + "TagKey": { + "type": "string", + "max": 128, + "min": 1, + "pattern": "[\\w \\.:/=+@-]+" + }, + "TagValue": { + "type": "string", + "max": 256, + "min": 0, + "pattern": "[\\w \\.:/=+@-]*" + }, + "ProvisionStatus": { + "type": "string", + "enum": [ + "PROVISIONING", + "PROVISIONING_PROJECT_PROFILES", + "PROVISIONING_MODEL_ASSETS", + "PROVISION_FAILED", + "PROVISION_COMPLETE" + ] + }, + "DomainVersion": { + "type": "string", + "enum": ["V1", "V2"] + }, + "SupportedDomainVersions": { + "type": "list", + "member": { + "shape": "DomainVersion" + } + }, + "IamSignIns": { + "type": "list", + "member": { + "shape": "IamSignIn" + }, + "internalonly": true + }, + "IamSignIn": { + "type": "string", + "enum": ["IAM_ROLE", "IAM_USER"], + "internalonly": true + }, + "Preferences": { + "type": "map", + "key": { + "shape": "PreferenceKey" + }, + "value": { + "shape": "PreferenceValue" + }, + "max": 10, + "min": 0 + }, + "PreferenceKey": { + "type": "string", + "max": 128, + "min": 1, + "pattern": "[\\w \\.:/=+@-]+" + }, + "PreferenceValue": { + "type": "string", + "max": 256, + "min": 0, + "pattern": "[\\w \\.:/=+@-]*" + }, + "InternalServerException": { + "type": "structure", + "required": ["message"], + "members": { + "message": { + "shape": "ErrorMessage" + } + }, + "error": { + "httpStatusCode": 500 + }, + "exception": true, + "fault": true, + "retryable": { + "throttling": false + } + }, + "ErrorMessage": { + "type": "string" + }, + "ResourceNotFoundException": { + "type": "structure", + "required": ["message"], + "members": { + "message": { + "shape": "ErrorMessage" + } + }, + "error": { + "httpStatusCode": 404, + "senderFault": true + }, + "exception": true + }, + "AccessDeniedException": { + "type": "structure", + "required": ["message"], + "members": { + "message": { + "shape": "ErrorMessage" + } + }, + "error": { + "httpStatusCode": 403, + "senderFault": true + }, + "exception": true + }, + "ThrottlingException": { + "type": "structure", + "required": ["message"], + "members": { + "message": { + "shape": "ErrorMessage" + } + }, + "error": { + "httpStatusCode": 429, + "senderFault": true + }, + "exception": true, + "retryable": { + "throttling": false + } + }, + "ServiceQuotaExceededException": { + "type": "structure", + "required": ["message"], + "members": { + "message": { + "shape": "ErrorMessage" + } + }, + "error": { + "httpStatusCode": 402, + "senderFault": true + }, + "exception": true + }, + "ValidationException": { + "type": "structure", + "required": ["message"], + "members": { + "message": { + "shape": "ErrorMessage" + } + }, + "error": { + "httpStatusCode": 400, + "senderFault": true + }, + "exception": true + }, + "UnauthorizedException": { + "type": "structure", + "required": ["message"], + "members": { + "message": { + "shape": "ErrorMessage" + } + }, + "error": { + "httpStatusCode": 401, + "senderFault": true + }, + "exception": true + }, + "ListDomainsInput": { + "type": "structure", + "members": { + "status": { + "shape": "DomainStatus", + "location": "querystring", + "locationName": "status" + }, + "maxResults": { + "shape": "MaxResultsForListDomains", + "location": "querystring", + "locationName": "maxResults" + }, + "nextToken": { + "shape": "PaginationToken", + "location": "querystring", + "locationName": "nextToken" + } + } + }, + "MaxResultsForListDomains": { + "type": "integer", + "box": true, + "max": 25, + "min": 1 + }, + "PaginationToken": { + "type": "string", + "max": 8192, + "min": 1 + }, + "ListDomainsOutput": { + "type": "structure", + "required": ["items"], + "members": { + "items": { + "shape": "DomainSummaries" + }, + "domains": { + "shape": "DomainSummaries", + "internalonly": true + }, + "nextToken": { + "shape": "PaginationToken" + } + } + }, + "DomainSummaries": { + "type": "list", + "member": { + "shape": "DomainSummary" + } + }, + "DomainSummary": { + "type": "structure", + "required": ["id", "name", "arn", "managedAccountId", "status", "createdAt"], + "members": { + "id": { + "shape": "DomainId" + }, + "name": { + "shape": "DomainName" + }, + "description": { + "shape": "DomainDescription" + }, + "arn": { + "shape": "String" + }, + "managedAccountId": { + "shape": "String" + }, + "status": { + "shape": "DomainStatus" + }, + "portalUrl": { + "shape": "String" + }, + "createdAt": { + "shape": "CreatedAt" + }, + "lastUpdatedAt": { + "shape": "UpdatedAt" + }, + "domainVersion": { + "shape": "DomainVersion" + }, + "iamSignIns": { + "shape": "IamSignIns", + "internalonly": true + }, + "preferences": { + "shape": "Preferences", + "internalonly": true + } + } + }, + "DomainName": { + "type": "string", + "sensitive": true + }, + "DomainDescription": { + "type": "string", + "sensitive": true + }, + "ConflictException": { + "type": "structure", + "required": ["message"], + "members": { + "message": { + "shape": "ErrorMessage" + }, + "reason": { + "shape": "ConflictReason" + }, + "details": { + "shape": "ConflictDetails" + } + }, + "error": { + "httpStatusCode": 409, + "senderFault": true + }, + "exception": true + }, + "ConflictReason": { + "type": "string", + "enum": ["RESOURCE_LOCKED"] + }, + "ConflictDetails": { + "type": "structure", + "members": { + "lock": { + "shape": "LockDetails" + } + }, + "union": true + }, + "LockDetails": { + "type": "structure", + "required": ["lockedBy", "lockedAt", "lockExpiresAt"], + "members": { + "lockedBy": { + "shape": "String" + }, + "lockedAt": { + "shape": "Timestamp" + }, + "lockExpiresAt": { + "shape": "Timestamp" + } + } + }, + "Timestamp": { + "type": "timestamp" + }, + "SearchGroupProfilesInput": { + "type": "structure", + "required": ["domainIdentifier", "groupType"], + "members": { + "domainIdentifier": { + "shape": "DomainId", + "location": "uri", + "locationName": "domainIdentifier" + }, + "groupType": { + "shape": "GroupSearchType" + }, + "searchText": { + "shape": "GroupSearchText" + }, + "maxResults": { + "shape": "MaxResults" + }, + "nextToken": { + "shape": "PaginationToken" + } + } + }, + "GroupSearchType": { + "type": "string", + "enum": ["SSO_GROUP", "DATAZONE_SSO_GROUP", "IAM_ROLE_SESSION_GROUP"] + }, + "GroupSearchText": { + "type": "string", + "max": 1024, + "min": 0, + "sensitive": true + }, + "MaxResults": { + "type": "integer", + "box": true, + "max": 50, + "min": 1 + }, + "SearchGroupProfilesOutput": { + "type": "structure", + "members": { + "items": { + "shape": "GroupProfileSummaries" + }, + "nextToken": { + "shape": "PaginationToken" + } + } + }, + "GroupProfileSummaries": { + "type": "list", + "member": { + "shape": "GroupProfileSummary" + } + }, + "GroupProfileSummary": { + "type": "structure", + "members": { + "domainId": { + "shape": "DomainId" + }, + "id": { + "shape": "GroupProfileId" + }, + "status": { + "shape": "GroupProfileStatus" + }, + "groupName": { + "shape": "GroupProfileName" + }, + "rolePrincipalArn": { + "shape": "String", + "internalonly": true + }, + "rolePrincipalId": { + "shape": "String", + "internalonly": true + } + } + }, + "GroupProfileId": { + "type": "string", + "pattern": "([0-9a-f]{10}-|)[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}" + }, + "GroupProfileStatus": { + "type": "string", + "enum": ["ASSIGNED", "NOT_ASSIGNED"] + }, + "GroupProfileName": { + "type": "string", + "max": 1024, + "min": 1, + "pattern": "[a-zA-Z_0-9+=,.@-]+", + "sensitive": true + }, + "SearchUserProfilesInput": { + "type": "structure", + "required": ["domainIdentifier", "userType"], + "members": { + "domainIdentifier": { + "shape": "DomainId", + "location": "uri", + "locationName": "domainIdentifier" + }, + "userType": { + "shape": "UserSearchType" + }, + "searchText": { + "shape": "UserSearchText" + }, + "maxResults": { + "shape": "MaxResults" + }, + "nextToken": { + "shape": "PaginationToken" + } + } + }, + "UserSearchType": { + "type": "string", + "enum": ["SSO_USER", "DATAZONE_USER", "DATAZONE_SSO_USER", "DATAZONE_IAM_USER"] + }, + "UserSearchText": { + "type": "string", + "max": 1024, + "min": 0, + "sensitive": true + }, + "SearchUserProfilesOutput": { + "type": "structure", + "members": { + "items": { + "shape": "UserProfileSummaries" + }, + "nextToken": { + "shape": "PaginationToken" + } + } + }, + "UserProfileSummaries": { + "type": "list", + "member": { + "shape": "UserProfileSummary" + } + }, + "UserProfileSummary": { + "type": "structure", + "members": { + "domainId": { + "shape": "DomainId" + }, + "id": { + "shape": "UserProfileId" + }, + "type": { + "shape": "UserProfileType" + }, + "status": { + "shape": "UserProfileStatus" + }, + "details": { + "shape": "UserProfileDetails" + } + } + }, + "UserProfileId": { + "type": "string", + "pattern": "([0-9a-f]{10}-|)[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}" + }, + "UserProfileType": { + "type": "string", + "enum": ["IAM", "SSO", "SAML"] + }, + "UserProfileStatus": { + "type": "string", + "enum": ["ASSIGNED", "NOT_ASSIGNED", "ACTIVATED", "DEACTIVATED", "ARCHIVED"] + }, + "UserProfileDetails": { + "type": "structure", + "members": { + "iam": { + "shape": "IamUserProfileDetails" + }, + "sso": { + "shape": "SsoUserProfileDetails" + } + }, + "union": true + }, + "IamUserProfileDetails": { + "type": "structure", + "members": { + "arn": { + "shape": "String" + }, + "principalId": { + "shape": "String" + } + } + }, + "SsoUserProfileDetails": { + "type": "structure", + "members": { + "username": { + "shape": "UserProfileName" + }, + "firstName": { + "shape": "FirstName" + }, + "lastName": { + "shape": "LastName" + }, + "email": { + "shape": "String", + "internalonly": true + }, + "userId": { + "shape": "String", + "internalonly": true + } + } + }, + "UserProfileName": { + "type": "string", + "max": 1024, + "min": 1, + "pattern": "[a-zA-Z_0-9+=,.@-]+", + "sensitive": true + }, + "FirstName": { + "type": "string", + "sensitive": true + }, + "LastName": { + "type": "string", + "sensitive": true + } + } +} diff --git a/packages/core/src/sagemakerunifiedstudio/shared/client/glueCatalogClient.ts b/packages/core/src/sagemakerunifiedstudio/shared/client/glueCatalogClient.ts index bbd3c440478..5eaccd2b5d5 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/client/glueCatalogClient.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/client/glueCatalogClient.ts @@ -23,7 +23,7 @@ export type GlueCatalog = GlueCatalogApi.Types.Catalog export class GlueCatalogClient { private glueClient: GlueCatalogApi | undefined private static instance: GlueCatalogClient | undefined - private readonly logger = getLogger() + private readonly logger = getLogger('smus') private constructor( private readonly region: string, diff --git a/packages/core/src/sagemakerunifiedstudio/shared/client/glueClient.ts b/packages/core/src/sagemakerunifiedstudio/shared/client/glueClient.ts index 15034a488cf..cad84d23332 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/client/glueClient.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/client/glueClient.ts @@ -22,7 +22,7 @@ import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCr */ export class GlueClient { private glueClient: Glue | undefined - private readonly logger = getLogger() + private readonly logger = getLogger('smus') constructor( private readonly region: string, diff --git a/packages/core/src/sagemakerunifiedstudio/shared/client/s3Client.ts b/packages/core/src/sagemakerunifiedstudio/shared/client/s3Client.ts index d86c3904a07..5375c3d5e92 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/client/s3Client.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/client/s3Client.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { S3 } from '@aws-sdk/client-s3' +import { S3, ListBucketsCommand } from '@aws-sdk/client-s3' import { getLogger } from '../../../shared/logger/logger' import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider' @@ -24,7 +24,7 @@ export interface S3Path { */ export class S3Client { private s3Client: S3 | undefined - private readonly logger = getLogger() + private readonly logger = getLogger('smus') constructor( private readonly region: string, @@ -45,7 +45,7 @@ export class S3Client { continuationToken?: string ): Promise<{ paths: S3Path[]; nextToken?: string }> { try { - this.logger.info(`S3Client: Listing paths in bucket ${bucket} with prefix ${prefix || 'root'}`) + this.logger.info(`S3Client: Listing paths in bucket ${bucket} with prefix ${prefix}`) const s3Client = await this.getS3Client() @@ -116,6 +116,40 @@ export class S3Client { } } + /** + * Lists all S3 buckets accessible to the current credentials + * @returns Array of bucket objects + */ + public async listBuckets(): Promise> { + try { + this.logger.debug('S3Client: Listing all accessible buckets') + + const s3Client = await this.getS3Client() + const allBuckets: Array<{ Name?: string; CreationDate?: Date }> = [] + let continuationToken: string | undefined + + do { + const response = await s3Client.send( + new ListBucketsCommand({ + ContinuationToken: continuationToken, + BucketRegion: this.region, + }) + ) + + if (response.Buckets) { + allBuckets.push(...response.Buckets) + } + continuationToken = response.ContinuationToken + } while (continuationToken) + + this.logger.debug(`S3Client: Found ${allBuckets.length} accessible buckets`) + return allBuckets + } catch (err) { + this.logger.error('S3Client: Failed to list buckets: %s', err as Error) + throw err + } + } + /** * Gets the S3 client, initializing it if necessary */ diff --git a/packages/core/src/sagemakerunifiedstudio/shared/client/sqlWorkbenchClient.ts b/packages/core/src/sagemakerunifiedstudio/shared/client/sqlWorkbenchClient.ts index 5513f139d2b..76527d1d622 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/client/sqlWorkbenchClient.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/client/sqlWorkbenchClient.ts @@ -135,7 +135,7 @@ export async function createRedshiftConnectionConfig( export class SQLWorkbenchClient { private sqlClient: SQLWorkbench | undefined private static instance: SQLWorkbenchClient | undefined - private readonly logger = getLogger() + private readonly logger = getLogger('smus') private constructor( private readonly region: string, diff --git a/packages/core/src/sagemakerunifiedstudio/shared/credentialExpiryHandler.ts b/packages/core/src/sagemakerunifiedstudio/shared/credentialExpiryHandler.ts new file mode 100644 index 00000000000..e5169207976 --- /dev/null +++ b/packages/core/src/sagemakerunifiedstudio/shared/credentialExpiryHandler.ts @@ -0,0 +1,33 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +import * as vscode from 'vscode' +import { isCredentialExpirationError } from './smusUtils' +import { SmusAuthenticationProvider } from '../auth/providers/smusAuthenticationProvider' + +/** + * + * If the provided error indicates expired credentials, it marks the connection as invalid. + * This refreshes the SmusAuthInfo node to reflect the updated authentication state. + * + * @param err The error + * @param showError If true, shows error message to user. If false, silently handles the error. + */ +export async function handleCredExpiredError(err: any, showError: boolean = false): Promise { + const errorMessage = (err as Error).message + if (isCredentialExpirationError(err)) { + if (showError) { + void vscode.window.showErrorMessage( + 'Connection to SageMaker Unified Studio has expired. Please try again after reauthentication.' + ) + } + const smusAuthProvider = SmusAuthenticationProvider.fromContext() + await smusAuthProvider.invalidateConnection() + smusAuthProvider.dispose() + } else { + if (showError) { + void vscode.window.showErrorMessage(errorMessage) + } + } +} diff --git a/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts b/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts index 35858f0dc5a..75de3d0dde3 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts @@ -8,6 +8,9 @@ import { ToolkitError } from '../../shared/errors' import { isSageMaker } from '../../shared/extensionUtilities' import { getResourceMetadata } from './utils/resourceMetadataUtils' import fetch from 'node-fetch' +import { CredentialsProvider, CredentialsProviderType } from '../../auth/providers/credentials' +import { CredentialType } from '../../shared/telemetry/telemetry' +import { AwsCredentialIdentity } from '@aws-sdk/types' /** * Represents SSO instance information retrieved from DataZone @@ -64,6 +67,32 @@ export const SmusErrorCodes = { GetProjectAccountIdFailed: 'GetProjectAccountIdFailed', /** Error code for when region is missing */ RegionNotFound: 'RegionNotFound', + /** Error code for when IAM-based domain is not found in the specified region */ + IamDomainNotFound: 'IamDomainNotFound', + /** Error code for when IAM profile is not found */ + ProfileNotFound: 'ProfileNotFound', + /** Error code for when IAM credential retrieval fails */ + CredentialRetrievalFailed: 'CredentialRetrievalFailed', + /** Error code for when IAM credential provider initialization fails */ + CredentialProviderInitFailed: 'CredentialProviderInitFailed', + /** Error code for when IAM profile type is invalid */ + InvalidProfileType: 'InvalidProfileType', + /** Error code for when IAM credential validation fails */ + IamValidationFailed: 'IamValidationFailed', + /** Error code for when sign out operation fails */ + SignOutFailed: 'SignOutFailed', + /** Error code for when domain URL format is invalid */ + InvalidDomainUrl: 'InvalidDomainUrl', + /** Error code for when connection to SMUS fails */ + FailedToConnect: 'FailedToConnect', + /** Error code for when connection is not found */ + ConnectionNotFound: 'ConnectionNotFound', + /** Error code for when connection type is invalid for the operation */ + InvalidConnectionType: 'InvalidConnectionType', + /** Error code for when no group profile is found for IAM role */ + NoGroupProfileFound: 'NoGroupProfileFound', + /** Error code for when no user profile is found for IAM principal */ + NoUserProfileFound: 'NoUserProfileFound', } as const /** @@ -74,6 +103,11 @@ export const SmusTimeouts = { apiCallTimeoutMs: 10 * 1000, } as const +/** + * DataZone service ID used for filtering regions + */ +export const DataZoneServiceId = 'datazone' + /** * Interface for AWS credential objects that need validation */ @@ -125,7 +159,7 @@ export function validateCredentialFields( * Utility class for SageMaker Unified Studio domain URL parsing and validation */ export class SmusUtils { - private static readonly logger = getLogger() + private static readonly logger = getLogger('smus') /** * Extracts the domain ID from a SageMaker Unified Studio domain URL @@ -284,7 +318,7 @@ export class SmusUtils { */ public static async getSsoInstanceInfo(domainUrl: string): Promise { try { - this.logger.info(`SMUS Auth: Getting SSO instance info from DataZone for domainurl: ${domainUrl}`) + this.logger.info(`Getting SSO instance info from DataZone for domainurl: ${domainUrl}`) // Extract domain ID from the domain URL const domainId = this.extractDomainIdFromUrl(domainUrl) @@ -318,7 +352,7 @@ export class SmusUtils { // Extract region from domain URL const region = this.extractRegionFromUrl(domainUrl) - this.logger.info('SMUS Auth: Extracted SSO instance info: %s', ssoInstanceId) + this.logger.info('Extracted SSO instance info: %s', ssoInstanceId) return { issuerUrl, @@ -328,7 +362,7 @@ export class SmusUtils { } } catch (error) { const errorMsg = error instanceof Error ? error.message : 'Unknown error' - this.logger.error('SMUS Auth: Failed to get SSO instance info: %s', errorMsg) + this.logger.error('Failed to get SSO instance info: %s', errorMsg) if (error instanceof ToolkitError) { throw error @@ -364,6 +398,96 @@ export class SmusUtils { const resourceMetadata = getResourceMetadata() return isSMUSspace && !!resourceMetadata?.AdditionalMetadata?.DataZoneDomainId } + + /** + * Extracts the session name from an assumed role ARN. + * + * Note: This function ONLY works for assumed role ARNs (arn:aws:sts::*:assumed-role/*). + * It will return undefined for other IAM principal types such as: + * - IAM users (arn:aws:iam::*:user/*) + * - IAM roles (arn:aws:iam::*:role/*) + * + * @param arn The assumed role ARN (format: arn:aws:sts::ACCOUNT:assumed-role/ROLE_NAME/SESSION_NAME) + * @returns The session name if the ARN is a valid assumed role ARN, undefined otherwise + */ + public static extractSessionNameFromArn(arn: string): string | undefined { + try { + // Expected format: arn:aws:sts::ACCOUNT:assumed-role/ROLE_NAME/SESSION_NAME + const parts = arn.split(':') + if (parts.length < 6) { + return undefined + } + + // The resource part is after the 5th colon + const resourcePart = parts.slice(5).join(':') + + // Split by '/' to get assumed-role, ROLE_NAME, and SESSION_NAME + const resourceParts = resourcePart.split('/') + if (resourceParts.length < 3 || resourceParts[0] !== 'assumed-role') { + return undefined + } + + // Session name is the last part + return resourceParts[2] + } catch (err) { + return undefined + } + } + + /** + * Determines if an ARN represents an IAM user (vs IAM role session) + * @param arn The ARN to check (format: arn:aws:iam::ACCOUNT:user/USER_NAME for IAM users, + * arn:aws:sts::ACCOUNT:assumed-role/ROLE_NAME/SESSION_NAME for role sessions) + * @returns True if the ARN is an IAM user, false otherwise + */ + public static isIamUserArn(arn: string | undefined): boolean { + if (!arn) { + return false + } + + // IAM user ARN format: arn:aws:iam::ACCOUNT:user/USER_NAME + // IAM role session ARN format: arn:aws:sts::ACCOUNT:assumed-role/ROLE_NAME/SESSION_NAME + return arn.includes(':iam::') && arn.includes(':user/') + } + + /** + * Converts an STS assumed-role ARN to its corresponding IAM role ARN, or returns IAM user ARN as-is. + * Supports all AWS partitions (aws, aws-cn, aws-us-gov, etc.) + * Examples: + * Input: arn:aws:sts::123456789012:assumed-role/MyRole/MySession + * Output: arn:aws:iam::123456789012:role/MyRole + * + * Input: arn:aws:iam::123456789012:user/MyUser + * Output: arn:aws:iam::123456789012:user/MyUser + * + * Input: arn:aws-cn:sts::123456789012:assumed-role/MyRole/MySession + * Output: arn:aws-cn:iam::123456789012:role/MyRole + */ + public static convertAssumedRoleArnToIamRoleArn(stsArn: string): string { + // Check if it's already an IAM user ARN - return as-is + // Supports all AWS partitions: aws, aws-cn, aws-us-gov, etc. + const iamUserRegex = /^arn:(aws[a-z-]*):iam::(\d{12}):user\/([A-Za-z0-9+=,.@_\/-]+)$/ + if (iamUserRegex.test(stsArn)) { + return stsArn + } + + // Check if it's already an IAM role ARN - return as-is + const iamRoleRegex = /^arn:(aws[a-z-]*):iam::(\d{12}):role\/([A-Za-z0-9+=,.@_\/-]+)$/ + if (iamRoleRegex.test(stsArn)) { + return stsArn + } + + // Try to convert STS assumed-role ARN to IAM role ARN + const arnRegex = /^arn:(aws[a-z-]*):sts::(\d{12}):assumed-role\/([A-Za-z0-9+=,.@_\/-]+)\/([A-Za-z0-9+=,.@_-]+)$/ + const match = stsArn.match(arnRegex) + if (!match) { + throw new Error(`Invalid STS ARN format: ${stsArn}`) + } + + const [, partition, accountId, roleName] = match + + return `arn:${partition}:iam::${accountId}:role/${roleName}` + } } /** @@ -393,10 +517,10 @@ export function extractAccountIdFromSageMakerArn(arn: string): string { * @throws ToolkitError if unable to extract account ID */ export async function extractAccountIdFromResourceMetadata(): Promise { - const logger = getLogger() + const logger = getLogger('smus') try { - logger.debug('SMUS: Extracting account ID from ResourceArn in resource-metadata file') + logger.debug('Extracting account ID from ResourceArn in resource-metadata file') const resourceMetadata = getResourceMetadata()! const resourceArn = resourceMetadata.ResourceArn @@ -414,3 +538,68 @@ export async function extractAccountIdFromResourceMetadata(): Promise { throw new Error('Failed to extract AWS account ID from ResourceArn in SMUS space environment') } } + +/** + * Creates a CredentialsProvider from an AWS credentials function + * @param credentialsFunction Function that returns AWS credentials + * @param credentialTypeId Identifier for the credential type + * @param hashCode Unique hash code for caching + * @param region Domain region + * @returns Complete CredentialsProvider object + */ +export function convertToToolkitCredentialProvider( + credentialsFunction: () => Promise, + credentialTypeId: string, + hashCode: string, + region: string +): CredentialsProvider { + return { + getCredentials: credentialsFunction, + getCredentialsId: () => ({ credentialSource: 'temp' as const, credentialTypeId }), + getProviderType: () => 'temp' as CredentialsProviderType, + getTelemetryType: () => 'other' as CredentialType, + getDefaultRegion: () => region, + getHashCode: () => hashCode, + canAutoConnect: () => Promise.resolve(false), + isAvailable: () => Promise.resolve(true), + } +} + +/** + * Checks if an error indicates credential/token expiration + * + * @param error The error to check (can be any type) + * @returns true if the error indicates expired credentials, false otherwise + * + */ +export function isCredentialExpirationError(error: any): boolean { + if (!error) { + return false + } + + const errorName = (error.name || '') as string + const errorMessage = (error.message || '') as string + const errorNameLower = errorName.toLowerCase() + const errorMessageLower = errorMessage.toLowerCase() + + const expirationErrorNames = ['ExpiredTokenException'] + + const expirationErrorMessages = ['The security token included in the request is expired'] + + // Return true if error name matches any expiration error names (case-insensitive) + if (expirationErrorNames.some((name) => name.toLowerCase() === errorNameLower)) { + return true + } + + // Return true if error message contains any expiration error names (case-insensitive) + if (expirationErrorNames.some((errorName) => errorMessageLower.includes(errorName.toLowerCase()))) { + return true + } + + // Return true if error message contains any expiration error messages + if (expirationErrorMessages.some((keyword) => errorMessageLower.includes(keyword.toLowerCase()))) { + return true + } + + return false +} diff --git a/packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts b/packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts index ceeb4828b83..ff518ec3b8e 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts @@ -18,7 +18,8 @@ import { SmusAuthenticationProvider } from '../auth/providers/smusAuthentication import { getLogger } from '../../shared/logger/logger' import { getContext } from '../../shared/vscode/setContext' import { ConnectionCredentialsProvider } from '../auth/providers/connectionCredentialsProvider' -import { DataZoneConnection, DataZoneClient } from './client/datazoneClient' +import { DataZoneConnection } from './client/datazoneClient' +import { createDZClientBaseOnDomainMode } from '../explorer/nodes/utils' /** * Records space telemetry @@ -27,7 +28,7 @@ export async function recordSpaceTelemetry( span: Span | Span, node: SagemakerUnifiedStudioSpaceNode ) { - const logger = getLogger() + const logger = getLogger('smus') try { const parent = node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode @@ -43,15 +44,16 @@ export async function recordSpaceTelemetry( projectAccountId = await authProvider.getProjectAccountId(projectId) // Get project region from tooling environment - const dzClient = await DataZoneClient.getInstance(authProvider) + const dzClient = await createDZClientBaseOnDomainMode(authProvider) const toolingEnv = await dzClient.getToolingEnvironment(projectId) projectRegion = toolingEnv.awsAccountRegion } span.record({ + smusAuthMode: authProvider.activeConnection?.type, smusSpaceKey: node.resource.DomainSpaceKey, smusDomainRegion: node.resource.regionCode, - smusDomainId: parent?.getAuthProvider()?.activeConnection?.domainId, + smusDomainId: parent?.getAuthProvider()?.getDomainId(), smusDomainAccountId: accountId, smusProjectId: projectId, smusProjectAccountId: projectAccountId, @@ -71,9 +73,10 @@ export async function recordAuthTelemetry( domainId: string | undefined, region: string | undefined ) { - const logger = getLogger() + const logger = getLogger('smus') span.record({ + smusAuthMode: authProvider.activeConnection?.type, smusDomainId: domainId, awsRegion: region, }) @@ -101,12 +104,15 @@ export async function recordDataConnectionTelemetry( connection: DataZoneConnection, connectionCredentialsProvider: ConnectionCredentialsProvider ) { - const logger = getLogger() + const logger = getLogger('smus') try { const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') + const authProvider = SmusAuthenticationProvider.fromContext() const accountId = await connectionCredentialsProvider.getDomainAccountId() + span.record({ + smusAuthMode: authProvider.activeConnection?.type, smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', smusDomainId: connection.domainId, smusDomainAccountId: accountId, diff --git a/packages/core/src/sagemakerunifiedstudio/shared/utils/resourceMetadataUtils.ts b/packages/core/src/sagemakerunifiedstudio/shared/utils/resourceMetadataUtils.ts index 61ce0430ecd..29c372b7968 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/utils/resourceMetadataUtils.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/utils/resourceMetadataUtils.ts @@ -51,7 +51,7 @@ export function getResourceMetadata(): ResourceMetadata | undefined { * Initializes resource metadata by reading and parsing the resource-metadata.json file */ export async function initializeResourceMetadata(): Promise { - const logger = getLogger() + const logger = getLogger('smus') if (!isSageMaker('SMUS') && !isSageMaker('SMUS-SPACE-REMOTE-ACCESS')) { logger.debug(`Not in SageMaker Unified Studio space, skipping initialization of resource metadata`) @@ -79,7 +79,7 @@ export async function resourceMetadataFileExists(): Promise { try { return await fs.existsFile(resourceMetadataPath) } catch (error) { - const logger = getLogger() + const logger = getLogger('smus') logger.error(`Failed to check if resource metadata file exists: ${error as Error}`) return false } diff --git a/packages/core/src/sagemakerunifiedstudio/uriHandlers.ts b/packages/core/src/sagemakerunifiedstudio/uriHandlers.ts index b4dc0dbcc24..ec313d50152 100644 --- a/packages/core/src/sagemakerunifiedstudio/uriHandlers.ts +++ b/packages/core/src/sagemakerunifiedstudio/uriHandlers.ts @@ -8,6 +8,7 @@ import { SearchParams } from '../shared/vscode/uriHandler' import { ExtContext } from '../shared/extensions' import { deeplinkConnect } from '../awsService/sagemaker/commands' import { telemetry } from '../shared/telemetry/telemetry' +import { SmusAuthMode } from '../shared/telemetry/telemetry.gen' /** * Registers the SMUS deeplink URI handler at path `/connect/smus`. * @@ -65,6 +66,7 @@ export function register(ctx: ExtContext) { * - smus_domain_account_id: SMUS domain account ID * - smus_project_id: SMUS project identifier * - smus_domain_region: SMUS domain region + * - smus_auth_mode: Authentication mode (sso or iam) * * Note: The ws_url from startSession API originally includes cell-number as a query parameter. * However, when the deeplink URL is processed, the URI handler extracts cell-number as a @@ -89,7 +91,8 @@ export function parseConnectParams(query: SearchParams) { 'smus_domain_id', 'smus_domain_account_id', 'smus_project_id', - 'smus_domain_region' + 'smus_domain_region', + 'smus_auth_mode' ) return { ...requiredParams, ...optionalParams } @@ -112,6 +115,10 @@ function extractTelemetryMetadata(params: ReturnType) const domainIdFromArn = resourceParts?.[1] // domain-id from ARN const spaceName = resourceParts?.[2] // space-name from ARN + // Validate and cast smusAuthMode to the expected type + const authMode = params.smus_auth_mode + const smusAuthMode: SmusAuthMode | undefined = authMode === 'sso' || authMode === 'iam' ? authMode : undefined + return { smusDomainId: params.smus_domain_id, smusDomainAccountId: params.smus_domain_account_id, @@ -120,5 +127,6 @@ function extractTelemetryMetadata(params: ReturnType) smusProjectRegion: projectRegion, smusProjectAccountId: projectAccountId, smusSpaceKey: domainIdFromArn && spaceName ? `${domainIdFromArn}/${spaceName}` : undefined, + smusAuthMode: smusAuthMode, } } diff --git a/packages/core/src/shared/clients/kubectlClient.ts b/packages/core/src/shared/clients/kubectlClient.ts index 67733bc0e0a..a374927bda4 100644 --- a/packages/core/src/shared/clients/kubectlClient.ts +++ b/packages/core/src/shared/clients/kubectlClient.ts @@ -73,13 +73,13 @@ export class KubectlClient { void vscode.window.showErrorMessage( `You do not have permission to view ${eksCluster.name} or its spaces. Please contact your administrator.` ) - throw new Error( - `Error: User has insufficient permissions to view EKS cluster (${eksCluster.name}) or its spaces.` + getLogger().warn( + `[Warning]: User has insufficient permissions to view EKS cluster (${eksCluster.name}) or its spaces.` ) } - getLogger().error( - `Error: Unavailable spaces for EKS Cluster (${eksCluster.name}): ${error}\nStack trace: ${(error as Error).stack}` + getLogger().warn( + `[Warning]: Unavailable spaces for EKS Cluster (${eksCluster.name}): ${error}\nStack trace: ${(error as Error).stack}` ) } return [] diff --git a/packages/core/src/shared/clients/sagemaker.ts b/packages/core/src/shared/clients/sagemaker.ts index 62c57a474a5..3b68c86cfda 100644 --- a/packages/core/src/shared/clients/sagemaker.ts +++ b/packages/core/src/shared/clients/sagemaker.ts @@ -60,6 +60,7 @@ import { AwsCredentialIdentity } from '@aws-sdk/types' import globals from '../extensionGlobals' import { HyperpodCluster } from './kubectlClient' import { EKSClient } from '@aws-sdk/client-eks' +import { DevSettings } from '../settings' const appTypeSettingsMap: Record = { [AppType.JupyterLab as string]: 'JupyterLabAppSettings', @@ -87,11 +88,14 @@ export class SagemakerClient extends ClientWrapper { protected override getClient(ignoreCache: boolean = false) { if (!this.client || ignoreCache) { + const devSettings = DevSettings.instance + const customEndpoint = devSettings.get('endpoints', {})['sagemaker'] + const endpoint = customEndpoint || `https://sagemaker.${this.regionCode}.amazonaws.com` const args = { serviceClient: SageMakerClient, region: this.regionCode, clientOptions: { - endpoint: `https://sagemaker.${this.regionCode}.amazonaws.com`, + endpoint: endpoint, region: this.regionCode, ...(this.credentialsProvider && { credentials: this.credentialsProvider }), }, diff --git a/packages/core/src/shared/errors.ts b/packages/core/src/shared/errors.ts index 24e3f1ba93a..292a6cc5f5a 100644 --- a/packages/core/src/shared/errors.ts +++ b/packages/core/src/shared/errors.ts @@ -375,6 +375,7 @@ export function getTelemetryResult(error: unknown | undefined): Result { * Examples: * - "Failed to save c:/fooß/bar/baz.txt" => "Failed to save c:/xß/x/x.txt" * - "EPERM for dir c:/Users/user1/.aws/sso/cache/abc123.json" => "EPERM for dir c:/Users/x/.aws/sso/cache/x.json" + * - "Error with profile my-profile" => "Error with profile [REDACTED]" */ export function scrubNames(s: string, username?: string) { let r = '' @@ -405,6 +406,10 @@ export function scrubNames(s: string, username?: string) { s = s.replaceAll(username, 'x') } + // Remove profile names that might appear in error messages + // Matches "profile" followed by optional punctuation and the profile name + s = s.replace(/(profile)\s*[:'"]?\s*([\w-]+)['"']?/gi, '$1 [REDACTED]') + // Replace contiguous whitespace with 1 space. s = s.replace(/\s+/g, ' ') diff --git a/packages/core/src/shared/globalState.ts b/packages/core/src/shared/globalState.ts index deef9f047c5..9e68836cbf1 100644 --- a/packages/core/src/shared/globalState.ts +++ b/packages/core/src/shared/globalState.ts @@ -48,6 +48,7 @@ export type globalKey = | 'aws.toolkit.lsp.versions' | 'aws.toolkit.lsp.manifest' | 'aws.amazonq.customization.overrideV2' + | 'aws.smus.authenticationPreferences' | 'aws.amazonq.regionProfiles' | 'aws.amazonq.regionProfiles.cache' // Deprecated/legacy names. New keys should start with "aws.". diff --git a/packages/core/src/shared/logger/logger.ts b/packages/core/src/shared/logger/logger.ts index cacbe260ffe..6a2c4a00511 100644 --- a/packages/core/src/shared/logger/logger.ts +++ b/packages/core/src/shared/logger/logger.ts @@ -24,6 +24,7 @@ export type LogTopic = | 'telemetry' | 'proxyUtil' | 'sagemaker' + | 'smus' class ErrorLog { constructor( diff --git a/packages/core/src/shared/settings.ts b/packages/core/src/shared/settings.ts index abdddf636a3..20bce5f21ea 100644 --- a/packages/core/src/shared/settings.ts +++ b/packages/core/src/shared/settings.ts @@ -783,6 +783,7 @@ const devSettings = { autofillStartUrl: String, webAuth: Boolean, notificationsPollInterval: Number, + datazoneScope: String, } type ResolvedDevSettings = FromDescriptor type AwsDevSetting = keyof ResolvedDevSettings diff --git a/packages/core/src/shared/telemetry/vscodeTelemetry.json b/packages/core/src/shared/telemetry/vscodeTelemetry.json index 129c7ffc702..fed66c6438a 100644 --- a/packages/core/src/shared/telemetry/vscodeTelemetry.json +++ b/packages/core/src/shared/telemetry/vscodeTelemetry.json @@ -319,6 +319,12 @@ "name": "smusProjectAccountId", "type": "string", "description": "SMUS project account id" + }, + { + "name": "smusAuthMode", + "type": "string", + "allowedValues": ["sso", "iam"], + "description": "SMUS authentication mode (SSO or IAM)" } ], "metrics": [ @@ -1414,6 +1420,10 @@ { "type": "smusDomainAccountId", "required": false + }, + { + "type": "smusAuthMode", + "required": false } ] }, @@ -1428,6 +1438,10 @@ { "type": "smusDomainAccountId", "required": false + }, + { + "type": "smusAuthMode", + "required": false } ] }, @@ -1450,6 +1464,10 @@ { "type": "smusDomainRegion", "required": false + }, + { + "type": "smusAuthMode", + "required": false } ] }, @@ -1476,6 +1494,10 @@ { "type": "smusDomainRegion", "required": false + }, + { + "type": "smusAuthMode", + "required": false } ], "passive": true @@ -1511,6 +1533,10 @@ { "type": "smusProjectAccountId", "required": false + }, + { + "type": "smusAuthMode", + "required": false } ] }, @@ -1545,6 +1571,10 @@ { "type": "smusProjectAccountId", "required": false + }, + { + "type": "smusAuthMode", + "required": false } ] }, @@ -1583,6 +1613,10 @@ { "type": "smusConnectionType", "required": false + }, + { + "type": "smusAuthMode", + "required": false } ] }, @@ -1621,6 +1655,10 @@ { "type": "smusConnectionType", "required": false + }, + { + "type": "smusAuthMode", + "required": false } ] }, @@ -1659,6 +1697,10 @@ { "type": "smusConnectionType", "required": false + }, + { + "type": "smusAuthMode", + "required": false } ] }, @@ -1700,6 +1742,10 @@ { "type": "smusSpaceKey", "required": false + }, + { + "type": "smusAuthMode", + "required": false } ] } diff --git a/packages/core/src/shared/vscode/setContext.ts b/packages/core/src/shared/vscode/setContext.ts index fb56be98f3e..ca594f9e9f6 100644 --- a/packages/core/src/shared/vscode/setContext.ts +++ b/packages/core/src/shared/vscode/setContext.ts @@ -45,6 +45,7 @@ export type contextKey = | 'aws.smus.connected' | 'aws.smus.inSmusSpaceEnvironment' | 'aws.cloudFormation.serviceEnabled' + | 'aws.smus.isIamMode' // Deprecated/legacy names. New keys should start with "aws.". | 'codewhisperer.activeLine' | 'gumby.isPlanAvailable' diff --git a/packages/core/src/test/awsService/sagemaker/credentialMapping.test.ts b/packages/core/src/test/awsService/sagemaker/credentialMapping.test.ts index c114c8b0bba..6b699dbaab4 100644 --- a/packages/core/src/test/awsService/sagemaker/credentialMapping.test.ts +++ b/packages/core/src/test/awsService/sagemaker/credentialMapping.test.ts @@ -13,7 +13,7 @@ import { saveMappings, setSpaceIamProfile, setSpaceSsoProfile, - setSmusSpaceSsoProfile, + setSmusSpaceProfile, setSpaceCredentials, } from '../../../awsService/sagemaker/credentialMapping' import { Auth } from '../../../auth' @@ -319,6 +319,7 @@ describe('credentialMapping', () => { mockNode.getParent.returns(mockParent as any) mockParent.getAuthProvider.returns(mockAuthProvider as any) mockParent.getProjectId.returns(projectId) + sandbox.stub(require('../../../sagemakerunifiedstudio/auth/model'), 'isSmusSsoConnection').returns(true) sandbox.stub(fs, 'existsFile').resolves(false) const writeStub = sandbox.stub(fs, 'writeFile').resolves() @@ -457,7 +458,7 @@ describe('credentialMapping', () => { }) }) - describe('setSmusSpaceSsoProfile', () => { + describe('setSmusSpaceProfile', () => { let sandbox: sinon.SinonSandbox beforeEach(() => { @@ -472,7 +473,7 @@ describe('credentialMapping', () => { sandbox.stub(fs, 'existsFile').resolves(false) const writeStub = sandbox.stub(fs, 'writeFile').resolves() - await setSmusSpaceSsoProfile('test-space', 'project-id') + await setSmusSpaceProfile('test-space', 'project-id', 'sso') const raw = writeStub.firstCall.args[1] const data = JSON.parse(typeof raw === 'string' ? raw : raw.toString()) diff --git a/packages/core/src/test/sagemakerunifiedstudio/auth/authenticationOrchestrator.test.ts b/packages/core/src/test/sagemakerunifiedstudio/auth/authenticationOrchestrator.test.ts new file mode 100644 index 00000000000..095922dfc6f --- /dev/null +++ b/packages/core/src/test/sagemakerunifiedstudio/auth/authenticationOrchestrator.test.ts @@ -0,0 +1,76 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as assert from 'assert' +import { SmusAuthenticationOrchestrator } from '../../../sagemakerunifiedstudio/auth/authenticationOrchestrator' + +describe('SmusAuthenticationOrchestrator', function () { + // Note: Due to AWS Toolkit test framework restrictions on mocking vscode.window, + // these tests focus on the interface and behavior rather than deep mocking. + // The actual authentication flows are tested through integration tests. + + describe('handleIamAuthentication', function () { + it('should export the correct interface', function () { + // Verify the class exists and has the expected static method + assert.ok('handleIamAuthentication' in SmusAuthenticationOrchestrator) + assert.strictEqual(typeof SmusAuthenticationOrchestrator.handleIamAuthentication, 'function') + }) + + it('should be callable without throwing', function () { + // Verify the method exists and is accessible + assert.doesNotThrow(() => { + assert.ok('handleIamAuthentication' in SmusAuthenticationOrchestrator) + }) + }) + }) + + describe('handleSsoAuthentication', function () { + it('should export the correct interface', function () { + // Verify the class exists and has the expected static method + assert.ok('handleSsoAuthentication' in SmusAuthenticationOrchestrator) + assert.strictEqual(typeof SmusAuthenticationOrchestrator.handleSsoAuthentication, 'function') + }) + + it('should be callable without throwing', function () { + // Verify the method exists and is accessible + assert.doesNotThrow(() => { + assert.ok('handleSsoAuthentication' in SmusAuthenticationOrchestrator) + }) + }) + }) + + describe('return types', function () { + it('should handle SUCCESS and BACK return types correctly', function () { + // Test that the return types are properly defined for both methods + const testResult1: 'SUCCESS' | 'BACK' = 'SUCCESS' + const testResult2: 'SUCCESS' | 'BACK' = 'BACK' + + assert.strictEqual(testResult1, 'SUCCESS') + assert.strictEqual(testResult2, 'BACK') + }) + }) + + describe('class structure', function () { + it('should be a class with static methods', function () { + // Verify the orchestrator is properly structured + assert.strictEqual(typeof SmusAuthenticationOrchestrator, 'function') + assert.ok(SmusAuthenticationOrchestrator.prototype) + }) + + it('should have both required authentication methods', function () { + // Verify both authentication methods exist + const methods = ['handleIamAuthentication', 'handleSsoAuthentication'] + + for (const method of methods) { + assert.ok(method in SmusAuthenticationOrchestrator, `Missing method: ${method}`) + assert.strictEqual( + typeof SmusAuthenticationOrchestrator[method as keyof typeof SmusAuthenticationOrchestrator], + 'function', + `${method} should be a function` + ) + } + }) + }) +}) diff --git a/packages/core/src/test/sagemakerunifiedstudio/auth/connectionCredentialsProvider.test.ts b/packages/core/src/test/sagemakerunifiedstudio/auth/connectionCredentialsProvider.test.ts index 951e391d181..8d14d21ba57 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/auth/connectionCredentialsProvider.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/auth/connectionCredentialsProvider.test.ts @@ -16,6 +16,7 @@ describe('ConnectionCredentialsProvider', function () { let connectionProvider: ConnectionCredentialsProvider let dataZoneClientStub: sinon.SinonStub + const testProjectId = 'proj-123456' const testConnectionId = 'conn-123456' const testDomainId = 'dzd_testdomain' const testRegion = 'us-east-2' @@ -42,6 +43,13 @@ describe('ConnectionCredentialsProvider', function () { isConnected: sinon.stub().returns(true), getDomainId: sinon.stub().returns(testDomainId), getDomainRegion: sinon.stub().returns(testRegion), + getDerCredentialsProvider: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), activeConnection: { ssoRegion: testRegion, }, @@ -52,10 +60,10 @@ describe('ConnectionCredentialsProvider', function () { getConnection: sinon.stub().resolves(mockGetConnectionResponse), } as any - // Stub DataZoneClient.getInstance - dataZoneClientStub = sinon.stub(DataZoneClient, 'getInstance').resolves(mockDataZoneClient as any) + // Stub DataZoneClient.createWithCredentials + dataZoneClientStub = sinon.stub(DataZoneClient, 'createWithCredentials').returns(mockDataZoneClient as any) - connectionProvider = new ConnectionCredentialsProvider(mockAuthProvider as any, testConnectionId) + connectionProvider = new ConnectionCredentialsProvider(mockAuthProvider as any, testConnectionId, testProjectId) }) afterEach(function () { diff --git a/packages/core/src/test/sagemakerunifiedstudio/auth/model.test.ts b/packages/core/src/test/sagemakerunifiedstudio/auth/model.test.ts index a6ca72736e9..e1d030ea825 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/auth/model.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/auth/model.test.ts @@ -3,230 +3,343 @@ * SPDX-License-Identifier: Apache-2.0 */ -import assert from 'assert' +import * as assert from 'assert' +import { Credentials } from '@aws-sdk/types' +import * as sinon from 'sinon' import { - createSmusProfile, + SmusSsoConnection, + SmusIamConnection, + isSmusIamConnection, + isSmusSsoConnection, isValidSmusConnection, + createSmusProfile, scopeSmus, - SmusConnection, + getDataZoneSsoScope, } from '../../../sagemakerunifiedstudio/auth/model' -import { SsoConnection } from '../../../auth/connection' +import { DevSettings } from '../../../shared/settings' -describe('SMUS Auth Model', function () { - const testDomainUrl = 'https://dzd_domainId.sagemaker.us-east-2.on.aws' - const testDomainId = 'dzd_domainId' - const testStartUrl = 'https://identitycenter.amazonaws.com/ssoins-testInstanceId' - const testRegion = 'us-east-2' +describe('SMUS Connection Model', function () { + let sandbox: sinon.SinonSandbox - describe('scopeSmus', function () { - it('should have correct scope value', function () { - assert.strictEqual(scopeSmus, 'datazone:domain:access') - }) + beforeEach(function () { + sandbox = sinon.createSandbox() }) - describe('createSmusProfile', function () { - it('should create profile with default scopes', function () { - const profile = createSmusProfile(testDomainUrl, testDomainId, testStartUrl, testRegion) + afterEach(function () { + sandbox.restore() + }) - assert.strictEqual(profile.domainUrl, testDomainUrl) - assert.strictEqual(profile.domainId, testDomainId) - assert.strictEqual(profile.startUrl, testStartUrl) - assert.strictEqual(profile.ssoRegion, testRegion) - assert.strictEqual(profile.type, 'sso') - assert.deepStrictEqual(profile.scopes, [scopeSmus]) - }) + const mockCredentials: Credentials = { + accessKeyId: 'test-access-key', + secretAccessKey: 'test-secret-key', + } - it('should create profile with custom scopes', function () { - const customScopes = ['custom:scope', 'another:scope'] - const profile = createSmusProfile(testDomainUrl, testDomainId, testStartUrl, testRegion, customScopes) + const mockCredentialsProvider = async (): Promise => mockCredentials - assert.strictEqual(profile.domainUrl, testDomainUrl) - assert.strictEqual(profile.domainId, testDomainId) - assert.strictEqual(profile.startUrl, testStartUrl) - assert.strictEqual(profile.ssoRegion, testRegion) - assert.strictEqual(profile.type, 'sso') - assert.deepStrictEqual(profile.scopes, customScopes) - }) + const mockGetToken = async () => ({ + accessToken: 'mock-access-token', + expiresAt: new Date(Date.now() + 3600000), // 1 hour from now + }) - it('should create profile with all required properties', function () { - const profile = createSmusProfile(testDomainUrl, testDomainId, testStartUrl, testRegion) + const mockGetRegistration = async () => ({ + clientId: 'mock-client-id', + clientSecret: 'mock-client-secret', + expiresAt: new Date(Date.now() + 86400000), // 24 hours from now + startUrl: 'https://test.sagemaker.us-east-1.on.aws/', + }) - // Check SsoProfile properties - assert.strictEqual(profile.type, 'sso') - assert.strictEqual(profile.startUrl, testStartUrl) - assert.strictEqual(profile.ssoRegion, testRegion) - assert.ok(Array.isArray(profile.scopes)) + describe('isSmusIamConnection', function () { + it('should return true for valid SMUS IAM connection', function () { + const connection: SmusIamConnection = { + type: 'iam', + profileName: 'test-profile', + region: 'us-east-1', + domainUrl: 'https://test.sagemaker.us-east-1.on.aws/', + domainId: 'test-domain-id', + id: 'test-id', + label: 'Test IAM Connection', + endpointUrl: undefined, + getCredentials: mockCredentialsProvider, + } - // Check SmusProfile properties - assert.strictEqual(profile.domainUrl, testDomainUrl) - assert.strictEqual(profile.domainId, testDomainId) + assert.strictEqual(isSmusIamConnection(connection), true) }) - }) - describe('isValidSmusConnection', function () { - it('should return true for valid SMUS connection', function () { - const validConnection = { - id: 'test-connection-id', + it('should return false for SSO connection', function () { + const connection: SmusSsoConnection = { type: 'sso', - startUrl: testStartUrl, - ssoRegion: testRegion, + startUrl: 'https://test.awsapps.com/start', + ssoRegion: 'us-east-1', scopes: [scopeSmus], - label: 'Test SMUS Connection', - domainUrl: testDomainUrl, - domainId: testDomainId, - } as SmusConnection + domainUrl: 'https://test.sagemaker.us-east-1.on.aws/', + domainId: 'test-domain-id', + id: 'test-id', + label: 'Test SSO Connection', + getToken: mockGetToken, + getRegistration: mockGetRegistration, + } - assert.strictEqual(isValidSmusConnection(validConnection), true) + assert.strictEqual(isSmusIamConnection(connection), false) }) - it('should return false for connection without SMUS scope', function () { - const connectionWithoutScope = { - id: 'test-connection-id', - type: 'sso', - startUrl: testStartUrl, - ssoRegion: testRegion, - scopes: ['sso:account:access'], + it('should return false for connection missing required IAM properties', function () { + const connection = { + type: 'iam', + profileName: 'test-profile', + // Missing region, domainUrl, domainId, getCredentials + id: 'test-id', + label: 'Test IAM Connection', + endpointUrl: undefined, + } + + assert.strictEqual(isSmusIamConnection(connection as any), false) + }) + + it('should return false for undefined connection', function () { + assert.strictEqual(isSmusIamConnection(undefined), false) + }) + + it('should return false for connection with wrong type', function () { + const connection = { + type: 'other', + profileName: 'test-profile', + region: 'us-east-1', + domainUrl: 'https://test.sagemaker.us-east-1.on.aws/', + domainId: 'test-domain-id', + id: 'test-id', label: 'Test Connection', - domainUrl: testDomainUrl, - domainId: testDomainId, - } as any + } - assert.strictEqual(isValidSmusConnection(connectionWithoutScope), false) + assert.strictEqual(isSmusIamConnection(connection as any), false) }) + }) - it('should return false for connection without SMUS properties', function () { - const connectionWithoutSmusProps = { - id: 'test-connection-id', + describe('isSmusSsoConnection', function () { + it('should return true for valid SMUS SSO connection', function () { + const connection: SmusSsoConnection = { type: 'sso', - startUrl: testStartUrl, - ssoRegion: testRegion, + startUrl: 'https://test.awsapps.com/start', + ssoRegion: 'us-east-1', scopes: [scopeSmus], - label: 'Test Connection', - } as SsoConnection + domainUrl: 'https://test.sagemaker.us-east-1.on.aws/', + domainId: 'test-domain-id', + id: 'test-id', + label: 'Test SSO Connection', + getToken: mockGetToken, + getRegistration: mockGetRegistration, + } - assert.strictEqual(isValidSmusConnection(connectionWithoutSmusProps), false) + assert.strictEqual(isSmusSsoConnection(connection), true) }) - it('should return false for non-SSO connection', function () { - const nonSsoConnection = { - id: 'test-connection-id', + it('should return false for IAM connection', function () { + const connection: SmusIamConnection = { type: 'iam', + profileName: 'test-profile', + region: 'us-east-1', + domainUrl: 'https://test.sagemaker.us-east-1.on.aws/', + domainId: 'test-domain-id', + id: 'test-id', label: 'Test IAM Connection', - domainUrl: testDomainUrl, - domainId: testDomainId, - scopes: [scopeSmus], + endpointUrl: undefined, + getCredentials: mockCredentialsProvider, } - assert.strictEqual(isValidSmusConnection(nonSsoConnection), false) + assert.strictEqual(isSmusSsoConnection(connection), false) }) - it('should return false for undefined connection', function () { - assert.strictEqual(isValidSmusConnection(undefined), false) - }) + it('should return false for SSO connection without SMUS scope', function () { + const connection = { + type: 'sso', + startUrl: 'https://test.awsapps.com/start', + ssoRegion: 'us-east-1', + scopes: ['other:scope'], + domainUrl: 'https://test.sagemaker.us-east-1.on.aws/', + domainId: 'test-domain-id', + id: 'test-id', + label: 'Test SSO Connection', + } - it('should return false for null connection', function () { - assert.strictEqual(isValidSmusConnection(undefined), false) + assert.strictEqual(isSmusSsoConnection(connection as any), false) }) - it('should return false for connection without scopes', function () { - const connectionWithoutScopes = { - id: 'test-connection-id', + it('should return false for SSO connection missing SMUS properties', function () { + const connection = { type: 'sso', - startUrl: testStartUrl, - ssoRegion: testRegion, - label: 'Test Connection', - domainUrl: testDomainUrl, - domainId: testDomainId, + startUrl: 'https://test.awsapps.com/start', + ssoRegion: 'us-east-1', + scopes: [scopeSmus], + // Missing domainUrl and domainId + id: 'test-id', + label: 'Test SSO Connection', } - assert.strictEqual(isValidSmusConnection(connectionWithoutScopes), false) + assert.strictEqual(isSmusSsoConnection(connection as any), false) + }) + + it('should return false for undefined connection', function () { + assert.strictEqual(isSmusSsoConnection(undefined), false) }) + }) - it('should return false for connection with empty scopes array', function () { - const connectionWithEmptyScopes = { - id: 'test-connection-id', + describe('isValidSmusConnection', function () { + it('should return true for valid SMUS SSO connection', function () { + const connection: SmusSsoConnection = { type: 'sso', - startUrl: testStartUrl, - ssoRegion: testRegion, - scopes: [], - label: 'Test Connection', - domainUrl: testDomainUrl, - domainId: testDomainId, + startUrl: 'https://test.awsapps.com/start', + ssoRegion: 'us-east-1', + scopes: [scopeSmus], + domainUrl: 'https://test.sagemaker.us-east-1.on.aws/', + domainId: 'test-domain-id', + id: 'test-id', + label: 'Test SSO Connection', + getToken: mockGetToken, + getRegistration: mockGetRegistration, } - assert.strictEqual(isValidSmusConnection(connectionWithEmptyScopes), false) + assert.strictEqual(isValidSmusConnection(connection), true) }) - it('should return true for connection with SMUS scope among other scopes', function () { - const connectionWithMultipleScopes = { - id: 'test-connection-id', - type: 'sso', - startUrl: testStartUrl, - ssoRegion: testRegion, - scopes: ['sso:account:access', scopeSmus, 'other:scope'], - label: 'Test SMUS Connection', - domainUrl: testDomainUrl, - domainId: testDomainId, - } as SmusConnection + it('should return true for valid SMUS IAM connection', function () { + const connection: SmusIamConnection = { + type: 'iam', + profileName: 'test-profile', + region: 'us-east-1', + domainUrl: 'https://test.sagemaker.us-east-1.on.aws/', + domainId: 'test-domain-id', + id: 'test-id', + label: 'Test IAM Connection', + endpointUrl: undefined, + getCredentials: mockCredentialsProvider, + } - assert.strictEqual(isValidSmusConnection(connectionWithMultipleScopes), true) + assert.strictEqual(isValidSmusConnection(connection), true) }) - it('should return false for connection missing domainUrl', function () { - const connectionMissingDomainUrl = { - id: 'test-connection-id', - type: 'sso', - startUrl: testStartUrl, - ssoRegion: testRegion, - scopes: [scopeSmus], + it('should return false for invalid connection', function () { + const connection = { + type: 'other', + id: 'test-id', label: 'Test Connection', - domainId: testDomainId, } - assert.strictEqual(isValidSmusConnection(connectionMissingDomainUrl), false) + assert.strictEqual(isValidSmusConnection(connection), false) }) - it('should return false for connection missing domainId', function () { - const connectionMissingDomainId = { - id: 'test-connection-id', - type: 'sso', - startUrl: testStartUrl, - ssoRegion: testRegion, - scopes: [scopeSmus], - label: 'Test Connection', - domainUrl: testDomainUrl, - } + it('should return false for undefined connection', function () { + assert.strictEqual(isValidSmusConnection(undefined), false) + }) + }) + + describe('getDataZoneSsoScope', function () { + it('should return default scope when no custom setting is provided', function () { + // When get() is called with default value, it returns the default (scopeSmus) + // This simulates the behavior when aws.dev.datazoneScope is not set + sandbox.stub(DevSettings.instance, 'get').withArgs('datazoneScope', scopeSmus).returns(scopeSmus) + + const scope = getDataZoneSsoScope() + + assert.strictEqual(scope, scopeSmus) + }) + + it('should return custom scope when setting is configured', function () { + const customScope = 'custom:datazone:scope' + // When get() is called, it returns the custom value from settings + // This simulates the behavior when aws.dev.datazoneScope is set to customScope + sandbox.stub(DevSettings.instance, 'get').withArgs('datazoneScope', scopeSmus).returns(customScope) + + const scope = getDataZoneSsoScope() + + assert.strictEqual(scope, customScope) + }) + }) + + describe('createSmusProfile', function () { + it('should create a valid SMUS profile with default scope', function () { + sandbox.stub(DevSettings.instance, 'get').withArgs('datazoneScope', scopeSmus).returns(scopeSmus) + + const domainUrl = 'https://test.sagemaker.us-east-1.on.aws/' + const domainId = 'test-domain-id' + const startUrl = 'https://test.awsapps.com/start' + const region = 'us-east-1' + + const profile = createSmusProfile(domainUrl, domainId, startUrl, region) + + assert.strictEqual(profile.domainUrl, domainUrl) + assert.strictEqual(profile.domainId, domainId) + assert.strictEqual(profile.startUrl, startUrl) + assert.strictEqual(profile.ssoRegion, region) + assert.strictEqual(profile.type, 'sso') + assert.deepStrictEqual(profile.scopes, [scopeSmus]) + }) + + it('should create a valid SMUS profile with custom scope from settings', function () { + const customScope = 'custom:datazone:scope' + sandbox.stub(DevSettings.instance, 'get').withArgs('datazoneScope', scopeSmus).returns(customScope) + + const domainUrl = 'https://test.sagemaker.us-east-1.on.aws/' + const domainId = 'test-domain-id' + const startUrl = 'https://test.awsapps.com/start' + const region = 'us-east-1' + + const profile = createSmusProfile(domainUrl, domainId, startUrl, region) - assert.strictEqual(isValidSmusConnection(connectionMissingDomainId), false) + assert.deepStrictEqual(profile.scopes, [customScope]) + }) + + it('should create a valid SMUS profile with custom scopes parameter', function () { + const domainUrl = 'https://test.sagemaker.us-east-1.on.aws/' + const domainId = 'test-domain-id' + const startUrl = 'https://test.awsapps.com/start' + const region = 'us-east-1' + const customScopes = ['custom:scope1', 'custom:scope2'] + + const profile = createSmusProfile(domainUrl, domainId, startUrl, region, customScopes) + + assert.deepStrictEqual(profile.scopes, customScopes) }) }) - describe('SmusConnection interface', function () { - it('should extend both SmusProfile and SsoConnection', function () { + describe('isSmusSsoConnection with custom scope', function () { + it('should return true for connection with custom scope from settings', function () { + const customScope = 'custom:datazone:scope' + sandbox.stub(DevSettings.instance, 'get').withArgs('datazoneScope', scopeSmus).returns(customScope) + const connection = { - id: 'test-connection-id', type: 'sso', - startUrl: testStartUrl, - ssoRegion: testRegion, - scopes: [scopeSmus], - label: 'Test SMUS Connection', - domainUrl: testDomainUrl, - domainId: testDomainId, - } as SmusConnection - - // Should have Connection properties - assert.strictEqual(connection.id, 'test-connection-id') - assert.strictEqual(connection.label, 'Test SMUS Connection') - - // Should have SsoConnection properties - assert.strictEqual(connection.type, 'sso') - assert.strictEqual(connection.startUrl, testStartUrl) - assert.strictEqual(connection.ssoRegion, testRegion) - assert.ok(Array.isArray(connection.scopes)) - - // Should have SmusProfile properties - assert.strictEqual(connection.domainUrl, testDomainUrl) - assert.strictEqual(connection.domainId, testDomainId) + startUrl: 'https://test.awsapps.com/start', + ssoRegion: 'us-east-1', + scopes: [customScope], + domainUrl: 'https://test.sagemaker.us-east-1.on.aws/', + domainId: 'test-domain-id', + id: 'test-id', + label: 'Test SSO Connection', + getToken: mockGetToken, + getRegistration: mockGetRegistration, + } as SmusSsoConnection + + assert.strictEqual(isSmusSsoConnection(connection), true) + }) + + it('should return true for connection with default scope even when custom scope is configured', function () { + const customScope = 'custom:datazone:scope' + sandbox.stub(DevSettings.instance, 'get').withArgs('datazoneScope', scopeSmus).returns(customScope) + + const connection = { + type: 'sso', + startUrl: 'https://test.awsapps.com/start', + ssoRegion: 'us-east-1', + scopes: [scopeSmus], // Using default scope + domainUrl: 'https://test.sagemaker.us-east-1.on.aws/', + domainId: 'test-domain-id', + id: 'test-id', + label: 'Test SSO Connection', + getToken: mockGetToken, + getRegistration: mockGetRegistration, + } as SmusSsoConnection + + // Should still work for backward compatibility + assert.strictEqual(isSmusSsoConnection(connection), true) }) }) }) diff --git a/packages/core/src/test/sagemakerunifiedstudio/auth/preferences/authenticationPreferences.test.ts b/packages/core/src/test/sagemakerunifiedstudio/auth/preferences/authenticationPreferences.test.ts new file mode 100644 index 00000000000..06549830934 --- /dev/null +++ b/packages/core/src/test/sagemakerunifiedstudio/auth/preferences/authenticationPreferences.test.ts @@ -0,0 +1,303 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as assert from 'assert' +import * as sinon from 'sinon' +import { + SmusAuthenticationPreferencesManager, + SmusAuthenticationPreferences, + SmusIamProfileConfig, +} from '../../../../sagemakerunifiedstudio/auth/preferences/authenticationPreferences' +import { globals } from '../../../../shared' + +describe('SmusAuthenticationPreferencesManager', function () { + let mockContext: any + let sandbox: sinon.SinonSandbox + let mockGlobalState: any + + beforeEach(function () { + sandbox = sinon.createSandbox() + + // Mock the globals.globalState instead of context.globalState directly + mockGlobalState = { + get: sandbox.stub(), + update: sandbox.stub().resolves(), + } + + // Mock VS Code extension context (still needed for the API) + mockContext = { + globalState: mockGlobalState, + } + + // Stub globals.globalState to use our mock + sandbox.stub(globals, 'globalState').value(mockGlobalState) + }) + + afterEach(function () { + sandbox.restore() + }) + + describe('getPreferences', function () { + it('should return default preferences when none are stored', function () { + // Setup + mockGlobalState.get.returns(undefined) + + // Act + const preferences = SmusAuthenticationPreferencesManager.getPreferences(mockContext) + + // Assert + assert.deepStrictEqual(preferences, { + rememberChoice: false, + }) + }) + + it('should return stored preferences when available', function () { + // Setup + const storedPreferences: SmusAuthenticationPreferences = { + preferredMethod: 'iam', + rememberChoice: true, + lastUsedSsoConnection: 'conn-123', + lastUsedIamProfile: { + profileName: 'default', + region: 'us-east-1', + lastUsed: new Date('2023-01-01'), + isDefault: true, + }, + } + mockGlobalState.get.returns(storedPreferences) + + // Act + const preferences = SmusAuthenticationPreferencesManager.getPreferences(mockContext) + + // Assert + assert.deepStrictEqual(preferences, storedPreferences) + }) + + it('should merge stored preferences with defaults', function () { + // Setup + const partialPreferences = { + preferredMethod: 'sso' as const, + } + mockGlobalState.get.returns(partialPreferences) + + // Act + const preferences = SmusAuthenticationPreferencesManager.getPreferences(mockContext) + + // Assert + assert.deepStrictEqual(preferences, { + preferredMethod: 'sso', + rememberChoice: false, + }) + }) + }) + + describe('updatePreferences', function () { + it('should update preferences correctly', async function () { + // Setup + const currentPreferences: SmusAuthenticationPreferences = { + preferredMethod: 'sso', + rememberChoice: true, + } + mockGlobalState.get.returns(currentPreferences) + + const updates = { + preferredMethod: 'iam' as const, + lastUsedSsoConnection: 'conn-456', + } + + // Act + await SmusAuthenticationPreferencesManager.updatePreferences(mockContext, updates) + + // Assert + assert.strictEqual(mockGlobalState.update.calledOnce, true) + const [key, updatedPreferences] = mockGlobalState.update.firstCall.args + assert.strictEqual(key, 'aws.smus.authenticationPreferences') + assert.deepStrictEqual(updatedPreferences, { + preferredMethod: 'iam', + rememberChoice: true, + lastUsedSsoConnection: 'conn-456', + }) + }) + }) + + describe('setPreferredMethod', function () { + it('should set preferred method and remember choice', async function () { + // Setup + mockGlobalState.get.returns({}) + + // Act + await SmusAuthenticationPreferencesManager.setPreferredMethod(mockContext, 'iam', true) + + // Assert + assert.strictEqual(mockGlobalState.update.calledOnce, true) + const [key, preferences] = mockGlobalState.update.firstCall.args + assert.strictEqual(key, 'aws.smus.authenticationPreferences') + assert.deepStrictEqual(preferences, { + preferredMethod: 'iam', + rememberChoice: true, + }) + }) + }) + + describe('getPreferredMethod', function () { + it('should return preferred method when remember choice is true', function () { + // Setup + const preferences: SmusAuthenticationPreferences = { + preferredMethod: 'iam', + rememberChoice: true, + } + mockGlobalState.get.returns(preferences) + + // Act + const method = SmusAuthenticationPreferencesManager.getPreferredMethod(mockContext) + + // Assert + assert.strictEqual(method, 'iam') + }) + + it('should return undefined when remember choice is false', function () { + // Setup + const preferences: SmusAuthenticationPreferences = { + preferredMethod: 'iam', + rememberChoice: false, + } + mockGlobalState.get.returns(preferences) + + // Act + const method = SmusAuthenticationPreferencesManager.getPreferredMethod(mockContext) + + // Assert + assert.strictEqual(method, undefined) + }) + + it('should return undefined when no preferred method is set', function () { + // Setup + const preferences: SmusAuthenticationPreferences = { + rememberChoice: true, + } + mockGlobalState.get.returns(preferences) + + // Act + const method = SmusAuthenticationPreferencesManager.getPreferredMethod(mockContext) + + // Assert + assert.strictEqual(method, undefined) + }) + }) + + describe('setLastUsedSsoConnection', function () { + it('should set last used SSO connection', async function () { + // Setup + mockGlobalState.get.returns({}) + + // Act + await SmusAuthenticationPreferencesManager.setLastUsedSsoConnection(mockContext, 'conn-789') + + // Assert + assert.strictEqual(mockGlobalState.update.calledOnce, true) + const [key, preferences] = mockGlobalState.update.firstCall.args + assert.strictEqual(key, 'aws.smus.authenticationPreferences') + assert.deepStrictEqual(preferences, { + rememberChoice: false, + lastUsedSsoConnection: 'conn-789', + }) + }) + }) + + describe('setLastUsedIamProfile', function () { + it('should set last used IAM profile with timestamp', async function () { + // Setup + mockGlobalState.get.returns({}) + const profileConfig: SmusIamProfileConfig = { + profileName: 'production', + region: 'us-west-2', + isDefault: false, + } + + // Act + await SmusAuthenticationPreferencesManager.setLastUsedIamProfile(mockContext, profileConfig) + + // Assert + assert.strictEqual(mockGlobalState.update.calledOnce, true) + const [key, preferences] = mockGlobalState.update.firstCall.args + assert.strictEqual(key, 'aws.smus.authenticationPreferences') + + assert.strictEqual(preferences.lastUsedIamProfile.profileName, 'production') + assert.strictEqual(preferences.lastUsedIamProfile.region, 'us-west-2') + assert.strictEqual(preferences.lastUsedIamProfile.isDefault, false) + assert.ok(preferences.lastUsedIamProfile.lastUsed instanceof Date) + }) + }) + + describe('getLastUsedIamProfile', function () { + it('should return last used IAM profile when available', function () { + // Setup + const profileConfig: SmusIamProfileConfig = { + profileName: 'test-profile', + region: 'eu-west-1', + lastUsed: new Date('2023-06-01'), + isDefault: true, + } + const preferences: SmusAuthenticationPreferences = { + rememberChoice: false, + lastUsedIamProfile: profileConfig, + } + mockGlobalState.get.returns(preferences) + + // Act + const result = SmusAuthenticationPreferencesManager.getLastUsedIamProfile(mockContext) + + // Assert + assert.deepStrictEqual(result, profileConfig) + }) + + it('should return undefined when no IAM profile is stored', function () { + // Setup + mockGlobalState.get.returns({}) + + // Act + const result = SmusAuthenticationPreferencesManager.getLastUsedIamProfile(mockContext) + + // Assert + assert.strictEqual(result, undefined) + }) + }) + + describe('clearPreferences', function () { + it('should clear all preferences', async function () { + // Act + await SmusAuthenticationPreferencesManager.clearPreferences(mockContext) + + // Assert + assert.strictEqual(mockGlobalState.update.calledOnce, true) + const [key, value] = mockGlobalState.update.firstCall.args + assert.strictEqual(key, 'aws.smus.authenticationPreferences') + assert.strictEqual(value, undefined) + }) + }) + + describe('switchAuthenticationMethod', function () { + it('should switch authentication method', async function () { + // Setup + const currentPreferences: SmusAuthenticationPreferences = { + preferredMethod: 'sso', + rememberChoice: true, + } + mockGlobalState.get.returns(currentPreferences) + + // Act + await SmusAuthenticationPreferencesManager.switchAuthenticationMethod(mockContext, 'iam') + + // Assert + assert.strictEqual(mockGlobalState.update.calledOnce, true) + const [key, preferences] = mockGlobalState.update.firstCall.args + assert.strictEqual(key, 'aws.smus.authenticationPreferences') + assert.deepStrictEqual(preferences, { + preferredMethod: 'iam', + rememberChoice: true, + }) + }) + }) +}) diff --git a/packages/core/src/test/sagemakerunifiedstudio/auth/projectRoleCredentialsProvider.test.ts b/packages/core/src/test/sagemakerunifiedstudio/auth/projectRoleCredentialsProvider.test.ts index 6dd206593f8..bef7cc1885d 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/auth/projectRoleCredentialsProvider.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/auth/projectRoleCredentialsProvider.test.ts @@ -36,6 +36,25 @@ describe('ProjectRoleCredentialsProvider', function () { getDomainId: sinon.stub().returns(testDomainId), getDomainRegion: sinon.stub().returns(testRegion), isConnected: sinon.stub().returns(true), + activeConnection: { + profileName: 'test-profile', + domainId: testDomainId, + ssoRegion: testRegion, + }, + getDerCredentialsProvider: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), + getCredentialsProviderForIamProfile: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'profile-key', + secretAccessKey: 'profile-secret', + sessionToken: 'profile-token', + }), + }), } as any // Mock DataZone client @@ -43,8 +62,7 @@ describe('ProjectRoleCredentialsProvider', function () { getProjectDefaultEnvironmentCreds: sinon.stub().resolves(mockGetEnvironmentCredentialsResponse), } as any - // Stub DataZoneClient.getInstance - dataZoneClientStub = sinon.stub(DataZoneClient, 'getInstance').resolves(mockDataZoneClient as any) + dataZoneClientStub = sinon.stub(DataZoneClient, 'createWithCredentials').returns(mockDataZoneClient as any) projectProvider = new ProjectRoleCredentialsProvider(mockSmusAuthProvider, testProjectId) }) @@ -111,8 +129,8 @@ describe('ProjectRoleCredentialsProvider', function () { it('should fetch and cache project credentials', async function () { const credentials = await projectProvider.getCredentials() - // Verify DataZone client getInstance was called - assert.ok(dataZoneClientStub.calledWith(mockSmusAuthProvider)) + // Verify DataZone client createWithCredentials was called with correct parameters + assert.ok(dataZoneClientStub.calledWith(testRegion, testDomainId, sinon.match.any)) // Verify getProjectDefaultEnvironmentCreds was called assert.ok(mockDataZoneClient.getProjectDefaultEnvironmentCreds.called) diff --git a/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts b/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts index 7cd2662f467..3b53eddb678 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts @@ -6,9 +6,11 @@ import assert from 'assert' import sinon from 'sinon' import * as vscode from 'vscode' - -// Mock the setContext function BEFORE importing modules that use it -const setContextModule = require('../../../shared/vscode/setContext') +import * as setContextModule from '../../../shared/vscode/setContext' +import * as secondaryAuthModule from '../../../auth/secondaryAuth' +import * as sharedCredentialsModule from '../../../auth/credentials/sharedCredentials' +import * as stsClientModule from '../../../shared/clients/stsClient' +import { DataZoneCustomClientHelper } from '../../../sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper' import { SmusAuthenticationProvider } from '../../../sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider' import { SmusConnection } from '../../../sagemakerunifiedstudio/auth/model' @@ -31,6 +33,7 @@ describe('SmusAuthenticationProvider', function () { let isInSmusSpaceEnvironmentStub: sinon.SinonStub let executeCommandStub: sinon.SinonStub let setContextStubGlobal: sinon.SinonStub + let getResourceMetadataStub: sinon.SinonStub let mockSecondaryAuthState: { activeConnection: SmusConnection | undefined hasSavedConnection: boolean @@ -88,6 +91,10 @@ describe('SmusAuthenticationProvider', function () { get isConnectionExpired() { return mockSecondaryAuthState.isConnectionExpired }, + state: { + get: sinon.stub().returns({}), + update: sinon.stub().resolves(), + }, onDidChangeActiveConnection: sinon.stub().returns({ dispose: sinon.stub() }), restoreConnection: sinon.stub().resolves(), useNewConnection: sinon.stub().resolves(mockSmusConnection), @@ -99,16 +106,16 @@ describe('SmusAuthenticationProvider', function () { } as any // Stub static methods - sinon.stub(DataZoneClient, 'getInstance').returns(mockDataZoneClient as any) + sinon.stub(DataZoneClient, 'createWithCredentials').returns(mockDataZoneClient as any) extractDomainInfoStub = sinon .stub(SmusUtils, 'extractDomainInfoFromUrl') .returns({ domainId: testDomainId, region: testRegion }) getSsoInstanceInfoStub = sinon.stub(SmusUtils, 'getSsoInstanceInfo').resolves(testSsoInstanceInfo) isInSmusSpaceEnvironmentStub = sinon.stub(SmusUtils, 'isInSmusSpaceEnvironment').returns(false) executeCommandStub = sinon.stub(vscode.commands, 'executeCommand').resolves() - sinon.stub(require('../../../auth/secondaryAuth'), 'getSecondaryAuth').returns(mockSecondaryAuth) + sinon.stub(secondaryAuthModule, 'getSecondaryAuth').returns(mockSecondaryAuth) - smusAuthProvider = new SmusAuthenticationProvider(mockAuth, mockSecondaryAuth) + smusAuthProvider = new SmusAuthenticationProvider(mockAuth) // Reset the executeCommand stub for clean state executeCommandStub.resetHistory() @@ -187,17 +194,55 @@ describe('SmusAuthenticationProvider', function () { }) describe('restore', function () { - it('should call secondary auth restoreConnection', async function () { + let mockState: any + let loadSharedCredentialsProfilesStub: sinon.SinonStub + let validateIamProfileStub: sinon.SinonStub + beforeEach(function () { + mockState = { + get: sinon.stub(), + update: sinon.stub().resolves(), + } + mockSecondaryAuth.state = mockState + + loadSharedCredentialsProfilesStub = sinon.stub(sharedCredentialsModule, 'loadSharedCredentialsProfiles') + validateIamProfileStub = sinon.stub(smusAuthProvider, 'validateIamProfile') + }) + + it('should call secondary auth restoreConnection when no saved connection ID', async function () { + mockState.get.withArgs('smus.savedConnectionId').returns(undefined) + + await smusAuthProvider.restore() + + assert.ok(mockSecondaryAuth.restoreConnection.called) + assert.ok(loadSharedCredentialsProfilesStub.notCalled) + }) + + it('should validate IAM profile and restore connection', async function () { + const savedConnectionId = 'test-connection-id' + const connectionMetadata = { + profileName: 'test-profile', + domainId: 'old-domain-id', + region: 'us-west-1', + } + const smusConnections = { [savedConnectionId]: connectionMetadata } + + mockState.get.withArgs('smus.savedConnectionId').returns(savedConnectionId) + mockState.get.withArgs('smus.connections').returns(smusConnections) + loadSharedCredentialsProfilesStub.resolves({ 'test-profile': { region: 'us-east-1' } }) + validateIamProfileStub.resolves({ isValid: true }) + await smusAuthProvider.restore() + + assert.ok(validateIamProfileStub.calledWith('test-profile')) assert.ok(mockSecondaryAuth.restoreConnection.called) }) }) - describe('connectToSmus', function () { + describe('connectToSmusWithSso', function () { it('should create new connection when none exists', async function () { mockAuth.listConnections.resolves([]) - const result = await smusAuthProvider.connectToSmus(testDomainUrl) + const result = await smusAuthProvider.connectToSmusWithSso(testDomainUrl) assert.strictEqual(result, mockSmusConnection) assert.ok(extractDomainInfoStub.calledWith(testDomainUrl)) @@ -212,7 +257,7 @@ describe('SmusAuthenticationProvider', function () { mockAuth.listConnections.resolves([existingConnection]) mockAuth.getConnectionState.returns('valid') - const result = await smusAuthProvider.connectToSmus(testDomainUrl) + const result = await smusAuthProvider.connectToSmusWithSso(testDomainUrl) assert.strictEqual(result, mockSmusConnection) assert.ok(mockAuth.createConnection.notCalled) @@ -225,7 +270,7 @@ describe('SmusAuthenticationProvider', function () { mockAuth.listConnections.resolves([existingConnection]) mockAuth.getConnectionState.returns('invalid') - const result = await smusAuthProvider.connectToSmus(testDomainUrl) + const result = await smusAuthProvider.connectToSmusWithSso(testDomainUrl) assert.strictEqual(result, mockSmusConnection) assert.ok(mockAuth.reauthenticate.calledWith(existingConnection)) @@ -237,7 +282,7 @@ describe('SmusAuthenticationProvider', function () { extractDomainInfoStub.returns({ domainId: undefined, region: testRegion }) await assert.rejects( - () => smusAuthProvider.connectToSmus('invalid-url'), + () => smusAuthProvider.connectToSmusWithSso('invalid-url'), (err: ToolkitError) => { // The error is wrapped with FailedToConnect, but the original error should be in the cause return err.code === 'FailedToConnect' && (err.cause as any)?.code === 'InvalidDomainUrl' @@ -252,7 +297,7 @@ describe('SmusAuthenticationProvider', function () { getSsoInstanceInfoStub.rejects(error) await assert.rejects( - () => smusAuthProvider.connectToSmus(testDomainUrl), + () => smusAuthProvider.connectToSmusWithSso(testDomainUrl), (err: ToolkitError) => err.code === 'FailedToConnect' ) // Should not trigger project selection on error @@ -264,7 +309,7 @@ describe('SmusAuthenticationProvider', function () { mockAuth.createConnection.rejects(error) await assert.rejects( - () => smusAuthProvider.connectToSmus(testDomainUrl), + () => smusAuthProvider.connectToSmusWithSso(testDomainUrl), (err: ToolkitError) => err.code === 'FailedToConnect' ) // Should not trigger project selection on error @@ -275,7 +320,7 @@ describe('SmusAuthenticationProvider', function () { isInSmusSpaceEnvironmentStub.returns(true) mockAuth.listConnections.resolves([]) - const result = await smusAuthProvider.connectToSmus(testDomainUrl) + const result = await smusAuthProvider.connectToSmusWithSso(testDomainUrl) assert.strictEqual(result, mockSmusConnection) assert.ok(mockAuth.createConnection.called) @@ -289,7 +334,7 @@ describe('SmusAuthenticationProvider', function () { mockAuth.listConnections.resolves([existingConnection]) mockAuth.getConnectionState.returns('valid') - const result = await smusAuthProvider.connectToSmus(testDomainUrl) + const result = await smusAuthProvider.connectToSmusWithSso(testDomainUrl) assert.strictEqual(result, mockSmusConnection) assert.ok(mockSecondaryAuth.useNewConnection.calledWith(existingConnection)) @@ -302,7 +347,7 @@ describe('SmusAuthenticationProvider', function () { mockAuth.listConnections.resolves([existingConnection]) mockAuth.getConnectionState.returns('invalid') - const result = await smusAuthProvider.connectToSmus(testDomainUrl) + const result = await smusAuthProvider.connectToSmusWithSso(testDomainUrl) assert.strictEqual(result, mockSmusConnection) assert.ok(mockAuth.reauthenticate.calledWith(existingConnection)) @@ -312,10 +357,16 @@ describe('SmusAuthenticationProvider', function () { }) describe('reauthenticate', function () { - it('should call auth reauthenticate', async function () { + it('should call auth reauthenticate for SSO connection', async function () { const result = await smusAuthProvider.reauthenticate(mockSmusConnection) - assert.strictEqual(result, mockSmusConnection) + // Verify the result has the correct SMUS properties preserved + assert.strictEqual(result.id, mockSmusConnection.id) + assert.strictEqual(result.domainUrl, mockSmusConnection.domainUrl) + assert.strictEqual(result.domainId, mockSmusConnection.domainId) + assert.strictEqual(result.type, mockSmusConnection.type) + assert.strictEqual(result.startUrl, mockSmusConnection.startUrl) + assert.strictEqual(result.label, mockSmusConnection.label) assert.ok(mockAuth.reauthenticate.calledWith(mockSmusConnection)) }) @@ -677,7 +728,7 @@ describe('SmusAuthenticationProvider', function () { beforeEach(function () { getContextStub.withArgs('aws.smus.inSmusSpaceEnvironment').returns(false) // Stub the DefaultStsClient constructor to return our mock instance - const stsClientModule = require('../../../shared/clients/stsClient') + // stsClientModule imported at top stsConstructorStub = sinon.stub(stsClientModule, 'DefaultStsClient').callsFake(() => mockStsClient) }) @@ -694,7 +745,7 @@ describe('SmusAuthenticationProvider', function () { assert.strictEqual(smusAuthProvider['cachedProjectAccountIds'].get(testProjectId), testAccountId) assert.ok(getProjectCredentialProviderStub.calledWith(testProjectId)) assert.ok(mockProjectCredentialsProvider.getCredentials.called) - assert.ok((DataZoneClient.getInstance as sinon.SinonStub).called) + assert.ok((DataZoneClient.createWithCredentials as sinon.SinonStub).called) assert.ok(mockDataZoneClientForProject.getToolingEnvironment.calledWith(testProjectId)) assert.ok(mockStsClient.getCallerIdentity.called) }) @@ -757,4 +808,1033 @@ describe('SmusAuthenticationProvider', function () { }) }) }) + + describe('signOut', function () { + let mockState: any + + beforeEach(function () { + mockState = { + get: sinon.stub(), + update: sinon.stub().resolves(), + } + mockSecondaryAuth.state = mockState + mockSecondaryAuth.forgetConnection = sinon.stub().resolves() + }) + + it('should do nothing when no active connection exists', async function () { + mockSecondaryAuthState.activeConnection = undefined + + await smusAuthProvider.signOut() + + assert.ok(mockState.get.notCalled) + assert.ok(mockState.update.notCalled) + assert.ok(mockSecondaryAuth.deleteConnection.notCalled) + assert.ok(mockSecondaryAuth.forgetConnection.notCalled) + }) + + it('should delete SSO connection and clear metadata', async function () { + const ssoConnection = { + ...mockSmusConnection, + type: 'sso' as const, + id: 'sso-connection-id', + } + mockSecondaryAuthState.activeConnection = ssoConnection + + const smusConnections = { + 'sso-connection-id': { + domainUrl: testDomainUrl, + domainId: testDomainId, + }, + } + mockState.get.withArgs('smus.connections').returns(smusConnections) + + await smusAuthProvider.signOut() + + assert.ok(mockState.get.calledWith('smus.connections')) + assert.ok(mockState.update.calledWith('smus.connections', {})) + assert.ok(mockSecondaryAuth.deleteConnection.called) + assert.ok(mockSecondaryAuth.forgetConnection.notCalled) + }) + + it('should forget IAM connection without deleting and clear metadata', async function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + const smusConnections = { + 'profile:test-profile': { + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + }, + } + mockState.get.withArgs('smus.connections').returns(smusConnections) + + await smusAuthProvider.signOut() + + assert.ok(mockState.get.calledWith('smus.connections')) + assert.ok(mockState.update.calledWith('smus.connections', {})) + assert.ok(mockSecondaryAuth.forgetConnection.called) + assert.ok(mockSecondaryAuth.deleteConnection.notCalled) + }) + + it('should handle mock connection in SMUS space environment', async function () { + const mockConnection = { + id: 'mock-connection-id', + // No 'type' property - simulates mock connection + } + mockSecondaryAuthState.activeConnection = mockConnection as any + + const smusConnections = { + 'mock-connection-id': { + domainUrl: testDomainUrl, + domainId: testDomainId, + }, + } + mockState.get.withArgs('smus.connections').returns(smusConnections) + + await smusAuthProvider.signOut() + + assert.ok(mockState.get.calledWith('smus.connections')) + assert.ok(mockState.update.calledWith('smus.connections', {})) + assert.ok(mockSecondaryAuth.deleteConnection.notCalled) + assert.ok(mockSecondaryAuth.forgetConnection.notCalled) + }) + + it('should handle missing metadata gracefully', async function () { + const ssoConnection = { + ...mockSmusConnection, + type: 'sso' as const, + id: 'sso-connection-id', + } + mockSecondaryAuthState.activeConnection = ssoConnection + + mockState.get.withArgs('smus.connections').returns({}) + + await smusAuthProvider.signOut() + + assert.ok(mockState.get.calledWith('smus.connections')) + // When there's no metadata to delete, update should not be called + assert.ok(mockState.update.notCalled) + assert.ok(mockSecondaryAuth.deleteConnection.called) + }) + + it('should throw ToolkitError when deleteConnection fails', async function () { + const ssoConnection = { + ...mockSmusConnection, + type: 'sso' as const, + id: 'sso-connection-id', + } + mockSecondaryAuthState.activeConnection = ssoConnection + + mockState.get.withArgs('smus.connections').returns({}) + mockSecondaryAuth.deleteConnection.rejects(new Error('Delete failed')) + + await assert.rejects( + () => smusAuthProvider.signOut(), + (err: ToolkitError) => { + return ( + err.code === 'SignOutFailed' && + err.message.includes('Failed to sign out from SageMaker Unified Studio') + ) + } + ) + }) + + it('should throw ToolkitError when forgetConnection fails', async function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + mockState.get.withArgs('smus.connections').returns({}) + mockSecondaryAuth.forgetConnection.rejects(new Error('Forget failed')) + + await assert.rejects( + () => smusAuthProvider.signOut(), + (err: ToolkitError) => { + return ( + err.code === 'SignOutFailed' && + err.message.includes('Failed to sign out from SageMaker Unified Studio') + ) + } + ) + }) + }) + + describe('connectWithIamProfile', function () { + let mockState: any + const testProfileName = 'test-profile' + const testIamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + } + + beforeEach(function () { + mockState = { + get: sinon.stub(), + update: sinon.stub().resolves(), + } + mockSecondaryAuth.state = mockState + mockAuth.getConnection = sinon.stub() + mockAuth.refreshConnectionState = sinon.stub().resolves() + }) + + it('should connect with existing IAM profile and store metadata', async function () { + extractDomainInfoStub.returns({ domainId: testDomainId, region: testRegion }) + mockAuth.getConnection.withArgs({ id: `profile:${testProfileName}` }).resolves(testIamConnection) + mockState.get.withArgs('smus.connections').returns({}) + + const result = await smusAuthProvider.connectWithIamProfile(testProfileName, testRegion, testDomainUrl) + + assert.strictEqual(result.id, testIamConnection.id) + assert.strictEqual(result.type, 'iam') + assert.strictEqual(result.profileName, testProfileName) + assert.strictEqual(result.region, testRegion) + assert.strictEqual(result.domainUrl, testDomainUrl) + assert.strictEqual(result.domainId, testDomainId) + + assert.ok(mockAuth.getConnection.calledWith({ id: `profile:${testProfileName}` })) + assert.ok(mockSecondaryAuth.useNewConnection.calledWith(testIamConnection)) + assert.ok(mockAuth.refreshConnectionState.calledWith(testIamConnection)) + assert.ok( + mockState.update.calledWith('smus.connections', { + [testIamConnection.id]: { + profileName: testProfileName, + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + isIamDomain: false, + }, + }) + ) + }) + + it('should merge with existing SMUS connections metadata', async function () { + extractDomainInfoStub.returns({ domainId: testDomainId, region: testRegion }) + mockAuth.getConnection.withArgs({ id: `profile:${testProfileName}` }).resolves(testIamConnection) + + const existingConnections = { + 'other-connection-id': { + domainUrl: 'https://other-domain.sagemaker.us-west-2.on.aws', + domainId: 'other-domain-id', + }, + } + mockState.get.withArgs('smus.connections').returns(existingConnections) + + await smusAuthProvider.connectWithIamProfile(testProfileName, testRegion, testDomainUrl) + + assert.ok( + mockState.update.calledWith('smus.connections', { + 'other-connection-id': existingConnections['other-connection-id'], + [testIamConnection.id]: { + profileName: testProfileName, + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + isIamDomain: false, + }, + }) + ) + }) + + it('should throw error for invalid domain URL', async function () { + extractDomainInfoStub.returns({ domainId: undefined, region: testRegion }) + + await assert.rejects( + () => smusAuthProvider.connectWithIamProfile(testProfileName, testRegion, 'invalid-url'), + (err: ToolkitError) => { + return ( + err.code === 'FailedToConnect' && + err.message.includes('Failed to connect to SageMaker Unified Studio with IAM profile') + ) + } + ) + + assert.ok(mockAuth.getConnection.notCalled) + assert.ok(mockSecondaryAuth.useNewConnection.notCalled) + }) + + it('should throw error when IAM connection not found', async function () { + extractDomainInfoStub.returns({ domainId: testDomainId, region: testRegion }) + mockAuth.getConnection.withArgs({ id: `profile:${testProfileName}` }).resolves(undefined) + + await assert.rejects( + () => smusAuthProvider.connectWithIamProfile(testProfileName, testRegion, testDomainUrl), + (err: ToolkitError) => { + return ( + err.code === 'FailedToConnect' && + err.message.includes('Failed to connect to SageMaker Unified Studio with IAM profile') && + (err.cause as any)?.code === 'ConnectionNotFound' + ) + } + ) + + assert.ok(mockSecondaryAuth.useNewConnection.notCalled) + }) + + it('should throw error when connection is not IAM type', async function () { + extractDomainInfoStub.returns({ domainId: testDomainId, region: testRegion }) + const nonIamConnection = { + id: 'profile:test-profile', + type: 'sso' as const, + label: 'Test SSO Connection', + } + mockAuth.getConnection.withArgs({ id: `profile:${testProfileName}` }).resolves(nonIamConnection) + + await assert.rejects( + () => smusAuthProvider.connectWithIamProfile(testProfileName, testRegion, testDomainUrl), + (err: ToolkitError) => { + return ( + err.code === 'FailedToConnect' && + err.message.includes('Failed to connect to SageMaker Unified Studio with IAM profile') + ) + } + ) + }) + + it('should handle useNewConnection failure', async function () { + extractDomainInfoStub.returns({ domainId: testDomainId, region: testRegion }) + mockAuth.getConnection.withArgs({ id: `profile:${testProfileName}` }).resolves(testIamConnection) + mockState.get.withArgs('smus.connections').returns({}) + mockSecondaryAuth.useNewConnection.rejects(new Error('Failed to use connection')) + + await assert.rejects( + () => smusAuthProvider.connectWithIamProfile(testProfileName, testRegion, testDomainUrl), + (err: ToolkitError) => { + return ( + err.code === 'FailedToConnect' && + err.message.includes('Failed to connect to SageMaker Unified Studio with IAM profile') + ) + } + ) + }) + + it('should handle refreshConnectionState failure', async function () { + extractDomainInfoStub.returns({ domainId: testDomainId, region: testRegion }) + mockAuth.getConnection.withArgs({ id: `profile:${testProfileName}` }).resolves(testIamConnection) + mockState.get.withArgs('smus.connections').returns({}) + mockAuth.refreshConnectionState.rejects(new Error('Failed to refresh state')) + + await assert.rejects( + () => smusAuthProvider.connectWithIamProfile(testProfileName, testRegion, testDomainUrl), + (err: ToolkitError) => { + return ( + err.code === 'FailedToConnect' && + err.message.includes('Failed to connect to SageMaker Unified Studio with IAM profile') + ) + } + ) + }) + }) + + describe('activeConnection with IAM metadata', function () { + let mockState: any + + beforeEach(function () { + mockState = { + get: sinon.stub(), + update: sinon.stub().resolves(), + } + mockSecondaryAuth.state = mockState + }) + + it('should return IAM connection with SMUS metadata when available', function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + const smusConnections = { + 'profile:test-profile': { + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + }, + } + mockState.get.withArgs('smus.connections').returns(smusConnections) + + const result = smusAuthProvider.activeConnection + + assert.strictEqual(result?.id, iamConnection.id) + assert.strictEqual((result as any)?.type, 'iam') + assert.strictEqual((result as any).profileName, 'test-profile') + assert.strictEqual((result as any).region, testRegion) + assert.strictEqual((result as any).domainUrl, testDomainUrl) + assert.strictEqual((result as any).domainId, testDomainId) + }) + + it('should return SSO connection with SMUS metadata when available', function () { + const ssoConnection = { + ...mockSmusConnection, + type: 'sso' as const, + } + mockSecondaryAuthState.activeConnection = ssoConnection + + const smusConnections = { + [ssoConnection.id]: { + domainUrl: testDomainUrl, + domainId: testDomainId, + }, + } + mockState.get.withArgs('smus.connections').returns(smusConnections) + + const result = smusAuthProvider.activeConnection + + assert.strictEqual(result?.id, ssoConnection.id) + assert.strictEqual((result as any)?.type, 'sso') + assert.strictEqual((result as any)?.domainUrl, testDomainUrl) + assert.strictEqual((result as any)?.domainId, testDomainId) + }) + + it('should return base connection when no metadata available', function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + mockState.get.withArgs('smus.connections').returns({}) + + const result = smusAuthProvider.activeConnection + + assert.strictEqual(result?.id, iamConnection.id) + assert.strictEqual((result as any)?.type, 'iam') + assert.strictEqual((result as any).profileName, undefined) + assert.strictEqual((result as any).domainUrl, undefined) + }) + + it('should return undefined when no active connection', function () { + mockSecondaryAuthState.activeConnection = undefined + + const result = smusAuthProvider.activeConnection + + assert.strictEqual(result, undefined) + }) + + it('should handle missing smus.connections state gracefully', function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + mockState.get.withArgs('smus.connections').returns(undefined) + + const result = smusAuthProvider.activeConnection + + assert.strictEqual(result?.id, iamConnection.id) + assert.strictEqual((result as any)?.type, 'iam') + }) + }) + + describe('getDerCredentialsProvider', function () { + let getContextStub: sinon.SinonStub + + beforeEach(function () { + getContextStub = sinon.stub(vscodeSetContext, 'getContext') + + // Clear cache + smusAuthProvider['credentialsProviderCache'].clear() + }) + + describe('in SMUS space environment', function () { + beforeEach(function () { + getContextStub.withArgs('aws.smus.inSmusSpaceEnvironment').returns(true) + + // Mock resource metadata for SMUS space environment + getResourceMetadataStub = sinon.stub(resourceMetadataUtils, 'getResourceMetadata').returns({ + ResourceArn: 'arn:aws:sagemaker:us-east-2:123456789012:app/dzd_domainId/test-app', + AdditionalMetadata: { + DataZoneDomainId: testDomainId, + DataZoneDomainRegion: testRegion, + }, + } as any) + }) + + afterEach(function () { + getResourceMetadataStub?.restore() + }) + + it('should return a credentials provider that can retrieve credentials', async function () { + // In SMUS space environment, the method should return a provider + // We can't easily test the internal branching logic without stubbing ES modules + // So we test that it returns a valid provider structure + const provider = await smusAuthProvider.getDerCredentialsProvider() + + assert.ok(provider, 'Provider should be returned') + assert.ok(typeof provider.getCredentials === 'function', 'Provider should have getCredentials method') + }) + + it('should not cache providers in SMUS space environment', async function () { + // Get provider twice + const provider1 = await smusAuthProvider.getDerCredentialsProvider() + const provider2 = await smusAuthProvider.getDerCredentialsProvider() + + // In SMUS space, providers are not cached (new provider each time) + // This is because the logic returns early before caching + assert.ok(provider1) + assert.ok(provider2) + }) + }) + + describe('in non-SMUS space environment', function () { + let getAccessTokenStub: sinon.SinonStub + + beforeEach(function () { + getContextStub.withArgs('aws.smus.inSmusSpaceEnvironment').returns(false) + mockSecondaryAuthState.activeConnection = mockSmusConnection + getAccessTokenStub = sinon.stub(smusAuthProvider, 'getAccessToken').resolves('mock-access-token') + }) + + it('should create and cache DomainExecRoleCredentialsProvider for SSO connection', async function () { + const provider = await smusAuthProvider.getDerCredentialsProvider() + + assert.ok(provider) + assert.ok(getAccessTokenStub.notCalled) // Not called until getCredentials is invoked + + // Verify caching + const cachedProvider = await smusAuthProvider.getDerCredentialsProvider() + assert.strictEqual(provider, cachedProvider) + }) + + it('should throw error when no active connection', async function () { + mockSecondaryAuthState.activeConnection = undefined + + await assert.rejects( + () => smusAuthProvider.getDerCredentialsProvider(), + (err: ToolkitError) => { + return ( + err.code === 'NoActiveConnection' && + err.message.includes('No active SMUS connection available') + ) + } + ) + }) + + it('should throw error for non-SSO connection', async function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + await assert.rejects( + () => smusAuthProvider.getDerCredentialsProvider(), + (err: ToolkitError) => { + return ( + err.code === 'InvalidConnectionType' && + err.message.includes( + 'Domain Execution Role credentials are only available for SSO connections' + ) + ) + } + ) + }) + + it('should use cached provider for same connection', async function () { + const provider1 = await smusAuthProvider.getDerCredentialsProvider() + const provider2 = await smusAuthProvider.getDerCredentialsProvider() + + assert.strictEqual(provider1, provider2) + }) + + it('should create different providers for different connections', async function () { + const provider1 = await smusAuthProvider.getDerCredentialsProvider() + + // Change connection + const differentConnection = { + ...mockSmusConnection, + id: 'different-connection-id', + domainId: 'different-domain-id', + } + mockSecondaryAuthState.activeConnection = differentConnection + + const provider2 = await smusAuthProvider.getDerCredentialsProvider() + + assert.notStrictEqual(provider1, provider2) + }) + }) + }) + + describe('initIamModeContextInSpaceEnvironment', function () { + let getResourceMetadataStub: sinon.SinonStub + let getDerCredentialsProviderStub: sinon.SinonStub + let getInstanceStub: sinon.SinonStub + let isIamDomainStub: sinon.SinonStub + let mockCredentialsProvider: any + let mockClientHelper: any + + const testResourceMetadata = { + AdditionalMetadata: { + DataZoneDomainId: 'test-domain-id', + DataZoneDomainRegion: 'us-east-1', + DataZoneProjectId: 'test-project-id', + }, + } + + beforeEach(function () { + getResourceMetadataStub = sinon.stub(resourceMetadataUtils, 'getResourceMetadata') + + // Reset the global setContext stub history for clean test state + setContextStubGlobal.resetHistory() + + mockCredentialsProvider = { + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + }), + } + + getDerCredentialsProviderStub = sinon + .stub(smusAuthProvider, 'getDerCredentialsProvider') + .resolves(mockCredentialsProvider) + + // Mock DataZoneCustomClientHelper + isIamDomainStub = sinon.stub() + mockClientHelper = { + isIamDomain: isIamDomainStub, + } + + getInstanceStub = sinon.stub(DataZoneCustomClientHelper, 'getInstance').returns(mockClientHelper) + }) + + afterEach(function () { + sinon.restore() + }) + + it('should set IAM mode context to true when domain is IAM mode', async function () { + getResourceMetadataStub.returns(testResourceMetadata) + isIamDomainStub.resolves(true) + + await smusAuthProvider['initIamModeContextInSpaceEnvironment']() + + assert.ok(getResourceMetadataStub.called) + assert.ok(getDerCredentialsProviderStub.called) + assert.ok( + getInstanceStub.calledWith( + mockCredentialsProvider, + testResourceMetadata.AdditionalMetadata.DataZoneDomainRegion + ) + ) + assert.ok(isIamDomainStub.calledWith(testResourceMetadata.AdditionalMetadata.DataZoneDomainId)) + assert.ok(setContextStubGlobal.calledWith('aws.smus.isIamMode', true)) + }) + + it('should set IAM mode context to false when domain is not IAM mode', async function () { + getResourceMetadataStub.returns(testResourceMetadata) + isIamDomainStub.resolves(false) + + await smusAuthProvider['initIamModeContextInSpaceEnvironment']() + + assert.ok(getResourceMetadataStub.called) + assert.ok(getDerCredentialsProviderStub.called) + assert.ok( + getInstanceStub.calledWith( + mockCredentialsProvider, + testResourceMetadata.AdditionalMetadata.DataZoneDomainRegion + ) + ) + assert.ok(isIamDomainStub.calledWith(testResourceMetadata.AdditionalMetadata.DataZoneDomainId)) + assert.ok(setContextStubGlobal.calledWith('aws.smus.isIamMode', false)) + }) + + it('should not call IAM mode check when resource metadata is missing', async function () { + getResourceMetadataStub.returns(undefined) + + await smusAuthProvider['initIamModeContextInSpaceEnvironment']() + + assert.ok(getResourceMetadataStub.called) + assert.ok(getDerCredentialsProviderStub.notCalled) + assert.ok(getInstanceStub.notCalled) + assert.ok(isIamDomainStub.notCalled) + assert.ok(setContextStubGlobal.notCalled) + }) + + it('should handle error when getDerCredentialsProvider fails', async function () { + getResourceMetadataStub.returns(testResourceMetadata) + const testError = new Error('Failed to get credentials provider') + getDerCredentialsProviderStub.rejects(testError) + + await smusAuthProvider['initIamModeContextInSpaceEnvironment']() + + assert.ok(getResourceMetadataStub.called) + assert.ok(getDerCredentialsProviderStub.called) + assert.ok(getInstanceStub.notCalled) + assert.ok(isIamDomainStub.notCalled) + assert.ok(setContextStubGlobal.calledWith('aws.smus.isIamMode', false)) + }) + }) + + describe('getSessionName', function () { + let mockStsClient: any + let mockCredentialsProvider: any + + beforeEach(function () { + // Mock STS client + mockStsClient = { + getCallerIdentity: sinon.stub(), + } + sinon + .stub(DefaultStsClient.prototype, 'getCallerIdentity') + .callsFake(() => mockStsClient.getCallerIdentity()) + + // Mock credentials provider + mockCredentialsProvider = { + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-access-key', + secretAccessKey: 'test-secret-key', + sessionToken: 'test-session-token', + }), + } + + sinon + .stub(smusAuthProvider as any, 'getCredentialsForIamProfile') + .resolves(mockCredentialsProvider.getCredentials()) + }) + + afterEach(function () { + sinon.restore() + }) + + it('should return session name for IAM connection with assumed role', async function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + endpointUrl: undefined, + getCredentials: sinon.stub().resolves(), + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + // Mock STS response with assumed role ARN + const assumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/MyRole/my-session-name' + mockStsClient.getCallerIdentity.resolves({ + Arn: assumedRoleArn, + Account: '123456789012', + UserId: 'AIDAI1234567890EXAMPLE:my-session-name', + }) + + // Mock connection metadata + const smusConnections = { + [iamConnection.id]: { + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + }, + } + mockSecondaryAuth.state.get.withArgs('smus.connections').returns(smusConnections) + + const sessionName = await smusAuthProvider.getSessionName() + + assert.strictEqual(sessionName, 'my-session-name') + assert.ok(mockStsClient.getCallerIdentity.calledOnce) + }) + + it('should return undefined for IAM connection without assumed role (IAM user)', async function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + endpointUrl: undefined, + getCredentials: sinon.stub().resolves(), + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + // Mock STS response with IAM user ARN (no session name) + const iamUserArn = 'arn:aws:iam::123456789012:user/my-user' + mockStsClient.getCallerIdentity.resolves({ + Arn: iamUserArn, + Account: '123456789012', + UserId: 'AIDAI1234567890EXAMPLE', + }) + + // Mock connection metadata + const smusConnections = { + [iamConnection.id]: { + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + }, + } + mockSecondaryAuth.state.get.withArgs('smus.connections').returns(smusConnections) + + const sessionName = await smusAuthProvider.getSessionName() + + assert.strictEqual(sessionName, undefined) + assert.ok(mockStsClient.getCallerIdentity.calledOnce) + }) + + it('should return undefined for SSO connection', async function () { + mockSecondaryAuthState.activeConnection = mockSmusConnection + + const sessionName = await smusAuthProvider.getSessionName() + + assert.strictEqual(sessionName, undefined) + assert.ok(mockStsClient.getCallerIdentity.notCalled) + }) + + it('should return undefined when not connected', async function () { + mockSecondaryAuthState.activeConnection = undefined + + const sessionName = await smusAuthProvider.getSessionName() + + assert.strictEqual(sessionName, undefined) + assert.ok(mockStsClient.getCallerIdentity.notCalled) + }) + + it('should cache and reuse caller identity ARN', async function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + endpointUrl: undefined, + getCredentials: sinon.stub().resolves(), + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + const assumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/MyRole/my-session-name' + mockStsClient.getCallerIdentity.resolves({ + Arn: assumedRoleArn, + Account: '123456789012', + UserId: 'AIDAI1234567890EXAMPLE:my-session-name', + }) + + // Mock connection metadata + const smusConnections = { + [iamConnection.id]: { + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + }, + } + mockSecondaryAuth.state.get.withArgs('smus.connections').returns(smusConnections) + + // First call - should fetch from STS + const sessionName1 = await smusAuthProvider.getSessionName() + assert.strictEqual(sessionName1, 'my-session-name') + assert.ok(mockStsClient.getCallerIdentity.calledOnce) + + // Second call - should use cached value + const sessionName2 = await smusAuthProvider.getSessionName() + assert.strictEqual(sessionName2, 'my-session-name') + assert.ok(mockStsClient.getCallerIdentity.calledOnce) // Still only called once + }) + + it('should handle STS errors gracefully', async function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + endpointUrl: undefined, + getCredentials: sinon.stub().resolves(), + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + mockStsClient.getCallerIdentity.rejects(new Error('STS call failed')) + + // Mock connection metadata + const smusConnections = { + [iamConnection.id]: { + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + }, + } + mockSecondaryAuth.state.get.withArgs('smus.connections').returns(smusConnections) + + const sessionName = await smusAuthProvider.getSessionName() + + assert.strictEqual(sessionName, undefined) + }) + + it('should return undefined when connection metadata is missing', async function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + endpointUrl: undefined, + getCredentials: sinon.stub().resolves(), + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + // No connection metadata + mockSecondaryAuth.state.get.withArgs('smus.connections').returns({}) + + const sessionName = await smusAuthProvider.getSessionName() + + assert.strictEqual(sessionName, undefined) + assert.ok(mockStsClient.getCallerIdentity.notCalled) + }) + }) + + describe('getRoleArn', function () { + let mockStsClient: any + let mockCredentialsProvider: any + + beforeEach(function () { + // Mock STS client + mockStsClient = { + getCallerIdentity: sinon.stub(), + } + sinon + .stub(DefaultStsClient.prototype, 'getCallerIdentity') + .callsFake(() => mockStsClient.getCallerIdentity()) + + // Mock credentials provider + mockCredentialsProvider = { + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-access-key', + secretAccessKey: 'test-secret-key', + sessionToken: 'test-session-token', + }), + } + + sinon + .stub(smusAuthProvider as any, 'getCredentialsForIamProfile') + .resolves(mockCredentialsProvider.getCredentials()) + }) + + afterEach(function () { + sinon.restore() + }) + + it('should return IAM role ARN for IAM connection with assumed role', async function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + endpointUrl: undefined, + getCredentials: sinon.stub().resolves(), + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + // Mock STS response with assumed role ARN + const assumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/MyRole/my-session-name' + mockStsClient.getCallerIdentity.resolves({ + Arn: assumedRoleArn, + Account: '123456789012', + UserId: 'AIDAI1234567890EXAMPLE:my-session-name', + }) + + // Mock connection metadata + const smusConnections = { + [iamConnection.id]: { + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + }, + } + mockSecondaryAuth.state.get.withArgs('smus.connections').returns(smusConnections) + + const roleArn = await smusAuthProvider.getIamPrincipalArn() + + // Should convert assumed role ARN to IAM role ARN + assert.strictEqual(roleArn, 'arn:aws:iam::123456789012:role/MyRole') + assert.ok(mockStsClient.getCallerIdentity.calledOnce) + }) + + it('should return undefined for SSO connection', async function () { + mockSecondaryAuthState.activeConnection = mockSmusConnection + + const roleArn = await smusAuthProvider.getIamPrincipalArn() + + assert.strictEqual(roleArn, undefined) + assert.ok(mockStsClient.getCallerIdentity.notCalled) + }) + + it('should return undefined when not connected', async function () { + mockSecondaryAuthState.activeConnection = undefined + + const roleArn = await smusAuthProvider.getIamPrincipalArn() + + assert.strictEqual(roleArn, undefined) + assert.ok(mockStsClient.getCallerIdentity.notCalled) + }) + + it('should use cached caller identity ARN', async function () { + const iamConnection = { + id: 'profile:test-profile', + type: 'iam' as const, + label: 'Test IAM Profile', + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + endpointUrl: undefined, + getCredentials: sinon.stub().resolves(), + } + mockSecondaryAuthState.activeConnection = iamConnection as any + + const assumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/MyRole/my-session-name' + mockStsClient.getCallerIdentity.resolves({ + Arn: assumedRoleArn, + Account: '123456789012', + UserId: 'AIDAI1234567890EXAMPLE:my-session-name', + }) + + // Mock connection metadata + const smusConnections = { + [iamConnection.id]: { + profileName: 'test-profile', + region: testRegion, + domainUrl: testDomainUrl, + domainId: testDomainId, + }, + } + mockSecondaryAuth.state.get.withArgs('smus.connections').returns(smusConnections) + + // First call - should fetch from STS + const roleArn1 = await smusAuthProvider.getIamPrincipalArn() + assert.strictEqual(roleArn1, 'arn:aws:iam::123456789012:role/MyRole') + assert.ok(mockStsClient.getCallerIdentity.calledOnce) + + // Second call - should use cached value + const roleArn2 = await smusAuthProvider.getIamPrincipalArn() + assert.strictEqual(roleArn2, 'arn:aws:iam::123456789012:role/MyRole') + assert.ok(mockStsClient.getCallerIdentity.calledOnce) // Still only called once + }) + }) }) diff --git a/packages/core/src/test/sagemakerunifiedstudio/auth/ui/authenticationMethodSelection.test.ts b/packages/core/src/test/sagemakerunifiedstudio/auth/ui/authenticationMethodSelection.test.ts new file mode 100644 index 00000000000..e0332326be1 --- /dev/null +++ b/packages/core/src/test/sagemakerunifiedstudio/auth/ui/authenticationMethodSelection.test.ts @@ -0,0 +1,39 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as assert from 'assert' +import { SmusAuthenticationMethodSelector } from '../../../../sagemakerunifiedstudio/auth/ui/authenticationMethodSelection' + +describe('SmusAuthenticationMethodSelector', function () { + // Note: Due to AWS Toolkit test framework restrictions on mocking vscode.window, + // these tests focus on the interface and behavior rather than deep mocking. + // The actual QuickPick functionality is tested through integration tests. + + describe('showAuthenticationMethodSelection', function () { + it('should export the correct interface', function () { + // Verify the class exists and has the expected static method + assert.ok('showAuthenticationMethodSelection' in SmusAuthenticationMethodSelector) + assert.strictEqual(typeof SmusAuthenticationMethodSelector.showAuthenticationMethodSelection, 'function') + }) + + it('should handle authentication method types correctly', function () { + // Test that the types are properly defined + const testMethod1: 'sso' | 'iam' = 'sso' + const testMethod2: 'sso' | 'iam' = 'iam' + + assert.strictEqual(testMethod1, 'sso') + assert.strictEqual(testMethod2, 'iam') + }) + + // The actual UI testing would be done manually or through E2E tests + it('should be callable without throwing', function () { + // Verify the method exists and is accessible + assert.doesNotThrow(() => { + // Just verify the method exists without calling it + assert.ok('showAuthenticationMethodSelection' in SmusAuthenticationMethodSelector) + }) + }) + }) +}) diff --git a/packages/core/src/test/sagemakerunifiedstudio/auth/ui/iamProfileSelection.test.ts b/packages/core/src/test/sagemakerunifiedstudio/auth/ui/iamProfileSelection.test.ts new file mode 100644 index 00000000000..2a56b2baf0c --- /dev/null +++ b/packages/core/src/test/sagemakerunifiedstudio/auth/ui/iamProfileSelection.test.ts @@ -0,0 +1,21 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import assert from 'assert' +import { SmusIamProfileSelector } from '../../../../sagemakerunifiedstudio/auth/ui/iamProfileSelection' + +describe('SmusIamProfileSelector', function () { + describe('showRegionSelection', function () { + it('should be a static method', function () { + assert.strictEqual(typeof SmusIamProfileSelector.showRegionSelection, 'function') + }) + }) + + describe('showIamProfileSelection', function () { + it('should be a static method', function () { + assert.strictEqual(typeof SmusIamProfileSelector.showIamProfileSelection, 'function') + }) + }) +}) diff --git a/packages/core/src/test/sagemakerunifiedstudio/auth/ui/ssoAuthentication.test.ts b/packages/core/src/test/sagemakerunifiedstudio/auth/ui/ssoAuthentication.test.ts new file mode 100644 index 00000000000..6a4221ac7a4 --- /dev/null +++ b/packages/core/src/test/sagemakerunifiedstudio/auth/ui/ssoAuthentication.test.ts @@ -0,0 +1,40 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as assert from 'assert' +import { SmusSsoAuthenticationUI } from '../../../../sagemakerunifiedstudio/auth/ui/ssoAuthentication' + +describe('SmusSsoAuthenticationUI', function () { + // Note: Due to AWS Toolkit test framework restrictions on mocking vscode.window, + // these tests focus on the interface and behavior rather than deep mocking. + // The actual QuickPick functionality is tested through integration tests. + + describe('showDomainUrlInput', function () { + it('should export the correct interface', function () { + // Verify the class exists and has the expected static method + assert.ok('showDomainUrlInput' in SmusSsoAuthenticationUI) + assert.strictEqual(typeof SmusSsoAuthenticationUI.showDomainUrlInput, 'function') + }) + + it('should be callable without throwing', function () { + // Verify the method exists and is accessible + assert.doesNotThrow(() => { + // Just verify the method exists without calling it + assert.ok('showDomainUrlInput' in SmusSsoAuthenticationUI) + }) + }) + + it('should handle return type union correctly', function () { + // Test that the return types are properly defined + const testResult1: string | 'BACK' | undefined = 'https://example.com' + const testResult2: string | 'BACK' | undefined = 'BACK' + const testResult3: string | 'BACK' | undefined = undefined + + assert.strictEqual(testResult1, 'https://example.com') + assert.strictEqual(testResult2, 'BACK') + assert.strictEqual(testResult3, undefined) + }) + }) +}) diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/activation.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/activation.test.ts index 982aa481bd3..c5c6932184b 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/activation.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/activation.test.ts @@ -11,7 +11,6 @@ import { SmusAuthenticationProvider, setSmusConnectedContext, } from '../../../sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider' -import { DataZoneClient } from '../../../sagemakerunifiedstudio/shared/client/datazoneClient' import { ResourceTreeDataProvider } from '../../../shared/treeview/resourceTreeDataProvider' import { SageMakerUnifiedStudioRootNode } from '../../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode' import { getLogger } from '../../../shared/logger/logger' @@ -19,6 +18,8 @@ import { getTestWindow } from '../../shared/vscode/window' import { SeverityLevel } from '../../shared/vscode/message' import * as extensionUtilities from '../../../shared/extensionUtilities' import { createMockSpaceNode } from '../testUtils' +import { DataZoneClient } from '../../../sagemakerunifiedstudio/shared/client/datazoneClient' +import * as model from '../../../sagemakerunifiedstudio/auth/model' describe('SMUS Explorer Activation', function () { let mockExtensionContext: vscode.ExtensionContext @@ -41,6 +42,7 @@ describe('SMUS Explorer Activation', function () { isConnected: sinon.stub().returns(true), reauthenticate: sinon.stub().resolves(), onDidChange: sinon.stub().callsFake((_listener: () => void) => ({ dispose: sinon.stub() })), + onDidChangeActiveConnection: sinon.stub().callsFake((_listener: () => void) => ({ dispose: sinon.stub() })), activeConnection: { id: 'test-connection', domainId: 'test-domain', @@ -69,7 +71,7 @@ describe('SMUS Explorer Activation', function () { // Stub SmusAuthenticationProvider sinon.stub(SmusAuthenticationProvider, 'fromContext').returns(mockSmusAuthProvider as any) - // Stub DataZoneClient + // Stub DataZoneClient.dispose dataZoneDisposeStub = sinon.stub(DataZoneClient, 'dispose') // Stub SageMakerUnifiedStudioRootNode constructor @@ -151,7 +153,7 @@ describe('SMUS Explorer Activation', function () { it('should register DataZone client disposal', async function () { await activate(mockExtensionContext) - // Find the DataZone dispose subscription - it should be the last one added + // Find the DataZone dispose subscription const subscriptions = mockExtensionContext.subscriptions assert.ok(subscriptions.length > 0) @@ -261,7 +263,124 @@ describe('SMUS Explorer Activation', function () { // Check that an error message was shown const errorMessages = testWindow.shownMessages.filter((msg) => msg.severity === SeverityLevel.Error) assert.ok(errorMessages.length > 0, 'Should show error message') - assert.ok(errorMessages.some((msg) => msg.message.includes('Failed to reauthenticate'))) + assert.ok(errorMessages.some((msg) => msg.message.includes('Reauthentication failed'))) + }) + + it('should extract detailed error message from ToolkitError cause chain', async function () { + const reauthCommand = registerCommandStub + .getCalls() + .find((call) => call.args[0] === 'aws.smus.reauthenticate') + + assert.ok(reauthCommand) + + const mockConnection = { + id: 'test-connection', + type: 'sso', + startUrl: 'https://identitycenter.amazonaws.com/ssoins-testInstanceId', + ssoRegion: 'us-east-1', + scopes: ['datazone:domain:access'], + label: 'Test Connection', + } as any + + // Create a ToolkitError with a cause chain + const detailedError = new Error('Invalid profile - The security token is expired') + const wrapperError = new Error('Unable to reauthenticate SageMaker Unified Studio connection.') + ;(wrapperError as any).cause = detailedError + mockSmusAuthProvider.reauthenticate.rejects(wrapperError) + + const testWindow = getTestWindow() + + // Execute the command handler + await reauthCommand.args[1](mockConnection) + + // Check that the detailed error message from the cause was shown + const errorMessages = testWindow.shownMessages.filter((msg) => msg.severity === SeverityLevel.Error) + assert.ok(errorMessages.length > 0, 'Should show error message') + const hasDetailedError = errorMessages.some((msg) => + msg.message.includes('Invalid profile - The security token is expired') + ) + assert.ok(hasDetailedError, 'Should show detailed error from cause chain') + }) + + it('should not show success message for IAM connection reauthentication', async function () { + const reauthCommand = registerCommandStub + .getCalls() + .find((call) => call.args[0] === 'aws.smus.reauthenticate') + + assert.ok(reauthCommand) + + // Create an IAM connection + const mockIamConnection = { + id: 'test-iam-connection', + type: 'iam', + profileName: 'test-profile', + region: 'us-east-1', + label: 'Test IAM Connection', + } as any + + // Stub isSmusIamConnection to return true for IAM connection + sinon.stub(model, 'isSmusIamConnection').returns(true) + + // Mock the return value to return the connection (IAM connection handled its own message) + mockSmusAuthProvider.reauthenticate.resolves(mockIamConnection) + + const testWindow = getTestWindow() + + // Execute the command handler + await reauthCommand.args[1](mockIamConnection) + + assert.ok(mockSmusAuthProvider.reauthenticate.calledWith(mockIamConnection)) + assert.ok(mockTreeDataProvider.refresh.called) + + // Check that NO information message was shown (IAM handles its own) + const infoMessages = testWindow.shownMessages.filter( + (msg) => msg.severity === SeverityLevel.Information + ) + assert.ok( + !infoMessages.some((msg) => msg.message.includes('Successfully reauthenticated')), + 'Should not show success message for IAM connection' + ) + }) + + it('should show success message for SSO connection reauthentication', async function () { + const reauthCommand = registerCommandStub + .getCalls() + .find((call) => call.args[0] === 'aws.smus.reauthenticate') + + assert.ok(reauthCommand) + + const mockSsoConnection = { + id: 'test-sso-connection', + type: 'sso', + startUrl: 'https://identitycenter.amazonaws.com/ssoins-testInstanceId', + ssoRegion: 'us-east-1', + scopes: ['datazone:domain:access'], + label: 'Test SSO Connection', + } as any + + // Stub isSmusIamConnection to return false for SSO connection + sinon.stub(model, 'isSmusIamConnection').returns(false) + + // Mock the return value to indicate SSO connection (returns connection object) + mockSmusAuthProvider.reauthenticate.resolves(mockSsoConnection) + + const testWindow = getTestWindow() + + // Execute the command handler + await reauthCommand.args[1](mockSsoConnection) + + assert.ok(mockSmusAuthProvider.reauthenticate.calledWith(mockSsoConnection)) + assert.ok(mockTreeDataProvider.refresh.called) + + // Check that an information message was shown for SSO + const infoMessages = testWindow.shownMessages.filter( + (msg) => msg.severity === SeverityLevel.Information + ) + assert.ok(infoMessages.length > 0, 'Should show information message for SSO') + assert.ok( + infoMessages.some((msg) => msg.message.includes('Successfully reauthenticated')), + 'Should show success message for SSO connection' + ) }) it('should handle aws.smus.refreshProject command', async function () { diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/federatedConnectionStrategy.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/federatedConnectionStrategy.test.ts new file mode 100644 index 00000000000..9aadf68d443 --- /dev/null +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/federatedConnectionStrategy.test.ts @@ -0,0 +1,185 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as assert from 'assert' +import * as sinon from 'sinon' +import * as vscode from 'vscode' +import { createFederatedConnectionNode } from '../../../../sagemakerunifiedstudio/explorer/nodes/federatedConnectionStrategy' +import { GlueClient, ListEntitiesCommand, DescribeEntityCommand } from '@aws-sdk/client-glue' +import { ConnectionCredentialsProvider } from '../../../../sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider' + +describe('FederatedConnectionStrategy', function () { + let sandbox: sinon.SinonSandbox + let mockGlueClient: sinon.SinonStubbedInstance + let mockCredentialsProvider: ConnectionCredentialsProvider + + const mockConnection = { + connectionId: 'federated-conn-123', + name: 'test-federated-connection', + glueConnectionName: 'test-glue-connection', + } + + beforeEach(function () { + sandbox = sinon.createSandbox() + + mockCredentialsProvider = { + getCredentials: sandbox.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + }), + logger: {} as any, + smusAuthProvider: {} as any, + connectionId: 'test-connection', + projectId: 'test-project', + } as any + + mockGlueClient = sandbox.createStubInstance(GlueClient) + sandbox.stub(GlueClient.prototype, 'send').callsFake(mockGlueClient.send) + }) + + afterEach(function () { + sandbox.restore() + }) + + describe('createFederatedConnectionNode', function () { + it('should create connection node with correct properties', async function () { + const node = await createFederatedConnectionNode( + mockConnection as any, + mockCredentialsProvider, + 'us-east-1' + ) + + assert.strictEqual(node.id, 'federated-federated-conn-123') + assert.strictEqual(node.resource, mockConnection) + + const treeItem = await node.getTreeItem() + assert.strictEqual(treeItem.label, 'test-federated-connection') + assert.strictEqual(treeItem.contextValue, 'federatedConnection') + assert.strictEqual(treeItem.collapsibleState, vscode.TreeItemCollapsibleState.Collapsed) + }) + + it('should return error when no glue connection name', async function () { + const connectionWithoutGlue = { ...mockConnection, glueConnectionName: undefined } + + const node = await createFederatedConnectionNode( + connectionWithoutGlue as any, + mockCredentialsProvider, + 'us-east-1' + ) + + const children = await node.getChildren!() + assert.strictEqual(children.length, 1) + assert.ok(children[0].id.includes('error')) + }) + + it('should return placeholder when no entities found', async function () { + mockGlueClient.send.resolves({ Entities: [] }) + + const node = await createFederatedConnectionNode( + mockConnection as any, + mockCredentialsProvider, + 'us-east-1' + ) + + const children = await node.getChildren!() + assert.strictEqual(children.length, 1) + assert.strictEqual(children[0].resource, '[No data found]') + }) + + it('should group tables under Tables container', async function () { + mockGlueClient.send.resolves({ + Entities: [ + { EntityName: 'table1', Category: 'TABLE', Label: 'Table 1' }, + { EntityName: 'table2', Category: 'TABLE', Label: 'Table 2' }, + ], + }) + + const node = await createFederatedConnectionNode( + mockConnection as any, + mockCredentialsProvider, + 'us-east-1' + ) + + const children = await node.getChildren!() + assert.strictEqual(children.length, 1) + + const tablesContainer = children[0] + assert.ok(tablesContainer.id.includes('tables')) + + const tableChildren = await tablesContainer.getChildren!() + assert.strictEqual(tableChildren.length, 2) + }) + + it('should handle mixed entity types correctly', async function () { + mockGlueClient.send.resolves({ + Entities: [ + { EntityName: 'schema1', Category: 'SCHEMA', Label: 'Schema 1' }, + { EntityName: 'table1', Category: 'TABLE', Label: 'Table 1' }, + ], + }) + + const node = await createFederatedConnectionNode( + mockConnection as any, + mockCredentialsProvider, + 'us-east-1' + ) + + const children = await node.getChildren!() + assert.strictEqual(children.length, 2) // schema + tables container + }) + + it('should handle table columns', async function () { + const mockEntity = { EntityName: 'test-table', Category: 'TABLE' } + + mockGlueClient.send.callsFake((command) => { + if (command instanceof DescribeEntityCommand) { + return Promise.resolve({ + Fields: [ + { FieldName: 'col1', FieldType: 'string', Label: 'Column 1' }, + { FieldName: 'col2', FieldType: 'int', Label: 'Column 2' }, + ], + }) + } + if (command instanceof ListEntitiesCommand) { + return Promise.resolve({ + Entities: [mockEntity], + }) + } + return Promise.resolve({}) + }) + + const node = await createFederatedConnectionNode( + mockConnection as any, + mockCredentialsProvider, + 'us-east-1' + ) + + const children = await node.getChildren!() + const tablesContainer = children[0] + const tableNodes = await tablesContainer.getChildren!() + const tableNode = tableNodes[0] + + const columns = await tableNode.getChildren!() + assert.strictEqual(columns.length, 2) + + const columnTreeItem = await columns[0].getTreeItem() + assert.strictEqual(columnTreeItem.description, 'string') + }) + + it('should handle API errors gracefully', async function () { + mockGlueClient.send.rejects(new Error('API Error')) + + const node = await createFederatedConnectionNode( + mockConnection as any, + mockCredentialsProvider, + 'us-east-1' + ) + + const children = await node.getChildren!() + assert.strictEqual(children.length, 1) + assert.ok(children[0].id.includes('error')) + }) + }) +}) diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioAuthInfoNode.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioAuthInfoNode.test.ts index ebf2eae2cb0..52d1d045403 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioAuthInfoNode.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioAuthInfoNode.test.ts @@ -8,12 +8,12 @@ import sinon from 'sinon' import * as vscode from 'vscode' import { SageMakerUnifiedStudioAuthInfoNode } from '../../../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioAuthInfoNode' import { SmusAuthenticationProvider } from '../../../../sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider' -import { SmusConnection } from '../../../../sagemakerunifiedstudio/auth/model' +import { SmusConnection, SmusSsoConnection } from '../../../../sagemakerunifiedstudio/auth/model' describe('SageMakerUnifiedStudioAuthInfoNode', function () { let authInfoNode: SageMakerUnifiedStudioAuthInfoNode let mockAuthProvider: any - let mockConnection: SmusConnection + let mockConnection: SmusSsoConnection let currentActiveConnection: SmusConnection | undefined beforeEach(function () { @@ -39,6 +39,20 @@ describe('SageMakerUnifiedStudioAuthInfoNode', function () { isConnected: sinon.stub().returns(true), isConnectionValid: sinon.stub().returns(true), onDidChange: sinon.stub().callsFake((listener: () => void) => ({ dispose: sinon.stub() })), + onDidChangeActiveConnection: sinon.stub().callsFake((listener: () => void) => ({ dispose: sinon.stub() })), + getDomainId: sinon.stub().callsFake(() => { + return currentActiveConnection?.domainId + }), + getDomainRegion: sinon.stub().callsFake(() => { + if (currentActiveConnection?.type === 'sso') { + return (currentActiveConnection as any).ssoRegion + } else if (currentActiveConnection?.type === 'iam') { + return (currentActiveConnection as any).region + } + return undefined + }), + getSessionName: sinon.stub().resolves(undefined), + getRoleArn: sinon.stub().resolves(undefined), get activeConnection() { return currentActiveConnection }, @@ -47,6 +61,9 @@ describe('SageMakerUnifiedStudioAuthInfoNode', function () { }, } + // Stub getContext to return false for IAM mode by default (SSO connections) + sinon.stub(require('../../../../shared/vscode/setContext'), 'getContext').returns(false) + // Stub SmusAuthenticationProvider.fromContext sinon.stub(SmusAuthenticationProvider, 'fromContext').returns(mockAuthProvider as any) @@ -80,8 +97,8 @@ describe('SageMakerUnifiedStudioAuthInfoNode', function () { mockAuthProvider.activeConnection = mockConnection }) - it('should return connected tree item', function () { - const treeItem = authInfoNode.getTreeItem() + it('should return connected tree item', async function () { + const treeItem = await authInfoNode.getTreeItem() assert.strictEqual(treeItem.label, 'Domain: dzd_domainId') assert.strictEqual(treeItem.description, 'us-east-2') @@ -111,8 +128,8 @@ describe('SageMakerUnifiedStudioAuthInfoNode', function () { mockAuthProvider.activeConnection = mockConnection }) - it('should return expired tree item with reauthenticate command', function () { - const treeItem = authInfoNode.getTreeItem() + it('should return expired tree item with reauthenticate command', async function () { + const treeItem = await authInfoNode.getTreeItem() assert.strictEqual(treeItem.label, 'Domain: dzd_domainId (Expired) - Click to reauthenticate') assert.strictEqual(treeItem.description, 'us-east-2') @@ -142,8 +159,8 @@ describe('SageMakerUnifiedStudioAuthInfoNode', function () { mockAuthProvider.activeConnection = undefined }) - it('should return not connected tree item', function () { - const treeItem = authInfoNode.getTreeItem() + it('should return not connected tree item', async function () { + const treeItem = await authInfoNode.getTreeItem() assert.strictEqual(treeItem.label, 'Not Connected') assert.strictEqual(treeItem.description, undefined) @@ -176,8 +193,8 @@ describe('SageMakerUnifiedStudioAuthInfoNode', function () { mockAuthProvider.activeConnection = incompleteConnection }) - it('should handle missing domain ID and region gracefully', function () { - const treeItem = authInfoNode.getTreeItem() + it('should handle missing domain ID and region gracefully', async function () { + const treeItem = await authInfoNode.getTreeItem() assert.strictEqual(treeItem.label, 'Domain: Unknown') assert.strictEqual(treeItem.description, 'Unknown') @@ -220,32 +237,32 @@ describe('SageMakerUnifiedStudioAuthInfoNode', function () { }) describe('theme icon colors', function () { - it('should use green color for connected state', function () { + it('should use green color for connected state', async function () { mockAuthProvider.isConnected.returns(true) mockAuthProvider.isConnectionValid.returns(true) - const treeItem = authInfoNode.getTreeItem() + const treeItem = await authInfoNode.getTreeItem() const icon = treeItem.iconPath as vscode.ThemeIcon assert.ok(icon.color instanceof vscode.ThemeColor) assert.strictEqual((icon.color as any).id, 'charts.green') }) - it('should use yellow color for expired state', function () { + it('should use yellow color for expired state', async function () { mockAuthProvider.isConnected.returns(true) mockAuthProvider.isConnectionValid.returns(false) - const treeItem = authInfoNode.getTreeItem() + const treeItem = await authInfoNode.getTreeItem() const icon = treeItem.iconPath as vscode.ThemeIcon assert.ok(icon.color instanceof vscode.ThemeColor) assert.strictEqual((icon.color as any).id, 'charts.yellow') }) - it('should use red color for not connected state', function () { + it('should use red color for not connected state', async function () { mockAuthProvider.isConnected.returns(false) - const treeItem = authInfoNode.getTreeItem() + const treeItem = await authInfoNode.getTreeItem() const icon = treeItem.iconPath as vscode.ThemeIcon assert.ok(icon.color instanceof vscode.ThemeColor) @@ -254,11 +271,11 @@ describe('SageMakerUnifiedStudioAuthInfoNode', function () { }) describe('tooltip content', function () { - it('should include all relevant information for connected state', function () { + it('should include all relevant information for connected state', async function () { mockAuthProvider.isConnected.returns(true) mockAuthProvider.isConnectionValid.returns(true) - const treeItem = authInfoNode.getTreeItem() + const treeItem = await authInfoNode.getTreeItem() const tooltip = treeItem.tooltip as string assert.ok(tooltip.includes('Connected to SageMaker Unified Studio')) @@ -267,25 +284,143 @@ describe('SageMakerUnifiedStudioAuthInfoNode', function () { assert.ok(tooltip.includes('Status: Connected')) }) - it('should include expiration information for expired state', function () { + it('should include expiration information for expired state', async function () { mockAuthProvider.isConnected.returns(true) mockAuthProvider.isConnectionValid.returns(false) - const treeItem = authInfoNode.getTreeItem() + const treeItem = await authInfoNode.getTreeItem() const tooltip = treeItem.tooltip as string assert.ok(tooltip.includes('Connection to SageMaker Unified Studio has expired')) assert.ok(tooltip.includes('Status: Expired - Click to reauthenticate')) }) - it('should include sign-in prompt for not connected state', function () { + it('should include sign-in prompt for not connected state', async function () { mockAuthProvider.isConnected.returns(false) - const treeItem = authInfoNode.getTreeItem() + const treeItem = await authInfoNode.getTreeItem() const tooltip = treeItem.tooltip as string assert.ok(tooltip.includes('Not connected to SageMaker Unified Studio')) assert.ok(tooltip.includes('Please sign in to access your projects')) }) }) + + describe('IAM connections in IAM mode', function () { + let mockIamConnection: any + + beforeEach(function () { + mockIamConnection = { + id: 'profile:test-profile', + type: 'iam', + label: 'Test IAM Profile', + profileName: 'test-profile', + region: 'us-west-2', + domainUrl: 'https://dzd_domainId.sagemaker.us-west-2.on.aws', + domainId: 'dzd_domainId', + getCredentials: sinon.stub().resolves(), + } + + currentActiveConnection = mockIamConnection + + // Override getContext stub to return true for IAM mode + const getContextModule = require('../../../../shared/vscode/setContext') + const existingStub = getContextModule.getContext as sinon.SinonStub + existingStub.withArgs('aws.smus.isIamMode').returns(true) + }) + + it('should display profile name with session name for IAM connection', async function () { + mockAuthProvider.isConnected.returns(true) + mockAuthProvider.isConnectionValid.returns(true) + mockAuthProvider.getSessionName = sinon.stub().resolves('my-session-name') + mockAuthProvider.getIamPrincipalArn = sinon + .stub() + .resolves('arn:aws:sts::123456789012:assumed-role/MyRole/my-session-name') + + const treeItem = await authInfoNode.getTreeItem() + + assert.strictEqual(treeItem.label, 'Connected with profile: test-profile (session: my-session-name)') + assert.strictEqual(treeItem.description, 'us-west-2') + }) + + it('should display profile name without session name when unavailable', async function () { + mockAuthProvider.isConnected.returns(true) + mockAuthProvider.isConnectionValid.returns(true) + mockAuthProvider.getSessionName = sinon.stub().resolves(undefined) + mockAuthProvider.getIamPrincipalArn = sinon.stub().resolves(undefined) + + const treeItem = await authInfoNode.getTreeItem() + + assert.strictEqual(treeItem.label, 'Connected with profile: test-profile') + assert.strictEqual(treeItem.description, 'us-west-2') + }) + + it('should include session name and role ARN in tooltip when available', async function () { + mockAuthProvider.isConnected.returns(true) + mockAuthProvider.isConnectionValid.returns(true) + mockAuthProvider.getSessionName = sinon.stub().resolves('my-session-name') + mockAuthProvider.getIamPrincipalArn = sinon + .stub() + .resolves('arn:aws:sts::123456789012:assumed-role/MyRole/my-session-name') + + const treeItem = await authInfoNode.getTreeItem() + const tooltip = treeItem.tooltip as string + + assert.ok(tooltip.includes('Connected to SageMaker Unified Studio')) + assert.ok(tooltip.includes('Profile: test-profile')) + assert.ok(tooltip.includes('Region: us-west-2')) + assert.ok(tooltip.includes('Session: my-session-name')) + assert.ok(tooltip.includes('Role ARN: arn:aws:sts::123456789012:assumed-role/MyRole/my-session-name')) + assert.ok(tooltip.includes('Status: Connected')) + }) + + it('should not include session name or role ARN in tooltip when unavailable', async function () { + mockAuthProvider.isConnected.returns(true) + mockAuthProvider.isConnectionValid.returns(true) + mockAuthProvider.getSessionName = sinon.stub().resolves(undefined) + mockAuthProvider.getIamPrincipalArn = sinon.stub().resolves(undefined) + + const treeItem = await authInfoNode.getTreeItem() + const tooltip = treeItem.tooltip as string + + assert.ok(tooltip.includes('Connected to SageMaker Unified Studio')) + assert.ok(tooltip.includes('Profile: test-profile')) + assert.ok(tooltip.includes('Region: us-west-2')) + assert.ok(!tooltip.includes('Session:')) + assert.ok(!tooltip.includes('Role ARN:')) + assert.ok(tooltip.includes('Status: Connected')) + }) + + it('should handle getSessionName errors gracefully', async function () { + mockAuthProvider.isConnected.returns(true) + mockAuthProvider.isConnectionValid.returns(true) + mockAuthProvider.getSessionName = sinon.stub().resolves(undefined) // Return undefined instead of rejecting + mockAuthProvider.getIamPrincipalArn = sinon.stub().resolves(undefined) + + // Should not throw, just display without session name + const treeItem = await authInfoNode.getTreeItem() + + assert.strictEqual(treeItem.label, 'Connected with profile: test-profile') + assert.strictEqual(treeItem.description, 'us-west-2') + }) + + it('should display expired IAM connection with profile name', async function () { + mockAuthProvider.isConnected.returns(true) + mockAuthProvider.isConnectionValid.returns(false) + mockAuthProvider.getSessionName = sinon.stub().resolves('my-session-name') + + const treeItem = await authInfoNode.getTreeItem() + + assert.strictEqual(treeItem.label, 'Profile: test-profile (Expired) - Click to reauthenticate') + assert.strictEqual(treeItem.description, 'us-west-2') + + // Check icon + assert.ok(treeItem.iconPath instanceof vscode.ThemeIcon) + assert.strictEqual((treeItem.iconPath as vscode.ThemeIcon).id, 'warning') + + // Should have reauthenticate command + assert.ok(treeItem.command) + assert.strictEqual(treeItem.command.command, 'aws.smus.reauthenticate') + }) + }) }) diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioComputeNode.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioComputeNode.test.ts index fc74eeab435..d1b05a547e6 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioComputeNode.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioComputeNode.test.ts @@ -11,6 +11,7 @@ import { SageMakerUnifiedStudioProjectNode } from '../../../../sagemakerunifieds import { SageMakerUnifiedStudioSpacesParentNode } from '../../../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode' import { SagemakerClient } from '../../../../shared/clients/sagemaker' import { SmusAuthenticationProvider } from '../../../../sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider' +import * as setContext from '../../../../shared/vscode/setContext' describe('SageMakerUnifiedStudioComputeNode', function () { let computeNode: SageMakerUnifiedStudioComputeNode @@ -71,10 +72,13 @@ describe('SageMakerUnifiedStudioComputeNode', function () { assert.deepStrictEqual(children, []) }) - it('returns connection nodes and spaces node when project is selected', async function () { + it('returns connection nodes and spaces node when project is selected (non-IAM mode)', async function () { const mockProject = { id: 'project-123', name: 'Test Project' } ;(mockParent.getProject as sinon.SinonStub).returns(mockProject) + // Mock IAM mode to be false + sinon.stub(setContext, 'getContext').returns(false) + const children = await computeNode.getChildren() assert.strictEqual(children.length, 3) @@ -82,6 +86,19 @@ describe('SageMakerUnifiedStudioComputeNode', function () { assert.strictEqual(children[1].id, 'Data processing') assert.ok(children[2] instanceof SageMakerUnifiedStudioSpacesParentNode) }) + + it('returns only spaces node when project is selected (IAM mode)', async function () { + const mockProject = { id: 'project-123', name: 'Test Project' } + ;(mockParent.getProject as sinon.SinonStub).returns(mockProject) + + // Mock IAM mode to be true + sinon.stub(setContext, 'getContext').returns(true) + + const children = await computeNode.getChildren() + + assert.strictEqual(children.length, 1) + assert.ok(children[0] instanceof SageMakerUnifiedStudioSpacesParentNode) + }) }) describe('getParent', function () { diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionParentNode.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionParentNode.test.ts index 686c85a0055..18778c52664 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionParentNode.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioConnectionParentNode.test.ts @@ -50,14 +50,24 @@ describe('SageMakerUnifiedStudioConnectionParentNode', function () { } as any mockComputeNode = { - authProvider: {} as any, + authProvider: { + getDerCredentialsProvider: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), + getDomainId: sinon.stub().returns('domain-123'), + getDomainRegion: sinon.stub().returns('us-east-1'), + } as any, parent: { project: mockProject, } as any, } as any // Stub static methods - sinon.stub(DataZoneClient, 'getInstance').resolves(mockDataZoneClient as any) + sinon.stub(DataZoneClient, 'createWithCredentials').resolves(mockDataZoneClient as any) sinon.stub(getLogger(), 'debug') connectionParentNode = new SageMakerUnifiedStudioConnectionParentNode( @@ -135,7 +145,17 @@ describe('SageMakerUnifiedStudioConnectionParentNode', function () { it('handles missing project information gracefully', async function () { const nodeWithoutProject = new SageMakerUnifiedStudioConnectionParentNode( { - authProvider: {} as any, + authProvider: { + getDerCredentialsProvider: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), + getDomainId: sinon.stub().returns('domain-123'), + getDomainRegion: sinon.stub().returns('us-east-1'), + } as any, parent: { project: undefined, } as any, @@ -164,7 +184,7 @@ describe('SageMakerUnifiedStudioConnectionParentNode', function () { describe('error handling', function () { it('handles DataZoneClient.getInstance error', async function () { sinon.restore() - sinon.stub(DataZoneClient, 'getInstance').rejects(new Error('Client error')) + sinon.stub(DataZoneClient, 'createWithCredentials').rejects(new Error('Client error')) sinon.stub(getLogger(), 'debug') try { diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioDataNode.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioDataNode.test.ts index 991e5955989..3d3ad01108c 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioDataNode.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioDataNode.test.ts @@ -13,6 +13,7 @@ import { SmusAuthenticationProvider } from '../../../../sagemakerunifiedstudio/a import * as s3Strategy from '../../../../sagemakerunifiedstudio/explorer/nodes/s3Strategy' import * as redshiftStrategy from '../../../../sagemakerunifiedstudio/explorer/nodes/redshiftStrategy' import * as lakehouseStrategy from '../../../../sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy' +import * as setContext from '../../../../shared/vscode/setContext' describe('SageMakerUnifiedStudioDataNode', function () { let sandbox: sinon.SinonSandbox @@ -50,6 +51,14 @@ describe('SageMakerUnifiedStudioDataNode', function () { getProjectCredentialProvider: sandbox.stub().resolves(mockProjectCredentialProvider), getConnectionCredentialsProvider: sandbox.stub().resolves(mockProjectCredentialProvider), getDomainRegion: sandbox.stub().returns('us-east-1'), + getDomainId: sandbox.stub().returns('domain-123'), + getDerCredentialsProvider: sandbox.stub().resolves({ + getCredentials: sandbox.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), } as any mockDataZoneClient = { @@ -60,7 +69,7 @@ describe('SageMakerUnifiedStudioDataNode', function () { getRegion: sandbox.stub().returns('us-east-1'), } as any - sandbox.stub(DataZoneClient, 'getInstance').returns(mockDataZoneClient as any) + sandbox.stub(DataZoneClient, 'createWithCredentials').returns(mockDataZoneClient as any) sandbox.stub(SmusAuthenticationProvider, 'fromContext').returns(mockAuthProvider as any) sandbox.stub(s3Strategy, 'createS3ConnectionNode').returns({ id: 's3-node', @@ -161,6 +170,9 @@ describe('SageMakerUnifiedStudioDataNode', function () { { connectionId: 'redshift-conn', type: 'REDSHIFT', name: 'redshift-connection' }, ] + // Mock IAM mode to be false so Redshift connections are included + sandbox.stub(setContext, 'getContext').returns(false) + mockDataZoneClient.listConnections.resolves(mockConnections as any) mockDataZoneClient.getConnection .onFirstCall() @@ -209,7 +221,6 @@ describe('SageMakerUnifiedStudioDataNode', function () { const mockConnections = [{ connectionId: 's3-conn', type: 'S3', name: 's3-connection' }] mockDataZoneClient.listConnections.resolves(mockConnections as any) - mockDataZoneClient.getConnection.rejects(new Error('Connection error')) const children = await dataNode.getChildren() @@ -217,6 +228,9 @@ describe('SageMakerUnifiedStudioDataNode', function () { assert.strictEqual(children.length, 1) assert.strictEqual(children[0].id, 'bucket-parent') + // Mock connection credentials provider to reject when bucket is expanded + mockAuthProvider.getConnectionCredentialsProvider.rejects(new Error('Connection error')) + // Error should occur when expanding the Bucket node const bucketChildren = await children[0].getChildren!() assert.strictEqual(bucketChildren.length, 1) diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.test.ts index 2fd8317fe06..8c71b102064 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.test.ts @@ -38,6 +38,14 @@ describe('SageMakerUnifiedStudioProjectNode', function () { getProjectCredentialProvider: sinon.stub(), getDomainRegion: sinon.stub().returns('us-west-2'), getDomainAccountId: sinon.stub().resolves('123456789012'), + getDomainId: sinon.stub().returns('test-domain'), + getDerCredentialsProvider: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), } as any // Create mock extension context @@ -63,7 +71,7 @@ describe('SageMakerUnifiedStudioProjectNode', function () { } as any // Stub DataZoneClient static methods - sinon.stub(DataZoneClient, 'getInstance').returns(mockDataZoneClient as any) + sinon.stub(DataZoneClient, 'createWithCredentials').returns(mockDataZoneClient as any) // Stub SagemakerClient constructor sinon.stub(SagemakerClient.prototype, 'dispose') @@ -293,21 +301,22 @@ describe('SageMakerUnifiedStudioProjectNode', function () { assert.strictEqual(hasAccess, false) }) - it('returns false when getCredentials fails', async function () { + it('throws error when getCredentials fails', async function () { const mockCredProvider = { getCredentials: sinon.stub().rejects(new Error('Credentials error')), } projectNode['authProvider'].getProjectCredentialProvider = sinon.stub().resolves(mockCredProvider) - const hasAccess = await projectNode['checkProjectCredsAccess']('project-123') - assert.strictEqual(hasAccess, false) + await assert.rejects( + async () => await projectNode['checkProjectCredsAccess']('project-123'), + /Credentials error/ + ) }) - it('returns false when access check throws non-AccessDeniedException error', async function () { + it('throws error when access check throws non-AccessDeniedException error', async function () { projectNode['authProvider'].getProjectCredentialProvider = sinon.stub().rejects(new Error('Other error')) - const hasAccess = await projectNode['checkProjectCredsAccess']('project-123') - assert.strictEqual(hasAccess, false) + await assert.rejects(async () => await projectNode['checkProjectCredsAccess']('project-123'), /Other error/) }) }) diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.test.ts index 64b866c7704..facfa867aca 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.test.ts @@ -18,6 +18,12 @@ import * as pickerPrompter from '../../../../shared/ui/pickerPrompter' import { getTestWindow } from '../../../shared/vscode/window' import { assertTelemetry } from '../../../../../src/test/testUtil' import { createMockExtensionContext, createMockUnauthenticatedAuthProvider } from '../../testUtils' +import { DataZoneCustomClientHelper } from '../../../../sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper' +import { DefaultStsClient } from '../../../../shared/clients/stsClient' +import { SmusUtils, SmusErrorCodes } from '../../../../sagemakerunifiedstudio/shared/smusUtils' +import { ToolkitError } from '../../../../shared/errors' +import * as utils from '../../../../sagemakerunifiedstudio/explorer/nodes/utils' +import * as setContextModule from '../../../../shared/vscode/setContext' describe('SmusRootNode', function () { let rootNode: SageMakerUnifiedStudioRootNode @@ -81,7 +87,7 @@ describe('SmusRootNode', function () { } as any // Stub DataZoneClient static methods - sinon.stub(DataZoneClient, 'getInstance').returns(mockDataZoneClient as any) + sinon.stub(DataZoneClient, 'createWithCredentials').returns(mockDataZoneClient as any) }) afterEach(function () { @@ -194,7 +200,13 @@ describe('SmusRootNode', function () { const mockAuthProvider = { isConnected: sinon.stub().returns(true), isConnectionValid: sinon.stub().returns(false), - activeConnection: { domainId: testDomainId, ssoRegion: 'us-west-2' }, + activeConnection: { + type: 'sso', + domainId: testDomainId, + ssoRegion: 'us-west-2', + domainUrl: 'https://test-domain.datazone.aws.amazon.com', + scopes: ['datazone:domain:access'], + }, onDidChange: sinon.stub().returns({ dispose: sinon.stub() }), showReauthenticationPrompt: sinon.stub(), } as any @@ -231,6 +243,8 @@ describe('SelectSMUSProject', function () { let mockProjectNode: sinon.SinonStubbedInstance let createQuickPickStub: sinon.SinonStub let executeCommandStub: sinon.SinonStub + let getContextStub: sinon.SinonStub + let createDZClientStub: sinon.SinonStub const testDomainId = 'test-domain-123' const mockProject: DataZoneProject = { @@ -264,8 +278,10 @@ describe('SelectSMUSProject', function () { project: undefined, } as any - // Stub DataZoneClient static methods - sinon.stub(DataZoneClient, 'getInstance').returns(mockDataZoneClient as any) + // Stub createDZClientBaseOnDomainMode to return our mock client + createDZClientStub = sinon.stub() + createDZClientStub.resolves(mockDataZoneClient) + sinon.replace(utils, 'createDZClientBaseOnDomainMode', createDZClientStub) // Stub SmusAuthenticationProvider sinon.stub(SmusAuthenticationProvider, 'fromContext').returns({ @@ -275,8 +291,21 @@ describe('SelectSMUSProject', function () { getDomainAccountId: sinon.stub().resolves('123456789012'), getDomainId: sinon.stub().returns(testDomainId), getDomainRegion: sinon.stub().returns('us-west-2'), + getDerCredentialsProvider: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), } as any) + // Stub getContext to return false for IAM mode by default (non-IAM mode) + getContextStub = sinon.stub() + getContextStub.withArgs('aws.smus.isIamMode').returns(false) + getContextStub.callThrough() + sinon.replace(setContextModule, 'getContext', getContextStub) + // Stub quickPick - return the project directly (not wrapped in an item) const mockQuickPick = { prompt: sinon.stub().resolves(mockProject), @@ -432,6 +461,8 @@ describe('selectSMUSProject - Additional Tests', function () { let mockProjectNode: sinon.SinonStubbedInstance let createQuickPickStub: sinon.SinonStub let executeCommandStub: sinon.SinonStub + let getContextStub: sinon.SinonStub + let createDZClientStub: sinon.SinonStub const testDomainId = 'test-domain-123' const mockProject: DataZoneProject = { @@ -452,14 +483,30 @@ describe('selectSMUSProject - Additional Tests', function () { setProject: sinon.stub(), } as any - sinon.stub(DataZoneClient, 'getInstance').returns(mockDataZoneClient as any) + // Stub createDZClientBaseOnDomainMode to return our mock client + createDZClientStub = sinon.stub() + createDZClientStub.resolves(mockDataZoneClient) + sinon.replace(utils, 'createDZClientBaseOnDomainMode', createDZClientStub) sinon.stub(SmusAuthenticationProvider, 'fromContext').returns({ activeConnection: { domainId: testDomainId, ssoRegion: 'us-west-2' }, getDomainAccountId: sinon.stub().resolves('123456789012'), getDomainId: sinon.stub().returns(testDomainId), getDomainRegion: sinon.stub().returns('us-west-2'), + getDerCredentialsProvider: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), } as any) + // Stub getContext to return false for IAM mode by default (non-IAM mode) + getContextStub = sinon.stub() + getContextStub.withArgs('aws.smus.isIamMode').returns(false) + getContextStub.callThrough() + sinon.replace(setContextModule, 'getContext', getContextStub) + const mockQuickPick = { prompt: sinon.stub().resolves(mockProject), } @@ -497,16 +544,8 @@ describe('selectSMUSProject - Additional Tests', function () { assert.strictEqual(result, undefined) const testWindow = getTestWindow() assert.ok(testWindow.shownMessages.some((msg) => msg.message === 'No projects found in the domain')) - assert.ok( - createQuickPickStub.calledWith([ - { - label: 'No projects found', - detail: '', - description: '', - data: {}, - }, - ]) - ) + // When no projects are found, createQuickPick should not be called + assert.ok(!createQuickPickStub.called) }) it('handles invalid selected project object', async function () { @@ -525,3 +564,552 @@ describe('selectSMUSProject - Additional Tests', function () { assert.ok(!executeCommandStub.called) }) }) + +describe('selectSMUSProject - Express Mode', function () { + let mockDataZoneClient: sinon.SinonStubbedInstance + let mockProjectNode: sinon.SinonStubbedInstance + let createQuickPickStub: sinon.SinonStub + let executeCommandStub: sinon.SinonStub + let getContextStub: sinon.SinonStub + let getInstanceStub: sinon.SinonStub + let createDZClientStub: sinon.SinonStub + + const testDomainId = 'test-domain-123' + const testUserProfileId = 'user-profile-123' + + const userProject: DataZoneProject = { + id: 'project-123', + name: 'User Project', + description: 'Project created by user', + domainId: testDomainId, + createdBy: testUserProfileId, + updatedAt: new Date(), + } + + const otherUserProject: DataZoneProject = { + id: 'project-456', + name: 'Other User Project', + description: 'Project created by another user', + domainId: testDomainId, + createdBy: 'other-user-profile-456', + updatedAt: new Date(Date.now() - 86400000), + } + + beforeEach(function () { + const mockGroupProfileId = 'group-profile-123' + + mockDataZoneClient = { + getDomainId: sinon.stub().returns(testDomainId), + fetchAllProjects: sinon.stub(), + } as any + + mockProjectNode = { + setProject: sinon.stub(), + } as any + + // Stub createDZClientBaseOnDomainMode to return our mock client + createDZClientStub = sinon.stub() + createDZClientStub.resolves(mockDataZoneClient) + sinon.replace(utils, 'createDZClientBaseOnDomainMode', createDZClientStub) + + // Mock credentials provider + const mockCredentialsProvider = { + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + } + + const mockAuthProvider = { + activeConnection: { + type: 'iam' as const, + profileName: 'test-profile', + region: 'us-west-2', + domainId: testDomainId, + domainUrl: `https://${testDomainId}.sagemaker.us-west-2.on.aws/`, + }, + getDomainAccountId: sinon.stub().resolves('123456789012'), + getDomainId: sinon.stub().returns(testDomainId), + getDomainRegion: sinon.stub().returns('us-west-2'), + getCredentialsProviderForIamProfile: sinon.stub().resolves(mockCredentialsProvider), + getIamPrincipalArn: sinon.stub().resolves('arn:aws:sts::123456789012:assumed-role/TestRole/test-session'), + } + sinon.stub(SmusAuthenticationProvider, 'fromContext').returns(mockAuthProvider as any) + + // Mock DataZoneCustomClientHelper + const mockDataZoneCustomClientHelper = { + getGroupProfileId: sinon.stub().resolves(mockGroupProfileId), + } + getInstanceStub = sinon + .stub(DataZoneCustomClientHelper, 'getInstance') + .returns(mockDataZoneCustomClientHelper as any) + + // Mock STS client + sinon.stub(DefaultStsClient.prototype, 'getCallerIdentity').resolves({ + Arn: 'arn:aws:sts::123456789012:assumed-role/TestRole/test-session', + UserId: 'AIDAI123456789EXAMPLE:test-session', + Account: '123456789012', + }) + + // Mock SmusUtils - simulate IAM role session (not IAM user) + sinon.stub(SmusUtils, 'isIamUserArn').returns(false) + sinon.stub(SmusUtils, 'convertAssumedRoleArnToIamRoleArn').returns('arn:aws:iam::123456789012:role/TestRole') + + const mockQuickPick = { + prompt: sinon.stub().resolves(userProject), + } + createQuickPickStub = sinon.stub(pickerPrompter, 'createQuickPick').returns(mockQuickPick as any) + executeCommandStub = sinon.stub(vscode.commands, 'executeCommand') + + // Stub getContext to simulate IAM mode + getContextStub = sinon.stub() + getContextStub.withArgs('aws.smus.isIamMode').returns(true) + getContextStub.callThrough() + sinon.replace(setContextModule, 'getContext', getContextStub) + }) + + afterEach(function () { + sinon.restore() + }) + + it('filters projects to show only user-created projects in IAM mode', async function () { + mockDataZoneClient.fetchAllProjects.resolves([userProject, otherUserProject]) + + const result = await selectSMUSProject(mockProjectNode as any) + + // Verify DataZoneCustomClientHelper.getInstance was called + assert.ok(getInstanceStub.called) + + // Verify projects were fetched with group identifier + assert.ok(mockDataZoneClient.fetchAllProjects.calledOnce) + const fetchCallArgs = mockDataZoneClient.fetchAllProjects.getCall(0).args[0] + assert.ok(fetchCallArgs?.groupIdentifier) + + // Verify the project was selected and set + assert.strictEqual(result, userProject) + assert.ok(mockProjectNode.setProject.calledOnce) + assert.ok(executeCommandStub.calledWith('aws.smus.rootView.refresh')) + }) + + it('shows message when no user-created projects found in IAM mode', async function () { + mockDataZoneClient.fetchAllProjects.resolves([]) + + const result = await selectSMUSProject(mockProjectNode as any) + + // Verify DataZoneCustomClientHelper.getInstance was called + assert.ok(getInstanceStub.called) + + // Verify no projects were shown in quick pick + assert.ok(!createQuickPickStub.called) + + // Verify appropriate message was shown + const testWindow = getTestWindow() + assert.ok( + testWindow.shownMessages.some( + (msg) => msg.message === 'No accessible projects found for your IAM principal' + ) + ) + + // Verify no project was set + assert.strictEqual(result, undefined) + assert.ok(!mockProjectNode.setProject.called) + }) + + it('shows all user-created projects when multiple exist in IAM mode', async function () { + const userProject2: DataZoneProject = { + id: 'project-789', + name: 'Another User Project', + description: 'Another project created by user', + domainId: testDomainId, + createdBy: testUserProfileId, + updatedAt: new Date(Date.now() - 172800000), // 2 days ago + } + + // In IAM mode, fetchAllProjects is called with groupIdentifier filter + // So the API returns only projects for that group (already filtered) + mockDataZoneClient.fetchAllProjects.resolves([userProject, userProject2]) + + await selectSMUSProject(mockProjectNode as any) + + // Verify projects were fetched with group identifier + assert.ok(mockDataZoneClient.fetchAllProjects.calledOnce) + const fetchCallArgs = mockDataZoneClient.fetchAllProjects.getCall(0).args[0] + assert.ok(fetchCallArgs?.groupIdentifier) + + // Verify all returned projects are shown in quick pick + const quickPickCall = createQuickPickStub.getCall(0) + const items = quickPickCall.args[0] + assert.strictEqual(items.length, 2) + assert.ok(items.some((item: any) => item.data.id === userProject.id)) + assert.ok(items.some((item: any) => item.data.id === userProject2.id)) + }) + + it('does not filter projects in non-IAM mode', async function () { + // Stub getContext to return false for IAM mode + getContextStub.withArgs('aws.smus.isIamMode').returns(false) + + mockDataZoneClient.fetchAllProjects.resolves([userProject, otherUserProject]) + + await selectSMUSProject(mockProjectNode as any) + + // Verify DataZoneCustomClientHelper.getInstance was NOT called in non-IAM mode + assert.ok(!getInstanceStub.called) + + // Verify all projects are shown in quick pick + const quickPickCall = createQuickPickStub.getCall(0) + const items = quickPickCall.args[0] + assert.strictEqual(items.length, 2) + assert.ok(items.some((item: any) => item.data.id === userProject.id)) + assert.ok(items.some((item: any) => item.data.id === otherUserProject.id)) + }) +}) + +describe('selectSMUSProject - Error Handling', function () { + let mockDataZoneClient: sinon.SinonStubbedInstance + let mockProjectNode: sinon.SinonStubbedInstance + let createQuickPickStub: sinon.SinonStub + let getContextStub: sinon.SinonStub + let createDZClientStub: sinon.SinonStub + + const testDomainId = 'test-domain-123' + const testUserProfileId = 'user-profile-123' + + const mockProject: DataZoneProject = { + id: 'project-123', + name: 'Test Project', + description: 'Test Description', + domainId: testDomainId, + createdBy: testUserProfileId, + updatedAt: new Date(), + } + + beforeEach(function () { + mockDataZoneClient = { + getDomainId: sinon.stub().returns(testDomainId), + fetchAllProjects: sinon.stub(), + getUserProfileId: sinon.stub().resolves(testUserProfileId), + } as any + + mockProjectNode = { + setProject: sinon.stub(), + } as any + + createDZClientStub = sinon.stub() + createDZClientStub.resolves(mockDataZoneClient) + sinon.replace(utils, 'createDZClientBaseOnDomainMode', createDZClientStub) + + sinon.stub(SmusAuthenticationProvider, 'fromContext').returns({ + activeConnection: { domainId: testDomainId, ssoRegion: 'us-west-2' }, + getDomainAccountId: sinon.stub().resolves('123456789012'), + getDomainId: sinon.stub().returns(testDomainId), + getDomainRegion: sinon.stub().returns('us-west-2'), + } as any) + + getContextStub = sinon.stub() + getContextStub.withArgs('aws.smus.isIamMode').returns(false) + getContextStub.callThrough() + sinon.replace(setContextModule, 'getContext', getContextStub) + + const mockQuickPick = { + prompt: sinon.stub().resolves(mockProject), + } + createQuickPickStub = sinon.stub(pickerPrompter, 'createQuickPick').returns(mockQuickPick as any) + }) + + afterEach(function () { + sinon.restore() + }) + + describe('No projects scenario', function () { + it('displays "No projects found in the domain" message when user has no projects', async function () { + mockDataZoneClient.fetchAllProjects.resolves([]) + + const result = await selectSMUSProject(mockProjectNode as any) + + assert.strictEqual(result, undefined) + const testWindow = getTestWindow() + assert.ok(testWindow.shownMessages.some((msg) => msg.message === 'No projects found in the domain')) + // createQuickPick should NOT be called when there are no projects + assert.ok(!createQuickPickStub.called) + assert.ok(!mockProjectNode.setProject.called) + }) + }) + + describe('No accessible projects in IAM mode', function () { + beforeEach(function () { + getContextStub.withArgs('aws.smus.isIamMode').returns(true) + + // Override the SSO connection with IAM connection for IAM mode tests + sinon.restore() + + // Re-setup mocks with IAM connection + mockDataZoneClient = { + getDomainId: sinon.stub().returns(testDomainId), + fetchAllProjects: sinon.stub(), + getUserProfileId: sinon.stub().resolves(testUserProfileId), + } as any + + mockProjectNode = { + setProject: sinon.stub(), + } as any + + createDZClientStub = sinon.stub() + createDZClientStub.resolves(mockDataZoneClient) + sinon.replace(utils, 'createDZClientBaseOnDomainMode', createDZClientStub) + + // Mock credentials provider + const mockCredentialsProvider = { + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + } + + const mockAuthProvider = { + activeConnection: { + type: 'iam' as const, + profileName: 'test-profile', + region: 'us-west-2', + domainId: testDomainId, + domainUrl: `https://${testDomainId}.sagemaker.us-west-2.on.aws/`, + }, + getDomainAccountId: sinon.stub().resolves('123456789012'), + getDomainId: sinon.stub().returns(testDomainId), + getDomainRegion: sinon.stub().returns('us-west-2'), + getCredentialsProviderForIamProfile: sinon.stub().resolves(mockCredentialsProvider), + getIamPrincipalArn: sinon + .stub() + .resolves('arn:aws:sts::123456789012:assumed-role/TestRole/test-session'), + } + sinon.stub(SmusAuthenticationProvider, 'fromContext').returns(mockAuthProvider as any) + + // Mock DataZoneCustomClientHelper + const mockDataZoneCustomClientHelper = { + getGroupProfileId: sinon.stub().resolves('group-profile-123'), + } + sinon.stub(DataZoneCustomClientHelper, 'getInstance').returns(mockDataZoneCustomClientHelper as any) + + // Mock STS client + sinon.stub(DefaultStsClient.prototype, 'getCallerIdentity').resolves({ + Arn: 'arn:aws:sts::123456789012:assumed-role/TestRole/test-session', + UserId: 'AIDAI123456789EXAMPLE:test-session', + Account: '123456789012', + }) + + // Mock SmusUtils - simulate IAM role session (not IAM user) + sinon.stub(SmusUtils, 'isIamUserArn').returns(false) + sinon + .stub(SmusUtils, 'convertAssumedRoleArnToIamRoleArn') + .returns('arn:aws:iam::123456789012:role/TestRole') + + const mockQuickPick = { + prompt: sinon.stub().resolves(mockProject), + } + createQuickPickStub = sinon.stub(pickerPrompter, 'createQuickPick').returns(mockQuickPick as any) + + getContextStub = sinon.stub() + getContextStub.withArgs('aws.smus.isIamMode').returns(true) + getContextStub.callThrough() + sinon.replace(setContextModule, 'getContext', getContextStub) + }) + + it('displays "No accessible projects found" when user has no projects they created', async function () { + // In IAM mode, fetchAllProjects is called with groupIdentifier filter + // which should return empty array when no projects match + mockDataZoneClient.fetchAllProjects.resolves([]) + + const result = await selectSMUSProject(mockProjectNode as any) + + assert.strictEqual(result, undefined) + const testWindow = getTestWindow() + assert.ok( + testWindow.shownMessages.some( + (msg) => msg.message === 'No accessible projects found for your IAM principal' + ) + ) + assert.ok(!mockProjectNode.setProject.called) + }) + + it('handles getGroupProfileId failure with appropriate error message', async function () { + const testRoleArn = 'arn:aws:iam::123456789012:role/TestRole' + + // Mock getGroupProfileId to throw a ToolkitError with NoGroupProfileFound code + const groupProfileError = new ToolkitError(`No group profile found for IAM role: ${testRoleArn}`, { + code: SmusErrorCodes.NoGroupProfileFound, + name: 'ToolkitError', + }) + const mockDataZoneCustomClientHelper = { + getGroupProfileId: sinon.stub().rejects(groupProfileError), + } + sinon.restore() + sinon.stub(DataZoneCustomClientHelper, 'getInstance').returns(mockDataZoneCustomClientHelper as any) + + // Re-stub other dependencies + const mockAuthProvider = { + activeConnection: { + type: 'iam' as const, + profileName: 'test-profile', + region: 'us-west-2', + domainId: testDomainId, + domainUrl: `https://${testDomainId}.sagemaker.us-west-2.on.aws/`, + }, + getDomainAccountId: sinon.stub().resolves('123456789012'), + getDomainId: sinon.stub().returns(testDomainId), + getDomainRegion: sinon.stub().returns('us-west-2'), + getCredentialsProviderForIamProfile: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), + getIamPrincipalArn: sinon.stub().resolves(testRoleArn), + } + sinon.stub(SmusAuthenticationProvider, 'fromContext').returns(mockAuthProvider as any) + + sinon.stub(DefaultStsClient.prototype, 'getCallerIdentity').resolves({ + Arn: 'arn:aws:sts::123456789012:assumed-role/TestRole/test-session', + UserId: 'AIDAI123456789EXAMPLE:test-session', + Account: '123456789012', + }) + + const getContextStub = sinon.stub() + getContextStub.withArgs('aws.smus.isIamMode').returns(true) + getContextStub.callThrough() + sinon.replace(setContextModule, 'getContext', getContextStub) + + sinon.replace(utils, 'createDZClientBaseOnDomainMode', sinon.stub().resolves(mockDataZoneClient)) + + const result = await selectSMUSProject(mockProjectNode as any) + + assert.ok(result instanceof Error) + const testWindow = getTestWindow() + assert.ok( + testWindow.shownMessages.some((msg) => + msg.message.includes(`No resources found for IAM principal: ${testRoleArn}`) + ) + ) + }) + + it('handles getUserProfileIdForIamPrincipal failure with appropriate error message', async function () { + const testUserArn = 'arn:aws:iam::123456789012:user/test-user' + + // Mock getUserProfileIdForIamPrincipal to throw a ToolkitError with NoUserProfileFound code + const userProfileError = new ToolkitError(`No user profile found for IAM principal: ${testUserArn}`, { + code: SmusErrorCodes.NoUserProfileFound, + name: 'ToolkitError', + }) + + sinon.restore() + + // Re-stub dependencies for IAM user flow + const mockAuthProvider = { + activeConnection: { + type: 'iam' as const, + profileName: 'test-profile', + region: 'us-west-2', + domainId: testDomainId, + domainUrl: `https://${testDomainId}.sagemaker.us-west-2.on.aws/`, + }, + getDomainAccountId: sinon.stub().resolves('123456789012'), + getDomainId: sinon.stub().returns(testDomainId), + getDomainRegion: sinon.stub().returns('us-west-2'), + getCredentialsProviderForIamProfile: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), + getIamPrincipalArn: sinon.stub().resolves(testUserArn), + } + sinon.stub(SmusAuthenticationProvider, 'fromContext').returns(mockAuthProvider as any) + + // Mock SmusUtils to indicate this is an IAM user + sinon.stub(SmusUtils, 'isIamUserArn').returns(true) + + const getContextStub = sinon.stub() + getContextStub.withArgs('aws.smus.isIamMode').returns(true) + getContextStub.callThrough() + sinon.replace(setContextModule, 'getContext', getContextStub) + + // Create a new mock client with the failing method + const failingMockDataZoneClient = { + getDomainId: sinon.stub().returns(testDomainId), + fetchAllProjects: sinon.stub(), + getUserProfileIdForIamPrincipal: sinon.stub().rejects(userProfileError), + } as any + + sinon.replace(utils, 'createDZClientBaseOnDomainMode', sinon.stub().resolves(failingMockDataZoneClient)) + + const result = await selectSMUSProject(mockProjectNode as any) + + assert.ok(result instanceof Error) + const testWindow = getTestWindow() + assert.ok( + testWindow.shownMessages.some((msg) => + msg.message.includes(`No resources found for IAM principal: ${testUserArn}`) + ) + ) + }) + }) + + describe('Access denied scenarios', function () { + it('displays appropriate error message when user lacks permissions to view projects', async function () { + const accessDeniedError = new Error('Access denied to list projects') + accessDeniedError.name = 'AccessDeniedException' + mockDataZoneClient.fetchAllProjects.rejects(accessDeniedError) + + const result = await selectSMUSProject(mockProjectNode as any) + + assert.strictEqual(result, undefined) + assert.ok( + createQuickPickStub.calledWith([ + { + label: '$(error)', + description: "You don't have permissions to view projects. Please contact your administrator", + }, + ]) + ) + assert.ok(!mockProjectNode.setProject.called) + }) + + it('handles AccessDenied error name variations', async function () { + const accessDeniedError = new Error('Access denied') + accessDeniedError.name = 'AccessDeniedError' + mockDataZoneClient.fetchAllProjects.rejects(accessDeniedError) + + const result = await selectSMUSProject(mockProjectNode as any) + + assert.strictEqual(result, undefined) + assert.ok( + createQuickPickStub.calledWith([ + { + label: '$(error)', + description: "You don't have permissions to view projects. Please contact your administrator", + }, + ]) + ) + }) + + it('handles UnauthorizedOperation error as access denied', async function () { + const unauthorizedError = new Error('Unauthorized operation') + unauthorizedError.name = 'UnauthorizedOperationAccessDenied' + mockDataZoneClient.fetchAllProjects.rejects(unauthorizedError) + + const result = await selectSMUSProject(mockProjectNode as any) + + assert.strictEqual(result, undefined) + assert.ok( + createQuickPickStub.calledWith([ + { + label: '$(error)', + description: "You don't have permissions to view projects. Please contact your administrator", + }, + ]) + ) + }) + }) +}) diff --git a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode.test.ts b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode.test.ts index 31481e70953..6ac3d7e0147 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode.test.ts @@ -13,7 +13,11 @@ import { DataZoneClient } from '../../../../sagemakerunifiedstudio/shared/client import { SagemakerClient } from '../../../../shared/clients/sagemaker' import { SmusAuthenticationProvider } from '../../../../sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider' import { getLogger } from '../../../../shared/logger/logger' -import { SmusUtils } from '../../../../sagemakerunifiedstudio/shared/smusUtils' +import { SmusUtils, SmusErrorCodes } from '../../../../sagemakerunifiedstudio/shared/smusUtils' +import { ToolkitError } from '../../../../shared/errors' +import * as vscodeUtils from '../../../../shared/vscode/setContext' +import * as utils from '../../../../sagemakerunifiedstudio/explorer/nodes/utils' +import { DataZoneCustomClientHelper } from '../../../../sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper' describe('SageMakerUnifiedStudioSpacesParentNode', function () { let spacesNode: SageMakerUnifiedStudioSpacesParentNode @@ -29,7 +33,24 @@ describe('SageMakerUnifiedStudioSpacesParentNode', function () { extensionUri: vscode.Uri.file('/test'), } as any mockAuthProvider = { - activeConnection: { domainId: 'test-domain', ssoRegion: 'us-west-2' }, + activeConnection: { domainId: 'test-domain', ssoRegion: 'us-west-2', profileName: 'test-profile' }, + getDomainId: sinon.stub().returns('test-domain'), + getDomainRegion: sinon.stub().returns('us-west-2'), + getIamPrincipalArn: sinon.stub().resolves(undefined), + getDerCredentialsProvider: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), + getCredentialsProviderForIamProfile: sinon.stub().resolves({ + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + }), } as any mockSagemakerClient = sinon.createStubInstance(SagemakerClient) mockSagemakerClient.fetchSpaceAppsAndDomains.resolves([new Map(), new Map()]) @@ -44,10 +65,11 @@ describe('SageMakerUnifiedStudioSpacesParentNode', function () { getToolingEnvironment: sinon.stub(), } as any - sinon.stub(DataZoneClient, 'getInstance').resolves(mockDataZoneClient as any) + sinon.stub(DataZoneClient, 'createWithCredentials').resolves(mockDataZoneClient as any) sinon.stub(getLogger(), 'debug') sinon.stub(getLogger(), 'error') sinon.stub(SmusUtils, 'extractSSOIdFromUserId').returns('user-12345') + sinon.stub(vscodeUtils, 'getContext').returns(false) spacesNode = new SageMakerUnifiedStudioSpacesParentNode( mockParent, @@ -154,7 +176,7 @@ describe('SageMakerUnifiedStudioSpacesParentNode', function () { }) it('throws error when DataZone client not initialized', async function () { - ;(DataZoneClient.getInstance as sinon.SinonStub).resolves(undefined) + ;(DataZoneClient.createWithCredentials as sinon.SinonStub).resolves(undefined) await assert.rejects( async () => await spacesNode.getSageMakerDomainId(), @@ -418,4 +440,179 @@ describe('SageMakerUnifiedStudioSpacesParentNode', function () { await assert.rejects(async () => await spacesNode['updateChildren'](), /Access denied to spaces/) }) }) + + describe('IAM mode error handling', function () { + beforeEach(function () { + // Add getIamPrincipalArn stub to mockAuthProvider + mockAuthProvider.getIamPrincipalArn = sinon.stub().resolves('arn:aws:iam::123456789012:user/test-user') + }) + + it('should return no user profile error node when NoUserProfileFound error is thrown', async function () { + const noProfileError = new ToolkitError('No user profile found for IAM principal', { + code: SmusErrorCodes.NoUserProfileFound, + }) + const updateChildrenStub = sinon.stub(spacesNode as any, 'updateChildren') + updateChildrenStub.rejects(noProfileError) + + const children = await spacesNode.getChildren() + + assert.strictEqual(children.length, 1) + assert.strictEqual(children[0].id, 'smusNoUserProfile') + + const treeItem = await children[0].getTreeItem() + assert.strictEqual(treeItem.label, 'No spaces found for IAM principal') + }) + + it('should return no user profile error node when NoGroupProfileFound error is thrown', async function () { + const noProfileError = new ToolkitError('No group profile found for IAM role', { + code: SmusErrorCodes.NoGroupProfileFound, + }) + const updateChildrenStub = sinon.stub(spacesNode as any, 'updateChildren') + updateChildrenStub.rejects(noProfileError) + + const children = await spacesNode.getChildren() + + assert.strictEqual(children.length, 1) + assert.strictEqual(children[0].id, 'smusNoUserProfile') + + const treeItem = await children[0].getTreeItem() + assert.strictEqual(treeItem.label, 'No spaces found for IAM principal') + }) + + it('should return access denied error node when IAM mode returns AccessDeniedException', async function () { + const accessDeniedError = new Error("You don't have permissions to access this resource") + accessDeniedError.name = 'AccessDeniedException' + const updateChildrenStub = sinon.stub(spacesNode as any, 'updateChildren') + updateChildrenStub.rejects(accessDeniedError) + + const children = await spacesNode.getChildren() + + assert.strictEqual(children.length, 1) + assert.strictEqual(children[0].id, 'smusAccessDenied') + }) + + it('should return user profile error node when IAM mode returns generic error', async function () { + const genericError = new Error('Failed to retrieve user profile information') + const updateChildrenStub = sinon.stub(spacesNode as any, 'updateChildren') + updateChildrenStub.rejects(genericError) + + const children = await spacesNode.getChildren() + + assert.strictEqual(children.length, 1) + assert.strictEqual(children[0].id, 'smusUserProfileError') + + const treeItem = await children[0].getTreeItem() + assert.strictEqual(treeItem.label, 'Failed to retrieve spaces. Please try again.') + }) + }) + + describe('getUserProfileIdForIamAuthMode - IAM user flow', function () { + let createDZClientStub: sinon.SinonStub + let getContextStub: sinon.SinonStub + + beforeEach(function () { + getContextStub = vscodeUtils.getContext as sinon.SinonStub + getContextStub.withArgs('aws.smus.isIamMode').returns(true) + createDZClientStub = sinon.stub(utils, 'createDZClientBaseOnDomainMode') + }) + + afterEach(function () { + createDZClientStub.restore() + }) + + it('should use GetUserProfile API for IAM user', async function () { + const mockUserArn = 'arn:aws:iam::123456789012:user/test-user' + const mockUserProfileId = 'up_user123' + + mockAuthProvider.getIamPrincipalArn = sinon.stub().resolves(mockUserArn) + mockAuthProvider.getDomainId = sinon.stub().returns('domain-123') + + const mockGetUserProfileId = sinon.stub().resolves(mockUserProfileId) + mockDataZoneClient.getUserProfileIdForIamPrincipal = mockGetUserProfileId as any + createDZClientStub.resolves(mockDataZoneClient) + + const result = await spacesNode['getUserProfileIdForIamAuthMode']() + + assert.strictEqual(result, mockUserProfileId) + assert(mockGetUserProfileId.calledWith(mockUserArn, 'domain-123')) + }) + + it('should throw error when IAM user profile not found', async function () { + const mockUserArn = 'arn:aws:iam::123456789012:user/test-user' + + mockAuthProvider.getIamPrincipalArn = sinon.stub().resolves(mockUserArn) + mockAuthProvider.getDomainId = sinon.stub().returns('domain-123') + + mockDataZoneClient.getUserProfileIdForIamPrincipal = sinon.stub().resolves(undefined) as any + createDZClientStub.resolves(mockDataZoneClient) + + await assert.rejects( + async () => await spacesNode['getUserProfileIdForIamAuthMode'](), + /No user profile found for IAM user/ + ) + }) + + it('should throw error when caller ARN cannot be retrieved', async function () { + mockAuthProvider.getIamPrincipalArn = sinon.stub().resolves(undefined) + + await assert.rejects( + async () => await spacesNode['getUserProfileIdForIamAuthMode'](), + /Unable to retrieve caller identity ARN/ + ) + }) + }) + + describe('getUserProfileIdForIamAuthMode - IAM role session flow', function () { + let mockDataZoneCustomClientHelper: any + let getInstanceStub: sinon.SinonStub + let getContextStub: sinon.SinonStub + + beforeEach(function () { + getContextStub = vscodeUtils.getContext as sinon.SinonStub + getContextStub.withArgs('aws.smus.isIamMode').returns(true) + + mockDataZoneCustomClientHelper = { + getUserProfileIdForSession: sinon.stub(), + } + + // Mock the DataZoneCustomClientHelper.getInstance + getInstanceStub = sinon + .stub(DataZoneCustomClientHelper, 'getInstance') + .returns(mockDataZoneCustomClientHelper) + }) + + afterEach(function () { + getInstanceStub.restore() + }) + + it('should use SearchUserProfile API for IAM role session', async function () { + const mockRoleArn = 'arn:aws:iam::123456789012:role/TestRole' + const mockAssumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/TestRole/test-session' + const mockUserProfileId = 'up_session123' + + mockAuthProvider.getIamPrincipalArn = sinon.stub().resolves(mockRoleArn) + mockAuthProvider.getCachedIamCallerIdentityArn = sinon.stub().resolves(mockAssumedRoleArn) + mockAuthProvider.getDomainId = sinon.stub().returns('domain-123') + mockDataZoneCustomClientHelper.getUserProfileIdForSession.resolves(mockUserProfileId) + + const result = await spacesNode['getUserProfileIdForIamAuthMode']() + + assert.strictEqual(result, mockUserProfileId) + assert( + mockDataZoneCustomClientHelper.getUserProfileIdForSession.calledWith('domain-123', mockAssumedRoleArn) + ) + }) + + it('should throw error when assumed role ARN cannot be retrieved', async function () { + const mockRoleArn = 'arn:aws:iam::123456789012:role/TestRole' + + mockAuthProvider.getIamPrincipalArn = sinon.stub().resolves(mockRoleArn) + mockAuthProvider.getCachedIamCallerIdentityArn = sinon.stub().resolves(undefined) + + await assert.rejects( + async () => await spacesNode['getUserProfileIdForIamAuthMode'](), + /Unable to retrieve assumed role ARN with session/ + ) + }) + }) }) diff --git a/packages/core/src/test/sagemakerunifiedstudio/shared/client/datazoneClient.test.ts b/packages/core/src/test/sagemakerunifiedstudio/shared/client/datazoneClient.test.ts index 38dbd5e33f5..91a6a092bde 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/shared/client/datazoneClient.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/shared/client/datazoneClient.test.ts @@ -8,6 +8,9 @@ import * as sinon from 'sinon' import { DataZoneClient } from '../../../../sagemakerunifiedstudio/shared/client/datazoneClient' import { SmusAuthenticationProvider } from '../../../../sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider' import { GetEnvironmentCommandOutput } from '@aws-sdk/client-datazone/dist-types/commands/GetEnvironmentCommand' +import { DefaultStsClient } from '../../../../shared/clients/stsClient' +import { SmusUtils, SmusErrorCodes } from '../../../../sagemakerunifiedstudio/shared/smusUtils' +import { ToolkitError } from '../../../../shared/errors' describe('DataZoneClient', () => { let dataZoneClient: DataZoneClient @@ -18,6 +21,7 @@ describe('DataZoneClient', () => { beforeEach(async () => { // Create mock connection object const mockConnection = { + id: 'connection-id', domainId: testDomainId, ssoRegion: testRegion, } @@ -31,42 +35,62 @@ describe('DataZoneClient', () => { onDidChangeActiveConnection: sinon.stub().returns({ dispose: sinon.stub(), }), + secondaryAuth: { + state: { + get: sinon.stub().returns({ + 'connection-id': { + profileName: 'test-profile', + }, + }), + }, + }, + getCredentialsProviderForIamProfile: sinon.stub(), } as any - // Set up the DataZoneClient using getInstance since constructor is private - DataZoneClient.dispose() - dataZoneClient = await DataZoneClient.getInstance(mockAuthProvider) + // Create mock credentials provider + const mockCredentialsProvider = { + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + getCredentialsId: () => ({ credentialSource: 'temp' as const, credentialTypeId: 'test' }), + getProviderType: () => 'temp' as const, + getTelemetryType: () => 'other' as any, + getDefaultRegion: () => testRegion, + getHashCode: () => 'test-hash', + canAutoConnect: () => Promise.resolve(false), + isAvailable: () => Promise.resolve(true), + } + + // Set up the DataZoneClient using createWithCredentials + dataZoneClient = DataZoneClient.createWithCredentials(testRegion, testDomainId, mockCredentialsProvider) }) afterEach(() => { sinon.restore() }) - describe('getInstance', () => { - it('should return singleton instance', async () => { - const instance1 = await DataZoneClient.getInstance(mockAuthProvider) - const instance2 = await DataZoneClient.getInstance(mockAuthProvider) - - assert.strictEqual(instance1, instance2) - }) - - it('should create new instance after dispose', async () => { - const instance1 = await DataZoneClient.getInstance(mockAuthProvider) - DataZoneClient.dispose() - const instance2 = await DataZoneClient.getInstance(mockAuthProvider) - - assert.notStrictEqual(instance1, instance2) - }) - }) - - describe('dispose', () => { - it('should clear singleton instance', async () => { - const instance = await DataZoneClient.getInstance(mockAuthProvider) - DataZoneClient.dispose() + describe('createWithCredentials', () => { + it('should create new instance with credentials', () => { + const mockCredentialsProvider = { + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + }), + getCredentialsId: () => ({ credentialSource: 'temp' as const, credentialTypeId: 'test' }), + getProviderType: () => 'temp' as const, + getTelemetryType: () => 'other' as any, + getDefaultRegion: () => testRegion, + getHashCode: () => 'test-hash', + canAutoConnect: () => Promise.resolve(false), + isAvailable: () => Promise.resolve(true), + } - // Should create new instance after dispose - const newInstance = await DataZoneClient.getInstance(mockAuthProvider) - assert.notStrictEqual(instance, newInstance) + const instance = DataZoneClient.createWithCredentials(testRegion, testDomainId, mockCredentialsProvider) + assert.ok(instance) + assert.strictEqual(instance.getRegion(), testRegion) + assert.strictEqual(instance.getDomainId(), testDomainId) }) }) @@ -151,6 +175,9 @@ describe('DataZoneClient', () => { getEnvironmentCredentials: sinon.stub().resolves(mockCredentials), } + // Mock getToolingBlueprintName to return 'Tooling' + sinon.stub(dataZoneClient as any, 'getToolingBlueprintName').returns('Tooling') + sinon.stub(dataZoneClient as any, 'getDataZoneClient').resolves(mockDataZone) const result = await dataZoneClient.getProjectDefaultEnvironmentCreds('project-1') @@ -215,9 +242,7 @@ describe('DataZoneClient', () => { describe('fetchAllProjects', function () { it('fetches all projects by handling pagination', async function () { - const client = await DataZoneClient.getInstance(mockAuthProvider) - - // Create a stub for listProjects that returns paginated results + // Create a stub for listProjects that returns paginated resultssults const listProjectsStub = sinon.stub() // First call returns first page with nextToken @@ -247,10 +272,10 @@ describe('DataZoneClient', () => { }) // Replace the listProjects method with our stub - client.listProjects = listProjectsStub + dataZoneClient.listProjects = listProjectsStub // Call fetchAllProjects - const result = await client.fetchAllProjects() + const result = await dataZoneClient.fetchAllProjects() // Verify results assert.strictEqual(result.length, 2) @@ -270,8 +295,6 @@ describe('DataZoneClient', () => { }) it('returns empty array when no projects found', async function () { - const client = await DataZoneClient.getInstance(mockAuthProvider) - // Create a stub for listProjects that returns empty results const listProjectsStub = sinon.stub().resolves({ projects: [], @@ -279,10 +302,10 @@ describe('DataZoneClient', () => { }) // Replace the listProjects method with our stub - client.listProjects = listProjectsStub + dataZoneClient.listProjects = listProjectsStub // Call fetchAllProjects - const result = await client.fetchAllProjects() + const result = await dataZoneClient.fetchAllProjects() // Verify results assert.strictEqual(result.length, 0) @@ -290,16 +313,14 @@ describe('DataZoneClient', () => { }) it('handles errors gracefully', async function () { - const client = await DataZoneClient.getInstance(mockAuthProvider) - // Create a stub for listProjects that throws an error const listProjectsStub = sinon.stub().rejects(new Error('API error')) // Replace the listProjects method with our stub - client.listProjects = listProjectsStub + dataZoneClient.listProjects = listProjectsStub // Call fetchAllProjects and expect it to throw - await assert.rejects(() => client.fetchAllProjects(), /API error/) + await assert.rejects(() => dataZoneClient.fetchAllProjects(), /API error/) }) }) @@ -314,6 +335,9 @@ describe('DataZoneClient', () => { }), } + // Mock getToolingBlueprintName to return 'Tooling' + sinon.stub(dataZoneClient as any, 'getToolingBlueprintName').returns('Tooling') + sinon.stub(dataZoneClient as any, 'getDataZoneClient').resolves(mockDataZone) const result = await dataZoneClient.getToolingEnvironmentId('domain-1', 'project-1') @@ -374,6 +398,9 @@ describe('DataZoneClient', () => { getEnvironment: sinon.stub().resolves(mockEnvironment), } + // Mock getToolingBlueprintName to return 'Tooling' + sinon.stub(dataZoneClient as any, 'getToolingBlueprintName').returns('Tooling') + sinon.stub(dataZoneClient as any, 'getDataZoneClient').resolves(mockDataZone) const result = await dataZoneClient.getToolingEnvironment('project-123') @@ -395,7 +422,7 @@ describe('DataZoneClient', () => { await assert.rejects( () => dataZoneClient.getToolingEnvironment('project-123'), - /Failed to get tooling environment ID: No default Tooling environment found for project/ + /No default Tooling environment found for project/ ) }) @@ -480,4 +507,130 @@ describe('DataZoneClient', () => { await assert.rejects(() => dataZoneClient.fetchAllProjectMemberships('project-1'), error) }) }) + + describe('getUserProfileId', () => { + let stsClientStub: sinon.SinonStub + let convertAssumedRoleArnStub: sinon.SinonStub + let mockCredentialsProvider: any + + beforeEach(() => { + // Mock connection with ID + mockAuthProvider.activeConnection = { id: 'connection-id' } + + // Mock credentials provider + mockCredentialsProvider = { + getCredentials: sinon.stub().resolves({ + accessKeyId: 'id', + secretAccessKey: 'secret', + sessionToken: 'token', + }), + } + + mockAuthProvider.getCredentialsProviderForIamProfile.resolves(mockCredentialsProvider) + + // Stub STS client + stsClientStub = sinon.stub(DefaultStsClient.prototype, 'getCallerIdentity') + + // Stub SmusUtils method + convertAssumedRoleArnStub = sinon.stub(SmusUtils as any, 'convertAssumedRoleArnToIamRoleArn') + }) + + afterEach(() => { + stsClientStub.restore() + convertAssumedRoleArnStub.restore() + }) + + it('should successfully get user profile ID with role ARN', async () => { + const mockRoleArn = 'arn:aws:iam::123456789012:role/service-role/MyRole' + const mockUserProfileId = 'user-profile-123' + + const mockDataZone = { + getUserProfile: sinon.stub().resolves({ + id: mockUserProfileId, + userIdentifier: mockRoleArn, + }), + } + + sinon.stub(dataZoneClient as any, 'getDataZoneClient').resolves(mockDataZone) + + const result = await dataZoneClient.getUserProfileIdForIamPrincipal(mockRoleArn) + + assert.strictEqual(result, mockUserProfileId) + assert.ok( + mockDataZone.getUserProfile.calledWith({ + domainIdentifier: testDomainId, + userIdentifier: mockRoleArn, + }) + ) + }) + + it('should handle DataZone getUserProfile API failure', async () => { + const mockRoleArn = 'arn:aws:iam::123456789012:role/service-role/MyRole' + const datazoneError = new Error('DataZone API Error') + + const mockDataZone = { + getUserProfile: sinon.stub().rejects(datazoneError), + } + + sinon.stub(dataZoneClient as any, 'getDataZoneClient').resolves(mockDataZone) + + await assert.rejects( + async () => { + await dataZoneClient.getUserProfileIdForIamPrincipal(mockRoleArn) + }, + (error: Error) => { + assert.ok(error instanceof ToolkitError) + assert.ok(error.message.includes('Failed to get user profile ID')) + return true + } + ) + }) + + it('should get user profile ID for IAM user ARN', async () => { + const mockUserArn = 'arn:aws:iam::123456789012:user/test-user' + const mockUserProfileId = 'user-profile-456' + + const mockDataZone = { + getUserProfile: sinon.stub().resolves({ + id: mockUserProfileId, + userIdentifier: mockUserArn, + }), + } + + sinon.stub(dataZoneClient as any, 'getDataZoneClient').resolves(mockDataZone) + + const result = await dataZoneClient.getUserProfileIdForIamPrincipal(mockUserArn) + + assert.strictEqual(result, mockUserProfileId) + assert.ok( + mockDataZone.getUserProfile.calledWith({ + domainIdentifier: testDomainId, + userIdentifier: mockUserArn, + }) + ) + }) + + it('should throw error when user profile ID is not returned', async () => { + const mockUserArn = 'arn:aws:iam::123456789012:user/test-user' + + const mockDataZone = { + getUserProfile: sinon.stub().resolves({ + // No id field + }), + } + + sinon.stub(dataZoneClient as any, 'getDataZoneClient').resolves(mockDataZone) + + await assert.rejects( + async () => { + await dataZoneClient.getUserProfileIdForIamPrincipal(mockUserArn) + }, + (error: Error) => { + assert.ok(error instanceof ToolkitError) + assert.strictEqual((error as ToolkitError).code, SmusErrorCodes.NoUserProfileFound) + return true + } + ) + }) + }) }) diff --git a/packages/core/src/test/sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper.test.ts b/packages/core/src/test/sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper.test.ts new file mode 100644 index 00000000000..5d0f8f29ade --- /dev/null +++ b/packages/core/src/test/sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper.test.ts @@ -0,0 +1,1276 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as assert from 'assert' +import * as sinon from 'sinon' +import { DataZoneCustomClientHelper } from '../../../../sagemakerunifiedstudio/shared/client/datazoneCustomClientHelper' +import * as DataZoneCustomClient from '../../../../sagemakerunifiedstudio/shared/client/datazonecustomclient' + +type DataZoneDomain = DataZoneCustomClient.Types.DomainSummary + +describe('DataZoneCustomClientHelper', () => { + let client: DataZoneCustomClientHelper + let mockAuthProvider: any + const testRegion = 'us-east-1' + + beforeEach(() => { + // Create mock auth provider + mockAuthProvider = { + isConnected: sinon.stub().returns(true), + onDidChangeActiveConnection: sinon.stub().returns({ + dispose: sinon.stub(), + }), + } as any + + // Clear instances and create new client + DataZoneCustomClientHelper.dispose() + client = DataZoneCustomClientHelper.getInstance(mockAuthProvider, testRegion) + }) + + afterEach(() => { + sinon.restore() + DataZoneCustomClientHelper.dispose() + }) + + describe('getInstance', () => { + it('should return singleton instance for same region', () => { + const instance1 = DataZoneCustomClientHelper.getInstance(mockAuthProvider, testRegion) + const instance2 = DataZoneCustomClientHelper.getInstance(mockAuthProvider, testRegion) + + assert.strictEqual(instance1, instance2) + }) + + it('should create different instances for different regions', () => { + const instance1 = DataZoneCustomClientHelper.getInstance(mockAuthProvider, 'us-east-1') + const instance2 = DataZoneCustomClientHelper.getInstance(mockAuthProvider, 'us-west-2') + + assert.notStrictEqual(instance1, instance2) + }) + + it('should create new instance after dispose', () => { + const instance1 = DataZoneCustomClientHelper.getInstance(mockAuthProvider, testRegion) + DataZoneCustomClientHelper.dispose() + const instance2 = DataZoneCustomClientHelper.getInstance(mockAuthProvider, testRegion) + + assert.notStrictEqual(instance1, instance2) + }) + }) + + describe('dispose', () => { + it('should clear all instances', () => { + const instance1 = DataZoneCustomClientHelper.getInstance(mockAuthProvider, 'us-east-1') + const instance2 = DataZoneCustomClientHelper.getInstance(mockAuthProvider, 'us-west-2') + + DataZoneCustomClientHelper.dispose() + + // Should create new instance after dispose + const newInstance1 = DataZoneCustomClientHelper.getInstance(mockAuthProvider, 'us-east-1') + const newInstance2 = DataZoneCustomClientHelper.getInstance(mockAuthProvider, 'us-west-2') + + assert.notStrictEqual(instance1, newInstance1) + assert.notStrictEqual(instance2, newInstance2) + }) + }) + + describe('getRegion', () => { + it('should return configured region', () => { + const result = client.getRegion() + assert.strictEqual(result, testRegion) + }) + }) + + describe('listDomains', () => { + it('should list domains with pagination', async () => { + const mockResponse = { + items: [ + { + id: 'dzd_domain1', + name: 'Test Domain 1', + description: 'First test domain', + arn: 'arn:aws:datazone:us-east-1:123456789012:domain/dzd_domain1', + managedAccountId: '123456789012', + status: 'AVAILABLE', + portalUrl: 'https://domain1.datazone.aws', + createdAt: new Date('2023-01-01T00:00:00Z'), + lastUpdatedAt: new Date('2023-01-02T00:00:00Z'), + domainVersion: '1.0', + preferences: { DOMAIN_MODE: 'STANDARD' }, + }, + ], + nextToken: 'next-token', + } + + const mockDataZoneClient = { + listDomains: sinon.stub().returns({ + promise: () => Promise.resolve(mockResponse), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + const result = await client.listDomains({ + maxResults: 10, + status: 'AVAILABLE', + }) + + assert.strictEqual(result.domains.length, 1) + assert.strictEqual(result.domains[0].id, 'dzd_domain1') + assert.strictEqual(result.domains[0].name, 'Test Domain 1') + assert.strictEqual(result.domains[0].arn, 'arn:aws:datazone:us-east-1:123456789012:domain/dzd_domain1') + assert.strictEqual(result.domains[0].managedAccountId, '123456789012') + assert.strictEqual(result.domains[0].status, 'AVAILABLE') + assert.strictEqual(result.nextToken, 'next-token') + assert.ok(result.domains[0].createdAt instanceof Date) + assert.ok(result.domains[0].lastUpdatedAt instanceof Date) + }) + + it('should handle empty results', async () => { + const mockResponse = { + items: [], + nextToken: undefined, + } + + const mockDataZoneClient = { + listDomains: sinon.stub().returns({ + promise: () => Promise.resolve(mockResponse), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + const result = await client.listDomains() + + assert.strictEqual(result.domains.length, 0) + assert.strictEqual(result.nextToken, undefined) + }) + + it('should handle API errors', async () => { + const error = new Error('API Error') + sinon.stub(client as any, 'getDataZoneCustomClient').rejects(error) + + await assert.rejects(() => client.listDomains(), error) + }) + }) + + describe('fetchAllDomains', () => { + it('should fetch all domains by handling pagination', async () => { + const listDomainsStub = sinon.stub() + + // First call returns first page with nextToken + listDomainsStub.onFirstCall().resolves({ + domains: [ + { + id: 'dzd_domain1', + name: 'Domain 1', + arn: 'arn:aws:datazone:us-east-1:123456789012:domain/dzd_domain1', + managedAccountId: '123456789012', + status: 'AVAILABLE', + createdAt: new Date(), + } as DataZoneDomain, + ], + nextToken: 'next-page-token', + }) + + // Second call returns second page with no nextToken + listDomainsStub.onSecondCall().resolves({ + domains: [ + { + id: 'dzd_domain2', + name: 'Domain 2', + arn: 'arn:aws:datazone:us-east-1:123456789012:domain/dzd_domain2', + managedAccountId: '123456789012', + status: 'AVAILABLE', + createdAt: new Date(), + } as DataZoneDomain, + ], + nextToken: undefined, + }) + + // Replace the listDomains method with our stub + client.listDomains = listDomainsStub + + const result = await client.fetchAllDomains({ status: 'AVAILABLE' }) + + assert.strictEqual(result.length, 2) + assert.strictEqual(result[0].id, 'dzd_domain1') + assert.strictEqual(result[1].id, 'dzd_domain2') + + // Verify listDomains was called correctly + assert.strictEqual(listDomainsStub.callCount, 2) + assert.deepStrictEqual(listDomainsStub.firstCall.args[0], { + status: 'AVAILABLE', + maxResults: 25, + nextToken: undefined, + }) + assert.deepStrictEqual(listDomainsStub.secondCall.args[0], { + status: 'AVAILABLE', + maxResults: 25, + nextToken: 'next-page-token', + }) + }) + + it('should return empty array when no domains found', async () => { + const listDomainsStub = sinon.stub().resolves({ + domains: [], + nextToken: undefined, + }) + + client.listDomains = listDomainsStub + + const result = await client.fetchAllDomains() + + assert.strictEqual(result.length, 0) + assert.strictEqual(listDomainsStub.callCount, 1) + }) + + it('should handle errors gracefully', async () => { + const listDomainsStub = sinon.stub().rejects(new Error('API error')) + + client.listDomains = listDomainsStub + + await assert.rejects(() => client.fetchAllDomains(), /API error/) + }) + }) + + describe('getDomain', () => { + it('should find EXPRESS domain', async () => { + const listDomainsStub = sinon.stub() + + listDomainsStub.onFirstCall().resolves({ + domains: [ + { + id: 'dzd_standard', + name: 'Standard Domain', + arn: 'arn:aws:datazone:us-east-1:123456789012:domain/dzd_standard', + managedAccountId: '123456789012', + status: 'AVAILABLE', + createdAt: new Date(), + preferences: { DOMAIN_MODE: 'STANDARD' }, + }, + { + id: 'dzd_express', + name: 'Express Domain', + arn: 'arn:aws:datazone:us-east-1:123456789012:domain/dzd_express', + managedAccountId: '123456789012', + status: 'AVAILABLE', + createdAt: new Date(), + preferences: { DOMAIN_MODE: 'EXPRESS' }, + }, + ] as DataZoneDomain[], + nextToken: 'next-token', + }) + + client.listDomains = listDomainsStub + + const result = await client.getIamDomain() + + assert.ok(result) + assert.strictEqual(result.id, 'dzd_express') + assert.strictEqual(result.name, 'Express Domain') + assert.strictEqual(result.preferences?.DOMAIN_MODE, 'EXPRESS') + + // Should only call once since EXPRESS domain found on first page + assert.strictEqual(listDomainsStub.callCount, 1) + }) + + it('should return undefined when no EXPRESS domain found', async () => { + const listDomainsStub = sinon.stub() + + listDomainsStub.onFirstCall().resolves({ + domains: [ + { + id: 'dzd_standard', + name: 'Standard Domain', + arn: 'arn:aws:datazone:us-east-1:123456789012:domain/dzd_standard', + managedAccountId: '123456789012', + status: 'AVAILABLE', + createdAt: new Date(), + preferences: { DOMAIN_MODE: 'STANDARD' }, + }, + ] as DataZoneDomain[], + nextToken: undefined, + }) + + client.listDomains = listDomainsStub + + const result = await client.getIamDomain() + + assert.strictEqual(result, undefined) + assert.strictEqual(listDomainsStub.callCount, 1) + }) + + it('should return undefined when no domains found', async () => { + const listDomainsStub = sinon.stub().resolves({ + domains: [], + nextToken: undefined, + }) + + client.listDomains = listDomainsStub + + const result = await client.getIamDomain() + + assert.strictEqual(result, undefined) + assert.strictEqual(listDomainsStub.callCount, 1) + }) + + it('should handle domains without preferences', async () => { + const listDomainsStub = sinon.stub() + + listDomainsStub.onFirstCall().resolves({ + domains: [ + { + id: 'dzd_no_prefs', + name: 'Domain Without Preferences', + arn: 'arn:aws:datazone:us-east-1:123456789012:domain/dzd_no_prefs', + managedAccountId: '123456789012', + status: 'AVAILABLE', + createdAt: new Date(), + // No preferences field + }, + ] as DataZoneDomain[], + nextToken: undefined, + }) + + client.listDomains = listDomainsStub + + const result = await client.getIamDomain() + + assert.strictEqual(result, undefined) + }) + + it('should handle API errors', async () => { + const listDomainsStub = sinon.stub().rejects(new Error('API error')) + + client.listDomains = listDomainsStub + + await assert.rejects(() => client.getIamDomain(), /Failed to get domain info: API error/) + }) + }) + + describe('getDomain', () => { + it('should get domain by ID successfully', async () => { + const mockDomainId = 'dzd_test123' + const mockResponse = { + id: mockDomainId, + name: 'Test Domain', + description: 'A test domain', + arn: `arn:aws:datazone:us-east-1:123456789012:domain/${mockDomainId}`, + status: 'AVAILABLE', + portalUrl: 'https://test.datazone.aws', + createdAt: '2023-01-01T00:00:00Z', + lastUpdatedAt: '2023-01-02T00:00:00Z', + domainVersion: '1.0', + preferences: { DOMAIN_MODE: 'EXPRESS' }, + } + const mockDataZoneClient = { + getDomain: sinon.stub().returns({ + promise: () => Promise.resolve(mockResponse), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + const result = await client.getDomain(mockDomainId) + + assert.strictEqual(result.id, mockDomainId) + assert.strictEqual(result.name, 'Test Domain') + assert.strictEqual(result.description, 'A test domain') + assert.strictEqual(result.arn, `arn:aws:datazone:us-east-1:123456789012:domain/${mockDomainId}`) + assert.strictEqual(result.status, 'AVAILABLE') + assert.strictEqual(result.portalUrl, 'https://test.datazone.aws') + assert.strictEqual(result.domainVersion, '1.0') + assert.deepStrictEqual(result.preferences, { DOMAIN_MODE: 'EXPRESS' }) + + // Verify the API was called with correct parameters + assert.ok(mockDataZoneClient.getDomain.calledOnce) + assert.deepStrictEqual(mockDataZoneClient.getDomain.firstCall.args[0], { + identifier: mockDomainId, + }) + }) + + it('should handle API errors when getting domain', async () => { + const mockDomainId = 'dzd_test123' + const error = new Error('Domain not found') + + const mockDataZoneClient = { + getDomain: sinon.stub().returns({ + promise: () => Promise.reject(error), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + await assert.rejects(() => client.getDomain(mockDomainId), error) + + // Verify the API was called with correct parameters + assert.ok(mockDataZoneClient.getDomain.calledOnce) + assert.deepStrictEqual(mockDataZoneClient.getDomain.firstCall.args[0], { + identifier: mockDomainId, + }) + }) + }) + + describe('isIamDomain', () => { + it('should return true for EXPRESS domain', async () => { + const mockDomainId = 'dzd_express123' + const mockResponse = { + id: mockDomainId, + name: 'Express Domain', + arn: `arn:aws:datazone:us-east-1:123456789012:domain/${mockDomainId}`, + status: 'AVAILABLE', + preferences: { DOMAIN_MODE: 'EXPRESS' }, + } + + const getDomainStub = sinon.stub(client, 'getDomain').resolves(mockResponse) + + const result = await client.isIamDomain(mockDomainId) + + assert.strictEqual(result, true) + assert.ok(getDomainStub.calledOnce) + assert.strictEqual(getDomainStub.firstCall.args[0], mockDomainId) + }) + + it('should return false for STANDARD domain', async () => { + const mockDomainId = 'dzd_standard123' + const mockResponse = { + id: mockDomainId, + name: 'Standard Domain', + arn: `arn:aws:datazone:us-east-1:123456789012:domain/${mockDomainId}`, + status: 'AVAILABLE', + preferences: { DOMAIN_MODE: 'STANDARD' }, + } + + const getDomainStub = sinon.stub(client, 'getDomain').resolves(mockResponse) + + const result = await client.isIamDomain(mockDomainId) + + assert.strictEqual(result, false) + assert.ok(getDomainStub.calledOnce) + assert.strictEqual(getDomainStub.firstCall.args[0], mockDomainId) + }) + + it('should return false for domain without preferences', async () => { + const mockDomainId = 'dzd_no_prefs123' + const mockResponse = { + id: mockDomainId, + name: 'Domain Without Preferences', + arn: `arn:aws:datazone:us-east-1:123456789012:domain/${mockDomainId}`, + status: 'AVAILABLE', + // No preferences field + } + + const getDomainStub = sinon.stub(client, 'getDomain').resolves(mockResponse) + + const result = await client.isIamDomain(mockDomainId) + + assert.strictEqual(result, false) + assert.ok(getDomainStub.calledOnce) + assert.strictEqual(getDomainStub.firstCall.args[0], mockDomainId) + }) + }) + + describe('searchGroupProfiles', () => { + const mockDomainId = 'dzd_test123' + + it('should search group profiles successfully', async () => { + const mockResponse = { + items: [ + { + domainId: mockDomainId, + id: 'gp_profile1', + status: 'ACTIVATED', + groupName: 'AdminGroup', + rolePrincipalArn: 'arn:aws:iam::123456789012:role/AdminRole', + rolePrincipalId: 'AIDAI123456789EXAMPLE', + }, + { + domainId: mockDomainId, + id: 'gp_profile2', + status: 'ACTIVATED', + groupName: 'DeveloperGroup', + rolePrincipalArn: 'arn:aws:iam::123456789012:role/DeveloperRole', + rolePrincipalId: 'AIDAI987654321EXAMPLE', + }, + ], + nextToken: 'next-token', + } + + const mockDataZoneClient = { + searchGroupProfiles: sinon.stub().returns({ + promise: () => Promise.resolve(mockResponse), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + const result = await client.searchGroupProfiles(mockDomainId, { + groupType: 'IAM_ROLE_SESSION_GROUP', + maxResults: 50, + }) + + assert.strictEqual(result.items.length, 2) + assert.strictEqual(result.items[0].id, 'gp_profile1') + assert.strictEqual(result.items[0].rolePrincipalArn, 'arn:aws:iam::123456789012:role/AdminRole') + assert.strictEqual(result.items[1].id, 'gp_profile2') + assert.strictEqual(result.nextToken, 'next-token') + + // Verify API was called with correct parameters + assert.ok(mockDataZoneClient.searchGroupProfiles.calledOnce) + const callArgs = mockDataZoneClient.searchGroupProfiles.firstCall.args[0] + assert.strictEqual(callArgs.domainIdentifier, mockDomainId) + assert.strictEqual(callArgs.groupType, 'IAM_ROLE_SESSION_GROUP') + assert.strictEqual(callArgs.maxResults, 50) + }) + + it('should handle empty results', async () => { + const mockResponse = { + items: [], + nextToken: undefined, + } + + const mockDataZoneClient = { + searchGroupProfiles: sinon.stub().returns({ + promise: () => Promise.resolve(mockResponse), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + const result = await client.searchGroupProfiles(mockDomainId) + + assert.strictEqual(result.items.length, 0) + assert.strictEqual(result.nextToken, undefined) + }) + + it('should handle API errors', async () => { + const error = new Error('API Error') + const mockDataZoneClient = { + searchGroupProfiles: sinon.stub().returns({ + promise: () => Promise.reject(error), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + await assert.rejects(() => client.searchGroupProfiles(mockDomainId), error) + }) + + it('should support pagination with nextToken', async () => { + const mockResponse = { + items: [ + { + domainId: mockDomainId, + id: 'gp_profile3', + status: 'ACTIVATED', + groupName: 'TestGroup', + rolePrincipalArn: 'arn:aws:iam::123456789012:role/TestRole', + rolePrincipalId: 'AIDAI111111111EXAMPLE', + }, + ], + nextToken: undefined, + } + + const mockDataZoneClient = { + searchGroupProfiles: sinon.stub().returns({ + promise: () => Promise.resolve(mockResponse), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + const result = await client.searchGroupProfiles(mockDomainId, { + nextToken: 'previous-token', + }) + + assert.strictEqual(result.items.length, 1) + assert.strictEqual(result.nextToken, undefined) + + // Verify nextToken was passed + const callArgs = mockDataZoneClient.searchGroupProfiles.firstCall.args[0] + assert.strictEqual(callArgs.nextToken, 'previous-token') + }) + }) + + describe('searchUserProfiles', () => { + const mockDomainId = 'dzd_test123' + + it('should search user profiles successfully', async () => { + const mockResponse = { + items: [ + { + domainId: mockDomainId, + id: 'up_user1', + type: 'IAM', + status: 'ACTIVATED', + details: { + iam: { + arn: 'arn:aws:iam::123456789012:role/AdminRole', + principalId: 'AIDAI123456789EXAMPLE:session1', + }, + }, + }, + { + domainId: mockDomainId, + id: 'up_user2', + type: 'IAM', + status: 'ACTIVATED', + details: { + iam: { + arn: 'arn:aws:iam::123456789012:role/DeveloperRole', + principalId: 'AIDAI987654321EXAMPLE:session2', + }, + }, + }, + ], + nextToken: 'next-token', + } + + const mockDataZoneClient = { + searchUserProfiles: sinon.stub().returns({ + promise: () => Promise.resolve(mockResponse), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + const result = await client.searchUserProfiles(mockDomainId, { + userType: 'DATAZONE_IAM_USER', + maxResults: 50, + }) + + assert.strictEqual(result.items.length, 2) + assert.strictEqual(result.items[0].id, 'up_user1') + assert.strictEqual(result.items[0].details?.iam?.principalId, 'AIDAI123456789EXAMPLE:session1') + assert.strictEqual(result.items[1].id, 'up_user2') + assert.strictEqual(result.nextToken, 'next-token') + + // Verify API was called with correct parameters + assert.ok(mockDataZoneClient.searchUserProfiles.calledOnce) + const callArgs = mockDataZoneClient.searchUserProfiles.firstCall.args[0] + assert.strictEqual(callArgs.domainIdentifier, mockDomainId) + assert.strictEqual(callArgs.userType, 'DATAZONE_IAM_USER') + assert.strictEqual(callArgs.maxResults, 50) + }) + + it('should handle SSO user profiles', async () => { + const mockResponse = { + items: [ + { + domainId: mockDomainId, + id: 'up_sso_user', + type: 'SSO', + status: 'ACTIVATED', + details: { + sso: { + firstName: 'John', + lastName: 'Doe', + username: 'jdoe', + }, + }, + }, + ], + nextToken: undefined, + } + + const mockDataZoneClient = { + searchUserProfiles: sinon.stub().returns({ + promise: () => Promise.resolve(mockResponse), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + const result = await client.searchUserProfiles(mockDomainId, { + userType: 'SSO_USER', + }) + + assert.strictEqual(result.items.length, 1) + assert.strictEqual(result.items[0].details?.sso?.username, 'jdoe') + assert.strictEqual(result.items[0].details?.sso?.firstName, 'John') + }) + + it('should handle empty results', async () => { + const mockResponse = { + items: [], + nextToken: undefined, + } + + const mockDataZoneClient = { + searchUserProfiles: sinon.stub().returns({ + promise: () => Promise.resolve(mockResponse), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + const result = await client.searchUserProfiles(mockDomainId, { + userType: 'DATAZONE_IAM_USER', + }) + + assert.strictEqual(result.items.length, 0) + assert.strictEqual(result.nextToken, undefined) + }) + + it('should handle API errors', async () => { + const error = new Error('API Error') + const mockDataZoneClient = { + searchUserProfiles: sinon.stub().returns({ + promise: () => Promise.reject(error), + }), + } + + sinon.stub(client as any, 'getDataZoneCustomClient').resolves(mockDataZoneClient) + + await assert.rejects( + () => + client.searchUserProfiles(mockDomainId, { + userType: 'DATAZONE_IAM_USER', + }), + error + ) + }) + }) + + describe('getGroupProfileId', () => { + const mockDomainId = 'dzd_test123' + const mockRoleArn = 'arn:aws:iam::123456789012:role/AdminRole' + + it('should find matching group profile on first page', async () => { + const searchStub = sinon.stub(client, 'searchGroupProfiles') + searchStub.onFirstCall().resolves({ + items: [ + { + id: 'gp_profile1', + rolePrincipalArn: mockRoleArn, + status: 'ACTIVATED', + }, + ], + nextToken: undefined, + }) + + const result = await client.getGroupProfileId(mockDomainId, mockRoleArn) + + assert.strictEqual(result, 'gp_profile1') + assert.ok(searchStub.calledOnce) + assert.strictEqual(searchStub.firstCall.args[0], mockDomainId) + assert.strictEqual(searchStub.firstCall.args[1]?.groupType, 'IAM_ROLE_SESSION_GROUP') + }) + + it('should throw ToolkitError when no matching profile found', async () => { + const searchStub = sinon.stub(client, 'searchGroupProfiles') + searchStub.resolves({ + items: [ + { + id: 'gp_profile1', + rolePrincipalArn: 'arn:aws:iam::123456789012:role/OtherRole', + status: 'ACTIVATED', + }, + ], + nextToken: undefined, + }) + + await assert.rejects( + () => client.getGroupProfileId(mockDomainId, mockRoleArn), + (err: any) => { + assert.ok(err.message.includes('No group profile found')) + assert.strictEqual(err.code, 'NoGroupProfileFound') + return true + } + ) + }) + + it('should handle API errors', async () => { + const searchStub = sinon.stub(client, 'searchGroupProfiles') + searchStub.rejects(new Error('API Error')) + + await assert.rejects( + () => client.getGroupProfileId(mockDomainId, mockRoleArn), + (err: any) => { + assert.ok(err.message.includes('Failed to get group profile ID')) + return true + } + ) + }) + }) + + describe('getUserProfileIdForSession', () => { + const mockDomainId = 'dzd_test123' + const mockAssumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/AdminRole/my-session' + + it('should find matching user profile by role ARN and session name', async () => { + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.onFirstCall().resolves({ + items: [ + { + id: 'up_user1', + status: 'ACTIVATED', + details: { + iam: { + arn: 'arn:aws:iam::123456789012:role/AdminRole', + principalId: 'AIDAI123456789EXAMPLE:my-session', + }, + }, + }, + ], + nextToken: undefined, + }) + + const result = await client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn) + + assert.strictEqual(result, 'up_user1') + assert.ok(searchStub.calledOnce) + assert.strictEqual(searchStub.firstCall.args[0], mockDomainId) + assert.strictEqual(searchStub.firstCall.args[1].userType, 'DATAZONE_IAM_USER') + assert.strictEqual(searchStub.firstCall.args[1].searchText, 'arn:aws:iam::123456789012:role/AdminRole') + }) + + it('should find matching user profile across multiple pages', async () => { + const searchStub = sinon.stub(client, 'searchUserProfiles') + + // First page - no match (different session name) + searchStub.onFirstCall().resolves({ + items: [ + { + id: 'up_user1', + status: 'ACTIVATED', + details: { + iam: { + arn: 'arn:aws:iam::123456789012:role/AdminRole', + principalId: 'AIDAI123456789EXAMPLE:other-session', + }, + }, + }, + ], + nextToken: 'next-token', + }) + + // Second page - match found + searchStub.onSecondCall().resolves({ + items: [ + { + id: 'up_user2', + status: 'ACTIVATED', + details: { + iam: { + arn: 'arn:aws:iam::123456789012:role/AdminRole', + principalId: 'AIDAI987654321EXAMPLE:my-session', + }, + }, + }, + ], + nextToken: undefined, + }) + + const result = await client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn) + + assert.strictEqual(result, 'up_user2') + assert.strictEqual(searchStub.callCount, 2) + }) + + it('should throw ToolkitError when session name cannot be extracted', async () => { + const invalidArn = 'arn:aws:iam::123456789012:role/AdminRole' + + await assert.rejects( + () => client.getUserProfileIdForSession(mockDomainId, invalidArn), + (err: any) => { + assert.ok(err.message.includes('Unable to extract session name')) + assert.strictEqual(err.code, 'NoUserProfileFound') + return true + } + ) + }) + + it('should throw ToolkitError when no matching profile found', async () => { + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.resolves({ + items: [ + { + id: 'up_user1', + status: 'ACTIVATED', + details: { + iam: { + arn: 'arn:aws:iam::123456789012:role/AdminRole', + principalId: 'AIDAI123456789EXAMPLE:other-session', + }, + }, + }, + ], + nextToken: undefined, + }) + + await assert.rejects( + () => client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn), + (err: any) => { + assert.ok(err.message.includes('No user profile found')) + assert.strictEqual(err.code, 'NoUserProfileFound') + return true + } + ) + }) + + it('should handle profiles without IAM details', async () => { + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.resolves({ + items: [ + { + id: 'up_user1', + status: 'ACTIVATED', + details: { + // No iam field + }, + }, + ], + nextToken: undefined, + }) + + await assert.rejects( + () => client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn), + (err: any) => { + assert.ok(err.message.includes('No user profile found')) + return true + } + ) + }) + + it('should handle API errors', async () => { + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.rejects(new Error('API Error')) + + await assert.rejects( + () => client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn), + (err: any) => { + assert.ok(err.message.includes('Failed to get user profile ID')) + return true + } + ) + }) + + it('should handle various role ARN formats', async () => { + const testCases = [ + { + arn: 'arn:aws:sts::123456789012:assumed-role/MyRole/session-123', + expectedSession: 'session-123', + }, + { + arn: 'arn:aws:sts::123456789012:assumed-role/DeveloperRole/user-session-name', + expectedSession: 'user-session-name', + }, + { + arn: 'arn:aws:sts::999888777666:assumed-role/AdminRole/admin-session', + expectedSession: 'admin-session', + }, + ] + + for (const testCase of testCases) { + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.resolves({ + items: [ + { + id: 'up_test', + status: 'ACTIVATED', + details: { + iam: { + principalId: `PRINCIPAL:${testCase.expectedSession}`, + }, + }, + }, + ], + nextToken: undefined, + }) + + const result = await client.getUserProfileIdForSession(mockDomainId, testCase.arn) + assert.strictEqual(result, 'up_test') + + searchStub.restore() + } + }) + }) + + describe('Project and Space Filtering', () => { + const mockDomainId = 'dzd_test123' + + describe('Project filtering by group profile', () => { + it('should filter projects when group profile is found', async () => { + const mockRoleArn = 'arn:aws:iam::123456789012:role/AdminRole' + const mockGroupProfileId = 'gp_profile1' + + const searchStub = sinon.stub(client, 'searchGroupProfiles') + searchStub.resolves({ + items: [ + { + id: mockGroupProfileId, + rolePrincipalArn: mockRoleArn, + status: 'ACTIVATED', + }, + ], + nextToken: undefined, + }) + + const result = await client.getGroupProfileId(mockDomainId, mockRoleArn) + + assert.strictEqual(result, mockGroupProfileId) + assert.ok(searchStub.calledOnce) + }) + + it('should handle empty project list for group profile', async () => { + const mockRoleArn = 'arn:aws:iam::123456789012:role/AdminRole' + + const searchStub = sinon.stub(client, 'searchGroupProfiles') + searchStub.resolves({ + items: [], + nextToken: undefined, + }) + + await assert.rejects( + () => client.getGroupProfileId(mockDomainId, mockRoleArn), + (err: any) => { + assert.ok(err.message.includes('No group profile found')) + assert.strictEqual(err.code, 'NoGroupProfileFound') + return true + } + ) + }) + + it('should handle API errors during project filtering', async () => { + const mockRoleArn = 'arn:aws:iam::123456789012:role/AdminRole' + + const searchStub = sinon.stub(client, 'searchGroupProfiles') + searchStub.rejects(new Error('API Error')) + + await assert.rejects( + () => client.getGroupProfileId(mockDomainId, mockRoleArn), + (err: any) => { + assert.ok(err.message.includes('Failed to get group profile ID')) + return true + } + ) + }) + + it('should handle AccessDeniedException during project filtering', async () => { + const mockRoleArn = 'arn:aws:iam::123456789012:role/AdminRole' + const accessDeniedError = new Error('Access denied') + accessDeniedError.name = 'AccessDeniedException' + + const searchStub = sinon.stub(client, 'searchGroupProfiles') + searchStub.rejects(accessDeniedError) + + await assert.rejects( + () => client.getGroupProfileId(mockDomainId, mockRoleArn), + (err: any) => { + assert.ok(err.message.includes('Failed to get group profile ID')) + return true + } + ) + }) + }) + + describe('Space filtering by user profile', () => { + it('should filter spaces when user profile is found', async () => { + const mockAssumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/AdminRole/my-session' + const mockUserProfileId = 'up_user1' + + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.resolves({ + items: [ + { + id: mockUserProfileId, + status: 'ACTIVATED', + details: { + iam: { + arn: 'arn:aws:iam::123456789012:role/AdminRole', + principalId: 'AIDAI123456789EXAMPLE:my-session', + }, + }, + }, + ], + nextToken: undefined, + }) + + const result = await client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn) + + assert.strictEqual(result, mockUserProfileId) + assert.ok(searchStub.calledOnce) + }) + + it('should handle empty space list for user profile', async () => { + const mockAssumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/AdminRole/my-session' + + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.resolves({ + items: [], + nextToken: undefined, + }) + + await assert.rejects( + () => client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn), + (err: any) => { + assert.ok(err.message.includes('No user profile found')) + assert.strictEqual(err.code, 'NoUserProfileFound') + return true + } + ) + }) + + it('should handle API errors during space filtering', async () => { + const mockAssumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/AdminRole/my-session' + + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.rejects(new Error('API Error')) + + await assert.rejects( + () => client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn), + (err: any) => { + assert.ok(err.message.includes('Failed to get user profile ID')) + return true + } + ) + }) + + it('should handle AccessDeniedException during space filtering', async () => { + const mockAssumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/AdminRole/my-session' + const accessDeniedError = new Error('Access denied') + accessDeniedError.name = 'AccessDeniedException' + + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.rejects(accessDeniedError) + + await assert.rejects( + () => client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn), + (err: any) => { + assert.ok(err.message.includes('Failed to get user profile ID')) + return true + } + ) + }) + + it('should handle profiles with missing principalId', async () => { + const mockAssumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/AdminRole/my-session' + + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.resolves({ + items: [ + { + id: 'up_user1', + status: 'ACTIVATED', + details: { + iam: { + // Missing principalId + }, + }, + }, + ], + nextToken: undefined, + }) + + await assert.rejects( + () => client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn), + (err: any) => { + assert.ok(err.message.includes('No user profile found')) + return true + } + ) + }) + }) + + describe('Error scenarios in filtering logic', () => { + it('should handle network errors during group profile search', async () => { + const mockRoleArn = 'arn:aws:iam::123456789012:role/AdminRole' + const networkError = new Error('Network error') + networkError.name = 'NetworkError' + + const searchStub = sinon.stub(client, 'searchGroupProfiles') + searchStub.rejects(networkError) + + await assert.rejects( + () => client.getGroupProfileId(mockDomainId, mockRoleArn), + (err: any) => { + assert.ok(err.message.includes('Failed to get group profile ID')) + return true + } + ) + }) + + it('should handle network errors during user profile search', async () => { + const mockAssumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/AdminRole/my-session' + const networkError = new Error('Network error') + networkError.name = 'NetworkError' + + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.rejects(networkError) + + await assert.rejects( + () => client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn), + (err: any) => { + assert.ok(err.message.includes('Failed to get user profile ID')) + return true + } + ) + }) + + it('should handle timeout errors during group profile search', async () => { + const mockRoleArn = 'arn:aws:iam::123456789012:role/AdminRole' + const timeoutError = new Error('Request timeout') + timeoutError.name = 'TimeoutError' + + const searchStub = sinon.stub(client, 'searchGroupProfiles') + searchStub.rejects(timeoutError) + + await assert.rejects( + () => client.getGroupProfileId(mockDomainId, mockRoleArn), + (err: any) => { + assert.ok(err.message.includes('Failed to get group profile ID')) + return true + } + ) + }) + + it('should handle malformed response during group profile search', async () => { + const mockRoleArn = 'arn:aws:iam::123456789012:role/AdminRole' + + const searchStub = sinon.stub(client, 'searchGroupProfiles') + searchStub.resolves({ + items: [ + { + // Missing required fields + status: 'ACTIVATED', + } as any, + ], + nextToken: undefined, + }) + + await assert.rejects( + () => client.getGroupProfileId(mockDomainId, mockRoleArn), + (err: any) => { + assert.ok(err.message.includes('No group profile found')) + return true + } + ) + }) + + it('should handle malformed response during user profile search', async () => { + const mockAssumedRoleArn = 'arn:aws:sts::123456789012:assumed-role/AdminRole/my-session' + + const searchStub = sinon.stub(client, 'searchUserProfiles') + searchStub.resolves({ + items: [ + { + // Missing required fields + status: 'ACTIVATED', + } as any, + ], + nextToken: undefined, + }) + + await assert.rejects( + () => client.getUserProfileIdForSession(mockDomainId, mockAssumedRoleArn), + (err: any) => { + assert.ok(err.message.includes('No user profile found')) + return true + } + ) + }) + }) + }) +}) diff --git a/packages/core/src/test/sagemakerunifiedstudio/shared/devSettingsEndpointConfiguration.test.ts b/packages/core/src/test/sagemakerunifiedstudio/shared/devSettingsEndpointConfiguration.test.ts new file mode 100644 index 00000000000..c5321283017 --- /dev/null +++ b/packages/core/src/test/sagemakerunifiedstudio/shared/devSettingsEndpointConfiguration.test.ts @@ -0,0 +1,86 @@ +/*! + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as assert from 'assert' +import * as sinon from 'sinon' +import { DevSettings } from '../../../shared/settings' + +describe('Endpoint Configuration from Settings', () => { + let sandbox: sinon.SinonSandbox + + beforeEach(() => { + sandbox = sinon.createSandbox() + }) + + afterEach(() => { + sandbox.restore() + }) + + describe('DataZone endpoint configuration', () => { + it('should return custom DataZone endpoint when configured', () => { + const customEndpoint = 'https://custom-datazone.example.com' + const getStub = sandbox.stub(DevSettings.instance, 'get') + getStub.withArgs('endpoints', {}).returns({ datazone: customEndpoint }) + + const endpoints = DevSettings.instance.get('endpoints', {}) + const datazoneEndpoint = endpoints['datazone'] + + assert.strictEqual(datazoneEndpoint, customEndpoint) + }) + }) + + describe('SageMaker endpoint configuration', () => { + it('should return custom SageMaker endpoint when configured', () => { + const customEndpoint = 'https://custom-sagemaker.example.com' + const getStub = sandbox.stub(DevSettings.instance, 'get') + getStub.withArgs('endpoints', {}).returns({ sagemaker: customEndpoint }) + + const endpoints = DevSettings.instance.get('endpoints', {}) + const sagemakerEndpoint = endpoints['sagemaker'] + + assert.strictEqual(sagemakerEndpoint, customEndpoint) + }) + }) + + describe('Endpoint fallback behavior', () => { + it('should construct default DataZone endpoint when custom endpoint is not set', () => { + const getStub = sandbox.stub(DevSettings.instance, 'get') + getStub.withArgs('endpoints', {}).returns({}) + + const region = 'us-west-2' + const endpoints = DevSettings.instance.get('endpoints', {}) + const customEndpoint = endpoints['datazone'] + const endpoint = customEndpoint || `https://datazone.${region}.api.aws` + + assert.strictEqual(endpoint, 'https://datazone.us-west-2.api.aws') + }) + + it('should construct default SageMaker endpoint when custom endpoint is not set', () => { + const getStub = sandbox.stub(DevSettings.instance, 'get') + getStub.withArgs('endpoints', {}).returns({}) + + const region = 'us-east-1' + const endpoints = DevSettings.instance.get('endpoints', {}) + const customEndpoint = endpoints['sagemaker'] + const endpoint = customEndpoint || `https://sagemaker.${region}.amazonaws.com` + + assert.strictEqual(endpoint, 'https://sagemaker.us-east-1.amazonaws.com') + }) + + it('should handle multiple endpoints in configuration', () => { + const customEndpoints = { + datazone: 'https://custom-datazone.example.com', + sagemaker: 'https://custom-sagemaker.example.com', + } + const getStub = sandbox.stub(DevSettings.instance, 'get') + getStub.withArgs('endpoints', {}).returns(customEndpoints) + + const endpoints = DevSettings.instance.get('endpoints', {}) + + assert.strictEqual(endpoints['datazone'], customEndpoints.datazone) + assert.strictEqual(endpoints['sagemaker'], customEndpoints.sagemaker) + }) + }) +}) diff --git a/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts b/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts index c03b55c64c6..46db2c228f5 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts @@ -13,6 +13,7 @@ import { validateCredentialFields, extractAccountIdFromSageMakerArn, extractAccountIdFromResourceMetadata, + isCredentialExpirationError, } from '../../../sagemakerunifiedstudio/shared/smusUtils' import { ToolkitError } from '../../../shared/errors' import * as extensionUtilities from '../../../shared/extensionUtilities' @@ -462,6 +463,132 @@ describe('SmusUtils', () => { assert.strictEqual(result, false) }) }) + + describe('isIamUserArn', () => { + it('should return true for IAM user ARN', () => { + const iamUserArn = 'arn:aws:iam::619071339486:user/vabharga-test' + const result = SmusUtils.isIamUserArn(iamUserArn) + assert.strictEqual(result, true) + }) + + it('should return false for IAM role session ARN', () => { + const roleSessionArn = 'arn:aws:sts::123456789012:assumed-role/MyRole/MySession' + const result = SmusUtils.isIamUserArn(roleSessionArn) + assert.strictEqual(result, false) + }) + + it('should return false for IAM role ARN', () => { + const roleArn = 'arn:aws:iam::123456789012:role/MyRole' + const result = SmusUtils.isIamUserArn(roleArn) + assert.strictEqual(result, false) + }) + + it('should return false for undefined ARN', () => { + const result = SmusUtils.isIamUserArn(undefined) + assert.strictEqual(result, false) + }) + + it('should return false for empty string', () => { + const result = SmusUtils.isIamUserArn('') + assert.strictEqual(result, false) + }) + + it('should return false for invalid ARN format', () => { + const result = SmusUtils.isIamUserArn('not-an-arn') + assert.strictEqual(result, false) + }) + + it('should return false for non-IAM ARN', () => { + const s3Arn = 'arn:aws:s3:::my-bucket' + const result = SmusUtils.isIamUserArn(s3Arn) + assert.strictEqual(result, false) + }) + }) + + describe('convertAssumedRoleArnToIamRoleArn', () => { + it('should convert basic assumed role ARN to IAM role ARN', () => { + const stsArn = 'arn:aws:sts::123456789012:assumed-role/MyRole/MySession' + const expected = 'arn:aws:iam::123456789012:role/MyRole' + + const result = SmusUtils.convertAssumedRoleArnToIamRoleArn(stsArn) + assert.strictEqual(result, expected) + }) + + it('should convert assumed role ARN with aws-cn partition', () => { + const stsArn = 'arn:aws-cn:sts::123456789012:assumed-role/MyRole/MySession' + const expected = 'arn:aws-cn:iam::123456789012:role/MyRole' + + const result = SmusUtils.convertAssumedRoleArnToIamRoleArn(stsArn) + assert.strictEqual(result, expected) + }) + + it('should convert assumed role ARN with aws-us-gov partition', () => { + const stsArn = 'arn:aws-us-gov:sts::123456789012:assumed-role/MyRole/MySession' + const expected = 'arn:aws-us-gov:iam::123456789012:role/MyRole' + + const result = SmusUtils.convertAssumedRoleArnToIamRoleArn(stsArn) + assert.strictEqual(result, expected) + }) + + it('should return IAM user ARN as-is', () => { + const iamUserArn = 'arn:aws:iam::619071339486:user/vabharga-test' + const result = SmusUtils.convertAssumedRoleArnToIamRoleArn(iamUserArn) + assert.strictEqual(result, iamUserArn) + }) + + it('should return IAM user ARN with aws-cn partition as-is', () => { + const iamUserArn = 'arn:aws-cn:iam::123456789012:user/my-user' + const result = SmusUtils.convertAssumedRoleArnToIamRoleArn(iamUserArn) + assert.strictEqual(result, iamUserArn) + }) + + it('should return IAM user ARN with aws-us-gov partition as-is', () => { + const iamUserArn = 'arn:aws-us-gov:iam::123456789012:user/my-user' + const result = SmusUtils.convertAssumedRoleArnToIamRoleArn(iamUserArn) + assert.strictEqual(result, iamUserArn) + }) + + it('should return IAM role ARN as-is', () => { + const iamRoleArn = 'arn:aws:iam::123456789012:role/MyRole' + const result = SmusUtils.convertAssumedRoleArnToIamRoleArn(iamRoleArn) + assert.strictEqual(result, iamRoleArn) + }) + + it('should return IAM role ARN with aws-cn partition as-is', () => { + const iamRoleArn = 'arn:aws-cn:iam::123456789012:role/MyRole' + const result = SmusUtils.convertAssumedRoleArnToIamRoleArn(iamRoleArn) + assert.strictEqual(result, iamRoleArn) + }) + + it('should handle IAM user ARN with special characters', () => { + const iamUserArn = 'arn:aws:iam::123456789012:user/path/to/user-name_123' + const result = SmusUtils.convertAssumedRoleArnToIamRoleArn(iamUserArn) + assert.strictEqual(result, iamUserArn) + }) + + it('should throw error for invalid ARN format - missing components', () => { + const invalidArn = 'arn:aws:sts::123456789012:assumed-role/MyRole' + + assert.throws( + () => SmusUtils.convertAssumedRoleArnToIamRoleArn(invalidArn), + (error: Error) => { + assert.ok(error.message.includes('Invalid STS ARN format')) + assert.ok(error.message.includes(invalidArn)) + return true + } + ) + }) + + it('should throw error for empty string', () => { + assert.throws( + () => SmusUtils.convertAssumedRoleArnToIamRoleArn(''), + (error: Error) => { + assert.ok(error.message.includes('Invalid STS ARN format')) + return true + } + ) + }) + }) }) describe('extractAccountIdFromSageMakerArn', () => { @@ -577,3 +704,39 @@ describe('extractAccountIdFromResourceMetadata', () => { ) }) }) + +describe('isCredentialExpirationError', () => { + describe('should return true for credential expiration errors', () => { + it('should detect ExpiredTokenException by error name (exact match)', () => { + const error = { + name: 'ExpiredTokenException', + message: 'Token has expired', + } + + const result = isCredentialExpirationError(error) + assert.strictEqual(result, true) + }) + + it('should detect ExpiredTokenException in error message', () => { + const error = { + name: 'SomeOtherError', + message: 'Request failed with ExpiredTokenException: Token has expired', + } + + const result = isCredentialExpirationError(error) + assert.strictEqual(result, true) + }) + }) + + describe('should return false for non-expiration errors', () => { + it('should return false for different error names', () => { + const error = { + name: 'AccessDeniedException', + message: 'Access denied', + } + + const result = isCredentialExpirationError(error) + assert.strictEqual(result, false) + }) + }) +}) diff --git a/packages/core/src/test/shared/errors.test.ts b/packages/core/src/test/shared/errors.test.ts index 32d18186912..093feceb012 100644 --- a/packages/core/src/test/shared/errors.test.ts +++ b/packages/core/src/test/shared/errors.test.ts @@ -683,6 +683,49 @@ describe('util', function () { assert.deepStrictEqual(scrubNames('unix ~jdoe123/.aws/config failed', fakeUser), 'unix ~x/.aws/config failed') assert.deepStrictEqual(scrubNames('unix ../../.aws/config failed', fakeUser), 'unix ../../.aws/config failed') assert.deepStrictEqual(scrubNames('unix ~/.aws/config failed', fakeUser), 'unix ~/.aws/config failed') + + // Profile name scrubbing - tests all three patterns + + // Pattern 1: profile name with space separator + const profileTest1 = scrubNames('Error with profile my-profile', fakeUser) + assert.deepStrictEqual(profileTest1, 'Error with profile [REDACTED]', 'Should handle space-separated profile') + assert.ok(!profileTest1.includes('my-profile'), 'Original profile name should not appear') + + // Pattern 2: profile name with single quotes + const profileTest2 = scrubNames("Failed to load profile 'production-admin'", fakeUser) + assert.deepStrictEqual(profileTest2, 'Failed to load profile [REDACTED]', 'Should handle single-quoted profile') + assert.ok(!profileTest2.includes('production-admin'), 'Original profile name should not appear') + assert.ok(!profileTest2.includes("'"), 'Closing quote should be removed') + + // Pattern 2: profile name with double quotes + const profileTest3 = scrubNames('Using profile "staging-env" for authentication', fakeUser) + assert.deepStrictEqual( + profileTest3, + 'Using profile [REDACTED] for authentication', + 'Should handle double-quoted profile' + ) + assert.ok(!profileTest3.includes('staging-env'), 'Original profile name should not appear') + assert.ok(!profileTest3.includes('"'), 'Closing quote should be removed') + + // Pattern 3: profile name with colon separator + const profileTest4 = scrubNames('Profile: dev-account not found', fakeUser) + assert.deepStrictEqual(profileTest4, 'Profile [REDACTED] not found', 'Should handle colon-separated profile') + assert.ok(!profileTest4.includes('dev-account'), 'Original profile name should not appear') + + // Case preservation tests + const profileTest5 = scrubNames('PROFILE: admin-user failed', fakeUser) + assert.deepStrictEqual(profileTest5, 'PROFILE [REDACTED] failed', 'Should preserve uppercase PROFILE') + + const profileTest6 = scrubNames("Profile 'test-123' is invalid", fakeUser) + assert.deepStrictEqual(profileTest6, 'Profile [REDACTED] is invalid', 'Should preserve capitalized Profile') + + // Multiple profiles in one message + const profileTest7 = scrubNames("Switching from profile 'old-profile' to profile 'new-profile'", fakeUser) + assert.ok( + !profileTest7.includes('old-profile') && !profileTest7.includes('new-profile'), + 'Should redact multiple profiles' + ) + assert.ok(profileTest7.includes('[REDACTED]'), 'Should contain redaction markers') }) }) diff --git a/packages/toolkit/.changes/3.86.0.json b/packages/toolkit/.changes/3.86.0.json new file mode 100644 index 00000000000..aae928e9939 --- /dev/null +++ b/packages/toolkit/.changes/3.86.0.json @@ -0,0 +1,10 @@ +{ + "date": "2025-11-21", + "version": "3.86.0", + "entries": [ + { + "type": "Feature", + "description": "Remote IDE connection support for IDE Spaces deployed on SageMaker HyperPod clusters" + } + ] +} \ No newline at end of file diff --git a/packages/toolkit/.changes/next-release/Feature-3af2733f-4e98-4b52-8a06-0faedb98cf70.json b/packages/toolkit/.changes/next-release/Feature-3af2733f-4e98-4b52-8a06-0faedb98cf70.json deleted file mode 100644 index f57dcc3dc4c..00000000000 --- a/packages/toolkit/.changes/next-release/Feature-3af2733f-4e98-4b52-8a06-0faedb98cf70.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "type": "Feature", - "description": "Remote IDE connection support for IDE Spaces deployed on SageMaker HyperPod clusters" -} diff --git a/packages/toolkit/.changes/next-release/Feature-a48a9132-b8fe-4888-90be-372e29a225d0.json b/packages/toolkit/.changes/next-release/Feature-a48a9132-b8fe-4888-90be-372e29a225d0.json new file mode 100644 index 00000000000..2891f446ef0 --- /dev/null +++ b/packages/toolkit/.changes/next-release/Feature-a48a9132-b8fe-4888-90be-372e29a225d0.json @@ -0,0 +1,4 @@ +{ + "type": "Feature", + "description": "Support IAM based domains for SageMaker Unified Studio" +} diff --git a/packages/toolkit/CHANGELOG.md b/packages/toolkit/CHANGELOG.md index f8c958c23b1..b55eb7ef495 100644 --- a/packages/toolkit/CHANGELOG.md +++ b/packages/toolkit/CHANGELOG.md @@ -1,3 +1,7 @@ +## 3.86.0 2025-11-21 + +- **Feature** Remote IDE connection support for IDE Spaces deployed on SageMaker HyperPod clusters + ## 3.85.0 2025-11-19 - **Bug Fix** Lambda: Attaching a debugger to your Lambda functions using LocalStack is not working diff --git a/packages/toolkit/package.json b/packages/toolkit/package.json index 5b5ef7563c1..1cbeb59de05 100644 --- a/packages/toolkit/package.json +++ b/packages/toolkit/package.json @@ -2,7 +2,7 @@ "name": "aws-toolkit-vscode", "displayName": "AWS Toolkit", "description": "Including CodeCatalyst, Infrastructure Composer, and support for Lambda, S3, CloudWatch Logs, CloudFormation, and many other services.", - "version": "3.86.0-SNAPSHOT", + "version": "3.87.0-SNAPSHOT", "extensionKind": [ "workspace" ], @@ -1938,6 +1938,11 @@ "when": "view == aws.explorer && viewItem == awsSagemakerHyperpodNode", "group": "inline@1" }, + { + "command": "aws.smus.switchProject", + "when": "view == aws.smus.rootView && viewItem == smusSelectedProject", + "group": "0_project@1" + }, { "command": "aws.smus.refreshProject", "when": "view == aws.smus.rootView && viewItem == smusSelectedProject", @@ -3245,6 +3250,16 @@ } } }, + { + "command": "aws.smus.refresh", + "title": "%AWS.command.smus.refresh%", + "category": "%AWS.title%", + "enablement": "isCloud9 || !aws.isWebExtHost", + "icon": { + "dark": "resources/icons/vscode/dark/refresh.svg", + "light": "resources/icons/vscode/light/refresh.svg" + } + }, { "command": "aws.smus.signOut", "title": "%AWS.command.smus.signOut%",