Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ export class ConnectionCredentialsProvider implements CredentialsProvider {
return this.smusAuthProvider.getDomainRegion()
}

/**
* Gets the domain AWS account ID
* @returns Promise resolving to the domain account ID
*/
public async getDomainAccountId(): Promise<string> {
return this.smusAuthProvider.getDomainAccountId()
}

/**
* Gets the hash code
* @returns Hash code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import * as localizedText from '../../../shared/localizedText'
import { ToolkitPromptSettings } from '../../../shared/settings'
import { setContext, getContext } from '../../../shared/vscode/setContext'
import { getLogger } from '../../../shared/logger/logger'
import { SmusUtils, SmusErrorCodes } from '../../shared/smusUtils'
import { SmusUtils, SmusErrorCodes, extractAccountIdFromArn } from '../../shared/smusUtils'
import { createSmusProfile, isValidSmusConnection, SmusConnection } from '../model'
import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsProvider'
import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider'
Expand All @@ -23,6 +23,7 @@ import { ConnectionClientStore } from '../../shared/client/connectionClientStore
import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils'
import { fromIni } from '@aws-sdk/credential-providers'
import { randomUUID } from '../../../shared/crypto'
import { DefaultStsClient } from '../../../shared/clients/stsClient'

/**
* Sets the context variable for SageMaker Unified Studio connection state
Expand Down Expand Up @@ -53,6 +54,7 @@ export class SmusAuthenticationProvider {
private credentialsProviderCache = new Map<string, any>()
private projectCredentialProvidersCache = new Map<string, ProjectRoleCredentialsProvider>()
private connectionCredentialProvidersCache = new Map<string, ConnectionCredentialsProvider>()
private cachedDomainAccountId: string | undefined

public constructor(
public readonly auth = Auth.instance,
Expand All @@ -75,6 +77,8 @@ export class SmusAuthenticationProvider {
this.projectCredentialProvidersCache.clear()
// Clear connection provider cache when connection changes
this.connectionCredentialProvidersCache.clear()
// Clear cached domain account ID when connection changes
this.cachedDomainAccountId = undefined
// Clear all clients in client store when connection changes
ConnectionClientStore.getInstance().clearAll()
await setSmusConnectedContext(this.isConnected())
Expand Down Expand Up @@ -423,6 +427,99 @@ export class SmusAuthenticationProvider {
return this.activeConnection.domainUrl
}

/**
* Gets the AWS account ID for the active domain connection
* In SMUS space environment, extracts from ResourceArn in metadata
* Otherwise, makes an STS GetCallerIdentity call using DER credentials and caches the result
* @returns Promise resolving to the domain's AWS account ID
* @throws ToolkitError if unable to retrieve account ID
*/
public async getDomainAccountId(): Promise<string> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can follow up: let's add a unit test for this method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is added

const logger = getLogger()

// Return cached value if available
if (this.cachedDomainAccountId) {
logger.debug('SMUS: Using cached domain account ID')
return this.cachedDomainAccountId
}

// If in SMUS space environment, extract account ID from resource-metadata file
if (getContext('aws.smus.inSmusSpaceEnvironment')) {
try {
logger.debug('SMUS: Extracting domain account ID from ResourceArn in resource-metadata file')

const resourceMetadata = getResourceMetadata()!
const resourceArn = resourceMetadata.ResourceArn

if (!resourceArn) {
throw new ToolkitError('ResourceArn not found in metadata file', {
code: SmusErrorCodes.AccountIdNotFound,
})
}

// Extract account ID from ResourceArn using SmusUtils
const accountId = extractAccountIdFromArn(resourceArn)

// Cache the account ID
this.cachedDomainAccountId = accountId

logger.debug(
`Successfully extracted and cached domain account ID from resource-metadata file: ${accountId}`
)

return accountId
} catch (err) {
logger.error(`Failed to extract domain account ID from ResourceArn: %s`, err)

throw new ToolkitError('Failed to extract AWS account ID from ResourceArn in SMUS space environment', {
code: SmusErrorCodes.GetDomainAccountIdFailed,
cause: err instanceof Error ? err : undefined,
})
}
}

