Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import * as localizedText from '../../../shared/localizedText'
import { ToolkitPromptSettings } from '../../../shared/settings'
import { setContext, getContext } from '../../../shared/vscode/setContext'
import { getLogger } from '../../../shared/logger/logger'
import { SmusUtils, SmusErrorCodes, extractAccountIdFromArn } from '../../shared/smusUtils'
import { SmusUtils, SmusErrorCodes, extractAccountIdFromResourceMetadata } from '../../shared/smusUtils'
import { createSmusProfile, isValidSmusConnection, SmusConnection } from '../model'
import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsProvider'
import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider'
Expand All @@ -24,6 +24,7 @@ import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils'
import { fromIni } from '@aws-sdk/credential-providers'
import { randomUUID } from '../../../shared/crypto'
import { DefaultStsClient } from '../../../shared/clients/stsClient'
import { DataZoneClient } from '../../shared/client/datazoneClient'

/**
* Sets the context variable for SageMaker Unified Studio connection state
Expand Down Expand Up @@ -55,6 +56,7 @@ export class SmusAuthenticationProvider {
private projectCredentialProvidersCache = new Map<string, ProjectRoleCredentialsProvider>()
private connectionCredentialProvidersCache = new Map<string, ConnectionCredentialsProvider>()
private cachedDomainAccountId: string | undefined
private cachedProjectAccountIds = new Map<string, string>()

public constructor(
public readonly auth = Auth.instance,
Expand All @@ -79,6 +81,8 @@ export class SmusAuthenticationProvider {
this.connectionCredentialProvidersCache.clear()
// Clear cached domain account ID when connection changes
this.cachedDomainAccountId = undefined
// Clear cached project account IDs when connection changes
this.cachedProjectAccountIds.clear()
// Clear all clients in client store when connection changes
ConnectionClientStore.getInstance().clearAll()
await setSmusConnectedContext(this.isConnected())
Expand Down Expand Up @@ -445,37 +449,13 @@ export class SmusAuthenticationProvider {

// If in SMUS space environment, extract account ID from resource-metadata file
if (getContext('aws.smus.inSmusSpaceEnvironment')) {
try {
logger.debug('SMUS: Extracting domain account ID from ResourceArn in resource-metadata file')

const resourceMetadata = getResourceMetadata()!
const resourceArn = resourceMetadata.ResourceArn

if (!resourceArn) {
throw new ToolkitError('ResourceArn not found in metadata file', {
code: SmusErrorCodes.AccountIdNotFound,
})
}
const accountId = await extractAccountIdFromResourceMetadata()

// Extract account ID from ResourceArn using SmusUtils
const accountId = extractAccountIdFromArn(resourceArn)

// Cache the account ID
this.cachedDomainAccountId = accountId

logger.debug(
`Successfully extracted and cached domain account ID from resource-metadata file: ${accountId}`
)

return accountId
} catch (err) {
logger.error(`Failed to extract domain account ID from ResourceArn: %s`, err)
// Cache the account ID
this.cachedDomainAccountId = accountId
logger.debug(`Successfully cached domain account ID: ${accountId}`)

throw new ToolkitError('Failed to extract AWS account ID from ResourceArn in SMUS space environment', {
code: SmusErrorCodes.GetDomainAccountIdFailed,
cause: err instanceof Error ? err : undefined,
})
}
return accountId
}

if (!this.activeConnection) {
Expand Down Expand Up @@ -520,6 +500,81 @@ export class SmusAuthenticationProvider {
}
}

/**
* Gets the AWS account ID for a specific project using project credentials
* In SMUS space environment, extracts from ResourceArn in metadata (same as domain account)
* Otherwise, makes an STS GetCallerIdentity call using project credentials
* @param projectId The DataZone project ID
* @returns Promise resolving to the project's AWS account ID
*/
public async getProjectAccountId(projectId: string): Promise<string> {
const logger = getLogger()

// Return cached value if available
if (this.cachedProjectAccountIds.has(projectId)) {
logger.debug(`SMUS: Using cached project account ID for project ${projectId}`)
return this.cachedProjectAccountIds.get(projectId)!
}

// If in SMUS space environment, extract account ID from resource-metadata file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the one in resource-metadata is always domain account id? can you confirm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The resourceArn in resource-metadata.json gives the project id.

if (getContext('aws.smus.inSmusSpaceEnvironment')) {
const accountId = await extractAccountIdFromResourceMetadata()

// Cache the account ID
this.cachedProjectAccountIds.set(projectId, accountId)
logger.debug(`Successfully cached project account ID for project ${projectId}: ${accountId}`)

return accountId
}

if (!this.activeConnection) {
throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection })
}

