-
Notifications
You must be signed in to change notification settings - Fork 730
telemetry: add project account id and region #8079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ import * as localizedText from '../../../shared/localizedText' | |
| import { ToolkitPromptSettings } from '../../../shared/settings' | ||
| import { setContext, getContext } from '../../../shared/vscode/setContext' | ||
| import { getLogger } from '../../../shared/logger/logger' | ||
| import { SmusUtils, SmusErrorCodes, extractAccountIdFromArn } from '../../shared/smusUtils' | ||
| import { SmusUtils, SmusErrorCodes, extractAccountIdFromResourceMetadata } from '../../shared/smusUtils' | ||
| import { createSmusProfile, isValidSmusConnection, SmusConnection } from '../model' | ||
| import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsProvider' | ||
| import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider' | ||
|
|
@@ -24,6 +24,7 @@ import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils' | |
| import { fromIni } from '@aws-sdk/credential-providers' | ||
| import { randomUUID } from '../../../shared/crypto' | ||
| import { DefaultStsClient } from '../../../shared/clients/stsClient' | ||
| import { DataZoneClient } from '../../shared/client/datazoneClient' | ||
|
|
||
| /** | ||
| * Sets the context variable for SageMaker Unified Studio connection state | ||
|
|
@@ -55,6 +56,7 @@ export class SmusAuthenticationProvider { | |
| private projectCredentialProvidersCache = new Map<string, ProjectRoleCredentialsProvider>() | ||
| private connectionCredentialProvidersCache = new Map<string, ConnectionCredentialsProvider>() | ||
| private cachedDomainAccountId: string | undefined | ||
| private cachedProjectAccountIds = new Map<string, string>() | ||
|
|
||
| public constructor( | ||
| public readonly auth = Auth.instance, | ||
|
|
@@ -79,6 +81,8 @@ export class SmusAuthenticationProvider { | |
| this.connectionCredentialProvidersCache.clear() | ||
| // Clear cached domain account ID when connection changes | ||
| this.cachedDomainAccountId = undefined | ||
| // Clear cached project account IDs when connection changes | ||
| this.cachedProjectAccountIds.clear() | ||
| // Clear all clients in client store when connection changes | ||
| ConnectionClientStore.getInstance().clearAll() | ||
| await setSmusConnectedContext(this.isConnected()) | ||
|
|
@@ -445,37 +449,13 @@ export class SmusAuthenticationProvider { | |
|
|
||
| // If in SMUS space environment, extract account ID from resource-metadata file | ||
| if (getContext('aws.smus.inSmusSpaceEnvironment')) { | ||
| try { | ||
| logger.debug('SMUS: Extracting domain account ID from ResourceArn in resource-metadata file') | ||
|
|
||
| const resourceMetadata = getResourceMetadata()! | ||
| const resourceArn = resourceMetadata.ResourceArn | ||
|
|
||
| if (!resourceArn) { | ||
| throw new ToolkitError('ResourceArn not found in metadata file', { | ||
| code: SmusErrorCodes.AccountIdNotFound, | ||
| }) | ||
| } | ||
| const accountId = await extractAccountIdFromResourceMetadata() | ||
|
|
||
| // Extract account ID from ResourceArn using SmusUtils | ||
| const accountId = extractAccountIdFromArn(resourceArn) | ||
|
|
||
| // Cache the account ID | ||
| this.cachedDomainAccountId = accountId | ||
|
|
||
| logger.debug( | ||
| `Successfully extracted and cached domain account ID from resource-metadata file: ${accountId}` | ||
| ) | ||
|
|
||
| return accountId | ||
| } catch (err) { | ||
| logger.error(`Failed to extract domain account ID from ResourceArn: %s`, err) | ||
| // Cache the account ID | ||
| this.cachedDomainAccountId = accountId | ||
| logger.debug(`Successfully cached domain account ID: ${accountId}`) | ||
|
|
||
| throw new ToolkitError('Failed to extract AWS account ID from ResourceArn in SMUS space environment', { | ||
| code: SmusErrorCodes.GetDomainAccountIdFailed, | ||
| cause: err instanceof Error ? err : undefined, | ||
| }) | ||
| } | ||
| return accountId | ||
| } | ||
|
|
||
| if (!this.activeConnection) { | ||
|
|
@@ -520,6 +500,81 @@ export class SmusAuthenticationProvider { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Gets the AWS account ID for a specific project using project credentials | ||
| * In SMUS space environment, extracts from ResourceArn in metadata (same as domain account) | ||
| * Otherwise, makes an STS GetCallerIdentity call using project credentials | ||
| * @param projectId The DataZone project ID | ||
| * @returns Promise resolving to the project's AWS account ID | ||
| */ | ||
| public async getProjectAccountId(projectId: string): Promise<string> { | ||
| const logger = getLogger() | ||
|
|
||
| // Return cached value if available | ||
| if (this.cachedProjectAccountIds.has(projectId)) { | ||
| logger.debug(`SMUS: Using cached project account ID for project ${projectId}`) | ||
| return this.cachedProjectAccountIds.get(projectId)! | ||
| } | ||
|
|
||
| // If in SMUS space environment, extract account ID from resource-metadata file | ||
| if (getContext('aws.smus.inSmusSpaceEnvironment')) { | ||
| const accountId = await extractAccountIdFromResourceMetadata() | ||
|
|
||
| // Cache the account ID | ||
| this.cachedProjectAccountIds.set(projectId, accountId) | ||
| logger.debug(`Successfully cached project account ID for project ${projectId}: ${accountId}`) | ||
|
|
||
| return accountId | ||
| } | ||
|
|
||
| if (!this.activeConnection) { | ||
| throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection }) | ||
| } | ||
|
|
||
| // For non-SMUS space environments, use project credentials with STS | ||
| try { | ||
| logger.debug('Fetching project account ID via STS GetCallerIdentity with project credentials') | ||
|
|
||
| // Get project credentials | ||
| const projectCredProvider = await this.getProjectCredentialProvider(projectId) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can me move this as a util function as well
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may leave it there because, first the logic is not used in other places, second it follows getDomainAccountId theme. |
||
| const projectCreds = await projectCredProvider.getCredentials() | ||
|
|
||
| // Get project region from tooling environment | ||
| const dzClient = await DataZoneClient.getInstance(this) | ||
| const toolingEnv = await dzClient.getToolingEnvironment(projectId) | ||
| const projectRegion = toolingEnv.awsAccountRegion | ||
|
|
||
| if (!projectRegion) { | ||
| throw new ToolkitError('No AWS account region found in tooling environment', { | ||
| code: SmusErrorCodes.RegionNotFound, | ||
| }) | ||
| } | ||
|
|
||
| // Use STS to get account ID from project credentials | ||
| const stsClient = new DefaultStsClient(projectRegion, projectCreds) | ||
| const callerIdentity = await stsClient.getCallerIdentity() | ||
|
|
||
| if (!callerIdentity.Account) { | ||
| throw new ToolkitError('Account ID not found in STS GetCallerIdentity response', { | ||
| code: SmusErrorCodes.AccountIdNotFound, | ||
| }) | ||
| } | ||
|
|
||
| // Cache the account ID | ||
| this.cachedProjectAccountIds.set(projectId, callerIdentity.Account) | ||
| logger.debug( | ||
| `Successfully retrieved and cached project account ID for project ${projectId}: ${callerIdentity.Account}` | ||
| ) | ||
|
|
||
| return callerIdentity.Account | ||
| } catch (err) { | ||
| logger.error('Failed to get project account ID: %s', err as Error) | ||
| throw new ToolkitError(`Failed to get project account ID: ${(err as Error).message}`, { | ||
| code: SmusErrorCodes.GetProjectAccountIdFailed, | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| public getDomainRegion(): string { | ||
| if (getContext('aws.smus.inSmusSpaceEnvironment')) { | ||
| const resourceMetadata = getResourceMetadata()! | ||
|
|
@@ -617,6 +672,10 @@ export class SmusAuthenticationProvider { | |
| // Clear cached domain account ID | ||
| this.cachedDomainAccountId = undefined | ||
| logger.debug('SMUS: Cleared cached domain account ID') | ||
|
|
||
| // Clear cached project account IDs | ||
| this.cachedProjectAccountIds.clear() | ||
| logger.debug('SMUS: Cleared cached project account IDs') | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -665,6 +724,9 @@ export class SmusAuthenticationProvider { | |
| // Clear cached domain account ID | ||
| this.cachedDomainAccountId = undefined | ||
|
|
||
| // Clear cached project account IDs | ||
| this.cachedProjectAccountIds.clear() | ||
|
|
||
| this.logger.debug('SMUS Auth: Successfully disposed authentication provider') | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ import { SmusAuthenticationProvider } from '../auth/providers/smusAuthentication | |
| import { getLogger } from '../../shared/logger/logger' | ||
| import { getContext } from '../../shared/vscode/setContext' | ||
| import { ConnectionCredentialsProvider } from '../auth/providers/connectionCredentialsProvider' | ||
| import { DataZoneConnection } from './client/datazoneClient' | ||
| import { DataZoneConnection, DataZoneClient } from './client/datazoneClient' | ||
|
|
||
| /** | ||
| * Records space telemetry | ||
|
|
@@ -27,16 +27,39 @@ export async function recordSpaceTelemetry( | |
| span: Span<SmusOpenRemoteConnection> | Span<SmusStopSpace>, | ||
| node: SagemakerUnifiedStudioSpaceNode | ||
| ) { | ||
| const parent = node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode | ||
| const authProvider = SmusAuthenticationProvider.fromContext() | ||
| const accountId = await authProvider.getDomainAccountId() | ||
| span.record({ | ||
| smusSpaceKey: node.resource.DomainSpaceKey, | ||
| smusDomainRegion: node.resource.regionCode, | ||
| smusDomainId: parent?.getAuthProvider()?.activeConnection?.domainId, | ||
| smusDomainAccountId: accountId, | ||
| smusProjectId: parent?.getProjectId(), | ||
| }) | ||
| const logger = getLogger() | ||
|
|
||
| try { | ||
| const parent = node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode | ||
| const authProvider = SmusAuthenticationProvider.fromContext() | ||
| const accountId = await authProvider.getDomainAccountId() | ||
| const projectId = parent?.getProjectId() | ||
|
|
||
| // Get project account ID and region | ||
| let projectAccountId: string | undefined | ||
| let projectRegion: string | undefined | ||
|
|
||
| if (projectId) { | ||
| projectAccountId = await authProvider.getProjectAccountId(projectId) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just want to confirm, this method is only used by retrieving project account id for right?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
|
|
||
| // Get project region from tooling environment | ||
| const dzClient = await DataZoneClient.getInstance(authProvider) | ||
| const toolingEnv = await dzClient.getToolingEnvironment(projectId) | ||
| projectRegion = toolingEnv.awsAccountRegion | ||
| } | ||
|
|
||
| span.record({ | ||
| smusSpaceKey: node.resource.DomainSpaceKey, | ||
| smusDomainRegion: node.resource.regionCode, | ||
| smusDomainId: parent?.getAuthProvider()?.activeConnection?.domainId, | ||
| smusDomainAccountId: accountId, | ||
| smusProjectId: projectId, | ||
| smusProjectAccountId: projectAccountId, | ||
| smusProjectRegion: projectRegion, | ||
| }) | ||
| } catch (err) { | ||
| logger.error(`Failed to record space telemetry: ${(err as Error).message}`) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -65,7 +88,7 @@ export async function recordAuthTelemetry( | |
| }) | ||
| } catch (err) { | ||
| logger.error( | ||
| `Failed to resolve AWS account ID via STS Client for domain ${domainId} in region ${region}: ${err}` | ||
| `Failed to record Domain AccountId in data connection telemetry for domain ${domainId} in region ${region}: ${err}` | ||
| ) | ||
| } | ||
| } | ||
|
|
@@ -78,15 +101,22 @@ export async function recordDataConnectionTelemetry( | |
| connection: DataZoneConnection, | ||
| connectionCredentialsProvider: ConnectionCredentialsProvider | ||
| ) { | ||
| const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') | ||
| const accountId = await connectionCredentialsProvider.getDomainAccountId() | ||
| span.record({ | ||
| smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', | ||
| smusDomainId: connection.domainId, | ||
| smusDomainAccountId: accountId, | ||
| smusProjectId: connection.projectId, | ||
| smusConnectionId: connection.connectionId, | ||
| smusConnectionType: connection.type, | ||
| smusProjectRegion: connection.location?.awsRegion, | ||
| }) | ||
| const logger = getLogger() | ||
|
|
||
| try { | ||
| const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') | ||
| const accountId = await connectionCredentialsProvider.getDomainAccountId() | ||
| span.record({ | ||
| smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', | ||
| smusDomainId: connection.domainId, | ||
| smusDomainAccountId: accountId, | ||
| smusProjectId: connection.projectId, | ||
| smusConnectionId: connection.connectionId, | ||
| smusConnectionType: connection.type, | ||
| smusProjectRegion: connection.location?.awsRegion, | ||
| smusProjectAccountId: connection.location?.awsAccountId, | ||
| }) | ||
| } catch (err) { | ||
| logger.error(`Failed to record data connection telemetry: ${(err as Error).message}`) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the one in resource-metadata is always domain account id? can you confirm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The resourceArn in resource-metadata.json gives the project id.