if (!this.activeConnection) {
throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection })
}

// Use existing STS GetCallerIdentity implementation for non-SMUS space environments
try {
logger.debug('Fetching domain account ID via STS GetCallerIdentity')

// Get DER credentials provider
const derCredProvider = await this.getDerCredentialsProvider()

// Get the region for STS client
const region = this.getDomainRegion()

// Create STS client with DER credentials
const stsClient = new DefaultStsClient(region, await derCredProvider.getCredentials())

// Make GetCallerIdentity call
const callerIdentity = await stsClient.getCallerIdentity()

if (!callerIdentity.Account) {
throw new ToolkitError('Account ID not found in STS GetCallerIdentity response', {
code: SmusErrorCodes.AccountIdNotFound,
})
}

// Cache the account ID
this.cachedDomainAccountId = callerIdentity.Account

logger.debug(`Successfully retrieved and cached domain account ID: ${callerIdentity.Account}`)

return callerIdentity.Account
} catch (err) {
logger.error(`Failed to retrieve domain account ID: %s`, err)

throw new ToolkitError('Failed to retrieve AWS account ID for active domain connection', {
code: SmusErrorCodes.GetDomainAccountIdFailed,
cause: err instanceof Error ? err : undefined,
})
}
}

public getDomainRegion(): string {
if (getContext('aws.smus.inSmusSpaceEnvironment')) {
const resourceMetadata = getResourceMetadata()!
Expand Down Expand Up @@ -516,6 +613,10 @@ export class SmusAuthenticationProvider {
logger.warn(`SMUS: Failed to invalidate connection credentials for cache key ${cacheKey}: %s`, err)
}
}

// Clear cached domain account ID
this.cachedDomainAccountId = undefined
logger.debug('SMUS: Cleared cached domain account ID')
}