// For non-SMUS space environments, use project credentials with STS
try {
logger.debug('Fetching project account ID via STS GetCallerIdentity with project credentials')

// Get project credentials
const projectCredProvider = await this.getProjectCredentialProvider(projectId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can me move this as a util function as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may leave it there because, first the logic is not used in other places, second it follows getDomainAccountId theme.

const projectCreds = await projectCredProvider.getCredentials()

// Get project region from tooling environment
const dzClient = await DataZoneClient.getInstance(this)
const toolingEnv = await dzClient.getToolingEnvironment(projectId)
const projectRegion = toolingEnv.awsAccountRegion

if (!projectRegion) {
throw new ToolkitError('No AWS account region found in tooling environment', {
code: SmusErrorCodes.RegionNotFound,
})
}

// Use STS to get account ID from project credentials
const stsClient = new DefaultStsClient(projectRegion, projectCreds)
const callerIdentity = await stsClient.getCallerIdentity()

if (!callerIdentity.Account) {
throw new ToolkitError('Account ID not found in STS GetCallerIdentity response', {
code: SmusErrorCodes.AccountIdNotFound,
})
}

// Cache the account ID
this.cachedProjectAccountIds.set(projectId, callerIdentity.Account)
logger.debug(
`Successfully retrieved and cached project account ID for project ${projectId}: ${callerIdentity.Account}`
)

return callerIdentity.Account
} catch (err) {
logger.error('Failed to get project account ID: %s', err as Error)
throw new ToolkitError(`Failed to get project account ID: ${(err as Error).message}`, {
code: SmusErrorCodes.GetProjectAccountIdFailed,
})
}
}