/**
Expand Down Expand Up @@ -560,6 +661,10 @@ export class SmusAuthenticationProvider {
}
}
this.credentialsProviderCache.clear()

// Clear cached domain account ID
this.cachedDomainAccountId = undefined

this.logger.debug('SMUS Auth: Successfully disposed authentication provider')
}

Expand Down
26 changes: 4 additions & 22 deletions packages/core/src/sagemakerunifiedstudio/explorer/activation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import { getLogger } from '../../shared/logger/logger'
import { setSmusConnectedContext, SmusAuthenticationProvider } from '../auth/providers/smusAuthenticationProvider'
import { setupUserActivityMonitoring } from '../../awsService/sagemaker/sagemakerSpace'
import { telemetry } from '../../shared/telemetry/telemetry'
import { SageMakerUnifiedStudioSpacesParentNode } from './nodes/sageMakerUnifiedStudioSpacesParentNode'
import { isSageMaker } from '../../shared/extensionUtilities'
import { recordSpaceTelemetry } from '../shared/telemetry'

export async function activate(extensionContext: vscode.ExtensionContext): Promise<void> {
// Initialize the SMUS authentication provider
Expand Down Expand Up @@ -75,16 +75,7 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi
return
}
await telemetry.smus_stopSpace.run(async (span) => {
span.record({
smusSpaceKey: node.resource.DomainSpaceKey,
smusDomainRegion: node.resource.regionCode,
smusDomainId: (
node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
)?.getAuthProvider()?.activeConnection?.domainId,
smusProjectId: (
node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
)?.getProjectId(),
})
await recordSpaceTelemetry(span, node)
await stopSpace(node.resource, extensionContext, node.resource.sageMakerClient)
})
}),
Expand All @@ -95,17 +86,8 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi
if (!validateNode(node)) {
return
}
await telemetry.smus_startSpace.run(async (span) => {
span.record({
smusSpaceKey: node.resource.DomainSpaceKey,
smusDomainRegion: node.resource.regionCode,
smusDomainId: (
node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
)?.getAuthProvider()?.activeConnection?.domainId,
smusProjectId: (
node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
)?.getProjectId(),
})
await telemetry.smus_openRemoteConnection.run(async (span) => {
await recordSpaceTelemetry(span, node)
await openRemoteConnect(node.resource, extensionContext, node.resource.sageMakerClient)
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import { createPlaceholderItem } from '../../../shared/treeview/utils'
import { Column, Database, Table } from '@aws-sdk/client-glue'
import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider'
import { telemetry } from '../../../shared/telemetry/telemetry'
import { getContext } from '../../../shared/vscode/setContext'
import { recordDataConnectionTelemetry } from '../../shared/telemetry'

/**
* Lakehouse data node for SageMaker Unified Studio
Expand Down Expand Up @@ -152,16 +152,7 @@ export function createLakehouseConnectionNode(
},
async (node) => {
return telemetry.smus_renderLakehouseNode.run(async (span) => {
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')

span.record({
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
smusDomainId: connection.domainId,
smusProjectId: connection.projectId,
smusConnectionId: connection.connectionId,
smusConnectionType: connection.type,
smusProjectRegion: connection.location?.awsRegion,
})
await recordDataConnectionTelemetry(span, connection, connectionCredentialsProvider)
try {
logger.info(`Loading Lakehouse catalogs for connection ${connection.name}`)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ export function createRedshiftConnectionNode(
async (node) => {
return telemetry.smus_renderRedshiftNode.run(async (span) => {
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')

const accountId = await connectionCredentialsProvider.getDomainAccountId()
span.record({
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
smusDomainId: connection.domainId,
smusDomainAccountId: accountId,
smusProjectId: connection.projectId,
smusConnectionId: connection.connectionId,
smusConnectionType: connection.type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
import { S3, ListObjectsV2Command } from '@aws-sdk/client-s3'
import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider'
import { telemetry } from '../../../shared/telemetry/telemetry'
import { getContext } from '../../../shared/vscode/setContext'
import { recordDataConnectionTelemetry } from '../../shared/telemetry'

// Regex to match default S3 connection names
// eslint-disable-next-line @typescript-eslint/naming-convention
Expand Down Expand Up @@ -144,16 +144,7 @@ export function createS3ConnectionNode(
},
async (node) => {
return telemetry.smus_renderS3Node.run(async (span) => {
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')

span.record({
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
smusDomainId: connection.domainId,
smusProjectId: connection.projectId,
smusConnectionId: connection.connectionId,
smusConnectionType: connection.type,
smusProjectRegion: connection.location?.awsRegion,
})
await recordDataConnectionTelemetry(span, connection, connectionCredentialsProvider)
try {
if (isDefaultConnection && s3Info.prefix) {
// For default connections, show the full path as the first node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode {
return telemetry.smus_renderProjectChildrenNode.run(async (span) => {
try {
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')

const accountId = await this.authProvider.getDomainAccountId()
span.record({
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
smusDomainId: this.project?.domainId,
smusDomainAccountId: accountId,
smusProjectId: this.project?.id,
smusDomainRegion: this.authProvider.getDomainRegion(),
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ import { telemetry } from '../../../shared/telemetry/telemetry'
import { createQuickPick } from '../../../shared/ui/pickerPrompter'
import { SageMakerUnifiedStudioProjectNode } from './sageMakerUnifiedStudioProjectNode'
import { SageMakerUnifiedStudioAuthInfoNode } from './sageMakerUnifiedStudioAuthInfoNode'
import { SmusUtils } from '../../shared/smusUtils'
import { SmusErrorCodes, SmusUtils } from '../../shared/smusUtils'
import { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticationProvider'
import { ToolkitError } from '../../../../src/shared/errors'
import { errorCode } from '../../shared/errors'
import { recordAuthTelemetry } from '../../shared/telemetry'

const contextValueSmusRoot = 'sageMakerUnifiedStudioRoot'
const contextValueSmusLogin = 'sageMakerUnifiedStudioLogin'
Expand Down Expand Up @@ -237,7 +237,10 @@ export const smusLoginCommand = Commands.declare('aws.smus.login', () => async (
if (!domainUrl) {
// User cancelled
logger.debug('User cancelled domain URL input')
return
throw new ToolkitError('User cancelled domain URL input', {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this typical? Why throw error, this is just a normal scenario where user cancelled workflow right?
Will this show up as a Fault or no in telemetry/dashboards/alarms?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And did we validate and test? Are the telemetry events being emitted correctly? Can you share examples events? On Slack is fine too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simple return signal toolkit to overwrite result:succeeded.
Yes it is tested:

Metadata: {
    metricId: 'e07ce894-93e6-4f9b-a0b2-363f01a95c6e',
    traceId: '7fb087c7-dcd5-4e4e-bbf5-1f01abf0330f',
    command: 'aws.smus.login',
    duration: '3026',
    result: 'Cancelled',
    reason: 'Error',
    reasonDesc: 'Failed to initiate login. | User cancelled domain URL input',
    awsAccount: 'not-set',
    awsRegion: 'us-east-1'
  },
  Value: 1,
  Unit: 'None',
  Passive: true
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would user see in this case? It's a very common case for user to firstly click on log in and realize they didn't have the domain url ready. So they leave to find the domain url. I don't think in this case they should see an error when they come back.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The users will see this error message when they dont provide domainUrl. They can retry to provide domain url again. I was not able to override result:cancelled without ToolkitError. We rely on ToolkitError for now.

Screenshot 2025-09-15 at 5 14 20 PM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed it by skipping showErrorMessage when UserCancelled

cancelled: true,
code: SmusErrorCodes.UserCancelled,
})
}

// Show a simple status bar message instead of progress dialog
Expand All @@ -252,19 +255,16 @@ export const smusLoginCommand = Commands.declare('aws.smus.login', () => async (

if (!connection) {
throw new ToolkitError('Failed to establish connection', {
code: errorCode.failedAuthConnecton,
code: SmusErrorCodes.FailedAuthConnecton,
})
}

// Extract domain ID and region for logging
// Extract domain account ID, domain ID, and region for logging
const domainId = connection.domainId
const region = connection.ssoRegion

logger.info(`Connected to SageMaker Unified Studio domain: ${domainId} in region ${region}`)
span.record({
smusDomainId: domainId,
awsRegion: region,
})
await recordAuthTelemetry(span, authProvider, domainId, region)

// Show success message
void vscode.window.showInformationMessage(
Expand Down Expand Up @@ -292,9 +292,12 @@ export const smusLoginCommand = Commands.declare('aws.smus.login', () => async (
})
}
} catch (err) {
void vscode.window.showErrorMessage(
`SageMaker Unified Studio: Failed to initiate login: ${(err as Error).message}`
)
const isUserCancelled = err instanceof ToolkitError && err.code === SmusErrorCodes.UserCancelled
if (!isUserCancelled) {
void vscode.window.showErrorMessage(
`SageMaker Unified Studio: Failed to initiate login: ${(err as Error).message}`
)
}
logger.error('Failed to initiate login: %s', (err as Error).message)
throw new ToolkitError('Failed to initiate login.', {
cause: err as Error,
Expand Down Expand Up @@ -329,11 +332,7 @@ export const smusSignOutCommand = Commands.declare('aws.smus.signOut', () => asy

// Show status message
vscode.window.setStatusBarMessage('Signing out from SageMaker Unified Studio...', 5000)

span.record({
smusDomainId: domainId,
awsRegion: region,
})
await recordAuthTelemetry(span, authProvider, domainId, region)

// Delete the connection (this will also invalidate tokens and clear cache)
if (activeConnection) {
Expand Down Expand Up @@ -425,10 +424,12 @@ export async function selectSMUSProject(projectNode?: SageMakerUnifiedStudioProj
}

const selectedProject = await showQuickPick(items)
const accountId = await authProvider.getDomainAccountId()
span.record({
smusDomainId: authProvider.getDomainId(),
smusProjectId: (selectedProject as DataZoneProject).id as string | undefined,
smusDomainRegion: authProvider.getDomainRegion(),
smusDomainAccountId: accountId,
})
if (
selectedProject &&
Expand Down
8 changes: 0 additions & 8 deletions packages/core/src/sagemakerunifiedstudio/shared/errors.ts

This file was deleted.

Loading
Loading