public getDomainRegion(): string {
if (getContext('aws.smus.inSmusSpaceEnvironment')) {
const resourceMetadata = getResourceMetadata()!
Expand Down Expand Up @@ -617,6 +672,10 @@ export class SmusAuthenticationProvider {
// Clear cached domain account ID
this.cachedDomainAccountId = undefined
logger.debug('SMUS: Cleared cached domain account ID')

// Clear cached project account IDs
this.cachedProjectAccountIds.clear()
logger.debug('SMUS: Cleared cached project account IDs')
}

/**
Expand Down Expand Up @@ -665,6 +724,9 @@ export class SmusAuthenticationProvider {
// Clear cached domain account ID
this.cachedDomainAccountId = undefined

// Clear cached project account IDs
this.cachedProjectAccountIds.clear()

this.logger.debug('SMUS Auth: Successfully disposed authentication provider')
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import { createPlaceholderItem } from '../../../shared/treeview/utils'
import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider'
import { GlueCatalog } from '../../shared/client/glueCatalogClient'
import { telemetry } from '../../../shared/telemetry/telemetry'
import { getContext } from '../../../shared/vscode/setContext'
import { recordDataConnectionTelemetry } from '../../shared/telemetry'

/**
* Redshift data node for SageMaker Unified Studio
Expand Down Expand Up @@ -119,6 +119,7 @@ export function createRedshiftConnectionNode(
connection: DataZoneConnection,
connectionCredentialsProvider: ConnectionCredentialsProvider
): RedshiftNode {
const logger = getLogger()
return new RedshiftNode(
{
id: connection.connectionId,
Expand All @@ -130,19 +131,8 @@ export function createRedshiftConnectionNode(
},
async (node) => {
return telemetry.smus_renderRedshiftNode.run(async (span) => {
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')
const accountId = await connectionCredentialsProvider.getDomainAccountId()
span.record({
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
smusDomainId: connection.domainId,
smusDomainAccountId: accountId,
smusProjectId: connection.projectId,
smusConnectionId: connection.connectionId,
smusConnectionType: connection.type,
smusProjectRegion: connection.location?.awsRegion,
})
const logger = getLogger()
logger.info(`Loading Redshift resources for connection ${connection.name}`)
await recordDataConnectionTelemetry(span, connection, connectionCredentialsProvider)

const connectionParams = extractConnectionParams(connection)
if (!connectionParams) {
Expand Down
38 changes: 36 additions & 2 deletions packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,14 @@ export const SmusErrorCodes = {
UserCancelled: 'UserCancelled',
/** Error code for when domain account Id is missing */
AccountIdNotFound: 'AccountIdNotFound',
/** Error code for when resource ARN is missing */
ResourceArnNotFound: 'ResourceArnNotFound',
/** Error code for when fails to get domain account Id */
GetDomainAccountIdFailed: 'GetDomainAccountIdFailed',
/** Error code for when fails to get project account Id */
GetProjectAccountIdFailed: 'GetProjectAccountIdFailed',
/** Error code for when region is missing */
RegionNotFound: 'RegionNotFound',
} as const

/**
Expand Down Expand Up @@ -369,9 +375,9 @@ export class SmusUtils {
* @returns The account ID from the ARN
* @throws If the ARN format is invalid
*/
export function extractAccountIdFromArn(arn: string): string {
export function extractAccountIdFromSageMakerArn(arn: string): string {
// Match the ARN components to extract account ID
const regex = /^arn:aws:sagemaker:(?<region>[^:]+):(?<accountId>\d+):(app|space|domain)\/.+$/i
const regex = /^arn:aws:sagemaker:(?<region>[^:]+):(?<accountId>\d+):(app|space)\/.+$/i
const match = arn.match(regex)

if (!match?.groups) {
Expand All @@ -380,3 +386,31 @@ export function extractAccountIdFromArn(arn: string): string {

return match.groups.accountId
}

/**
* Extracts account ID from ResourceArn in SMUS space environment
* @returns Promise resolving to the account ID
* @throws ToolkitError if unable to extract account ID
*/
export async function extractAccountIdFromResourceMetadata(): Promise<string> {
const logger = getLogger()

try {
logger.debug('SMUS: Extracting account ID from ResourceArn in resource-metadata file')

const resourceMetadata = getResourceMetadata()!
const resourceArn = resourceMetadata.ResourceArn

if (!resourceArn) {
throw new Error('ResourceArn not found in metadata file')
}

const accountId = extractAccountIdFromSageMakerArn(resourceArn)
logger.debug(`Successfully extracted account ID from resource-metadata file: ${accountId}`)

return accountId
} catch (err) {
logger.error(`Failed to extract account ID from ResourceArn: %s`, err)
throw new Error('Failed to extract AWS account ID from ResourceArn in SMUS space environment')
}
}
76 changes: 53 additions & 23 deletions packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { SmusAuthenticationProvider } from '../auth/providers/smusAuthentication
import { getLogger } from '../../shared/logger/logger'
import { getContext } from '../../shared/vscode/setContext'
import { ConnectionCredentialsProvider } from '../auth/providers/connectionCredentialsProvider'
import { DataZoneConnection } from './client/datazoneClient'
import { DataZoneConnection, DataZoneClient } from './client/datazoneClient'

/**
* Records space telemetry
Expand All @@ -27,16 +27,39 @@ export async function recordSpaceTelemetry(
span: Span<SmusOpenRemoteConnection> | Span<SmusStopSpace>,
node: SagemakerUnifiedStudioSpaceNode
) {
const parent = node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
const authProvider = SmusAuthenticationProvider.fromContext()
const accountId = await authProvider.getDomainAccountId()
span.record({
smusSpaceKey: node.resource.DomainSpaceKey,
smusDomainRegion: node.resource.regionCode,
smusDomainId: parent?.getAuthProvider()?.activeConnection?.domainId,
smusDomainAccountId: accountId,
smusProjectId: parent?.getProjectId(),
})
const logger = getLogger()

try {
const parent = node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
const authProvider = SmusAuthenticationProvider.fromContext()
const accountId = await authProvider.getDomainAccountId()
const projectId = parent?.getProjectId()

// Get project account ID and region
let projectAccountId: string | undefined
let projectRegion: string | undefined

if (projectId) {
projectAccountId = await authProvider.getProjectAccountId(projectId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to confirm, this method is only used by retrieving project account id for right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.


// Get project region from tooling environment
const dzClient = await DataZoneClient.getInstance(authProvider)
const toolingEnv = await dzClient.getToolingEnvironment(projectId)
projectRegion = toolingEnv.awsAccountRegion
}

span.record({
smusSpaceKey: node.resource.DomainSpaceKey,
smusDomainRegion: node.resource.regionCode,
smusDomainId: parent?.getAuthProvider()?.activeConnection?.domainId,
smusDomainAccountId: accountId,
smusProjectId: projectId,
smusProjectAccountId: projectAccountId,
smusProjectRegion: projectRegion,
})
} catch (err) {
logger.error(`Failed to record space telemetry: ${(err as Error).message}`)
}
}

/**
Expand Down Expand Up @@ -65,7 +88,7 @@ export async function recordAuthTelemetry(
})
} catch (err) {
logger.error(
`Failed to resolve AWS account ID via STS Client for domain ${domainId} in region ${region}: ${err}`
`Failed to record Domain AccountId in data connection telemetry for domain ${domainId} in region ${region}: ${err}`
)
}
}
Expand All @@ -78,15 +101,22 @@ export async function recordDataConnectionTelemetry(
connection: DataZoneConnection,
connectionCredentialsProvider: ConnectionCredentialsProvider
) {
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')
const accountId = await connectionCredentialsProvider.getDomainAccountId()
span.record({
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
smusDomainId: connection.domainId,
smusDomainAccountId: accountId,
smusProjectId: connection.projectId,
smusConnectionId: connection.connectionId,
smusConnectionType: connection.type,
smusProjectRegion: connection.location?.awsRegion,
})
const logger = getLogger()

try {
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')
const accountId = await connectionCredentialsProvider.getDomainAccountId()
span.record({
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
smusDomainId: connection.domainId,
smusDomainAccountId: accountId,
smusProjectId: connection.projectId,
smusConnectionId: connection.connectionId,
smusConnectionType: connection.type,
smusProjectRegion: connection.location?.awsRegion,
smusProjectAccountId: connection.location?.awsAccountId,
})
} catch (err) {
logger.error(`Failed to record data connection telemetry: ${(err as Error).message}`)
}
}
Loading
Loading