Skip to content

Commit 8c03419

Browse files
committed
telemetry: add project account id and region
1 parent 6d678e9 commit 8c03419

File tree

7 files changed

+467
-94
lines changed

7 files changed

+467
-94
lines changed

packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts

Lines changed: 120 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import * as localizedText from '../../../shared/localizedText'
1414
import { ToolkitPromptSettings } from '../../../shared/settings'
1515
import { setContext, getContext } from '../../../shared/vscode/setContext'
1616
import { getLogger } from '../../../shared/logger/logger'
17-
import { SmusUtils, SmusErrorCodes, extractAccountIdFromArn } from '../../shared/smusUtils'
17+
import { SmusUtils, SmusErrorCodes, extractAccountIdFromSageMakerArn } from '../../shared/smusUtils'
1818
import { createSmusProfile, isValidSmusConnection, SmusConnection } from '../model'
1919
import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsProvider'
2020
import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider'
@@ -24,6 +24,7 @@ import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils'
2424
import { fromIni } from '@aws-sdk/credential-providers'
2525
import { randomUUID } from '../../../shared/crypto'
2626
import { DefaultStsClient } from '../../../shared/clients/stsClient'
27+
import { DataZoneClient } from '../../shared/client/datazoneClient'
2728

2829
/**
2930
* Sets the context variable for SageMaker Unified Studio connection state
@@ -55,6 +56,7 @@ export class SmusAuthenticationProvider {
5556
private projectCredentialProvidersCache = new Map<string, ProjectRoleCredentialsProvider>()
5657
private connectionCredentialProvidersCache = new Map<string, ConnectionCredentialsProvider>()
5758
private cachedDomainAccountId: string | undefined
59+
private cachedProjectAccountIds = new Map<string, string>()
5860

5961
public constructor(
6062
public readonly auth = Auth.instance,
@@ -79,6 +81,8 @@ export class SmusAuthenticationProvider {
7981
this.connectionCredentialProvidersCache.clear()
8082
// Clear cached domain account ID when connection changes
8183
this.cachedDomainAccountId = undefined
84+
// Clear cached project account IDs when connection changes
85+
this.cachedProjectAccountIds.clear()
8286
// Clear all clients in client store when connection changes
8387
ConnectionClientStore.getInstance().clearAll()
8488
await setSmusConnectedContext(this.isConnected())
@@ -427,6 +431,34 @@ export class SmusAuthenticationProvider {
427431
return this.activeConnection.domainUrl
428432
}
429433

434+
/**
435+
* Extracts account ID from ResourceArn in SMUS space environment
436+
* @returns Promise resolving to the account ID
437+
* @throws ToolkitError if unable to extract account ID
438+
*/
439+
private async extractAccountIdFromResourceMetadata(): Promise<string> {
440+
const logger = getLogger()
441+
442+
try {
443+
logger.debug('SMUS: Extracting account ID from ResourceArn in resource-metadata file')
444+
445+
const resourceMetadata = getResourceMetadata()!
446+
const resourceArn = resourceMetadata.ResourceArn
447+
448+
if (!resourceArn) {
449+
throw new Error('ResourceArn not found in metadata file')
450+
}
451+
452+
const accountId = extractAccountIdFromSageMakerArn(resourceArn)
453+
logger.debug(`Successfully extracted account ID from resource-metadata file: ${accountId}`)
454+
455+
return accountId
456+
} catch (err) {
457+
logger.error(`Failed to extract account ID from ResourceArn: %s`, err)
458+
throw new Error('Failed to extract AWS account ID from ResourceArn in SMUS space environment')
459+
}
460+
}
461+
430462
/**
431463
* Gets the AWS account ID for the active domain connection
432464
* In SMUS space environment, extracts from ResourceArn in metadata
@@ -445,37 +477,13 @@ export class SmusAuthenticationProvider {
445477

446478
// If in SMUS space environment, extract account ID from resource-metadata file
447479
if (getContext('aws.smus.inSmusSpaceEnvironment')) {
448-
try {
449-
logger.debug('SMUS: Extracting domain account ID from ResourceArn in resource-metadata file')
450-
451-
const resourceMetadata = getResourceMetadata()!
452-
const resourceArn = resourceMetadata.ResourceArn
453-
454-
if (!resourceArn) {
455-
throw new ToolkitError('ResourceArn not found in metadata file', {
456-
code: SmusErrorCodes.AccountIdNotFound,
457-
})
458-
}
459-
460-
// Extract account ID from ResourceArn using SmusUtils
461-
const accountId = extractAccountIdFromArn(resourceArn)
462-
463-
// Cache the account ID
464-
this.cachedDomainAccountId = accountId
480+
const accountId = await this.extractAccountIdFromResourceMetadata()
465481

466-
logger.debug(
467-
`Successfully extracted and cached domain account ID from resource-metadata file: ${accountId}`
468-
)
469-
470-
return accountId
471-
} catch (err) {
472-
logger.error(`Failed to extract domain account ID from ResourceArn: %s`, err)
482+
// Cache the account ID
483+
this.cachedDomainAccountId = accountId
484+
logger.debug(`Successfully cached domain account ID: ${accountId}`)
473485

474-
throw new ToolkitError('Failed to extract AWS account ID from ResourceArn in SMUS space environment', {
475-
code: SmusErrorCodes.GetDomainAccountIdFailed,
476-
cause: err instanceof Error ? err : undefined,
477-
})
478-
}
486+
return accountId
479487
}
480488

481489
if (!this.activeConnection) {
@@ -520,6 +528,81 @@ export class SmusAuthenticationProvider {
520528
}
521529
}
522530

531+
/**
532+
* Gets the AWS account ID for a specific project using project credentials
533+
* In SMUS space environment, extracts from ResourceArn in metadata (same as domain account)
534+
* Otherwise, makes an STS GetCallerIdentity call using project credentials
535+
* @param projectId The DataZone project ID
536+
* @returns Promise resolving to the project's AWS account ID
537+
*/
538+
public async getProjectAccountId(projectId: string): Promise<string> {
539+
const logger = getLogger()
540+
541+
// Return cached value if available
542+
if (this.cachedProjectAccountIds.has(projectId)) {
543+
logger.debug(`SMUS: Using cached project account ID for project ${projectId}`)
544+
return this.cachedProjectAccountIds.get(projectId)!
545+
}
546+
547+
// If in SMUS space environment, extract account ID from resource-metadata file
548+
if (getContext('aws.smus.inSmusSpaceEnvironment')) {
549+
const accountId = await this.extractAccountIdFromResourceMetadata()
550+
551+
// Cache the account ID
552+
this.cachedProjectAccountIds.set(projectId, accountId)
553+
logger.debug(`Successfully cached project account ID for project ${projectId}: ${accountId}`)
554+
555+
return accountId
556+
}
557+
558+
if (!this.activeConnection) {
559+
throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection })
560+
}
561+
562+
// For non-SMUS space environments, use project credentials with STS
563+
try {
564+
logger.debug('Fetching project account ID via STS GetCallerIdentity with project credentials')
565+
566+
// Get project credentials
567+
const projectCredProvider = await this.getProjectCredentialProvider(projectId)
568+
const projectCreds = await projectCredProvider.getCredentials()
569+
570+
// Get project region from tooling environment
571+
const dzClient = await DataZoneClient.getInstance(this)
572+
const toolingEnv = await dzClient.getToolingEnvironment(projectId)
573+
const projectRegion = toolingEnv.awsAccountRegion
574+
575+
if (!projectRegion) {
576+
throw new ToolkitError('No AWS account region found in tooling environment', {
577+
code: SmusErrorCodes.RegionNotFound,
578+
})
579+
}
580+
581+
// Use STS to get account ID from project credentials
582+
const stsClient = new DefaultStsClient(projectRegion, projectCreds)
583+
const callerIdentity = await stsClient.getCallerIdentity()
584+
585+
if (!callerIdentity.Account) {
586+
throw new ToolkitError('Account ID not found in STS GetCallerIdentity response', {
587+
code: SmusErrorCodes.AccountIdNotFound,
588+
})
589+
}
590+
591+
// Cache the account ID
592+
this.cachedProjectAccountIds.set(projectId, callerIdentity.Account)
593+
logger.debug(
594+
`Successfully retrieved and cached project account ID for project ${projectId}: ${callerIdentity.Account}`
595+
)
596+
597+
return callerIdentity.Account
598+
} catch (err) {
599+
logger.error('Failed to get project account ID: %s', err as Error)
600+
throw new ToolkitError(`Failed to get project account ID: ${(err as Error).message}`, {
601+
code: SmusErrorCodes.GetProjectAccountIdFailed,
602+
})
603+
}
604+
}
605+
523606
public getDomainRegion(): string {
524607
if (getContext('aws.smus.inSmusSpaceEnvironment')) {
525608
const resourceMetadata = getResourceMetadata()!
@@ -617,6 +700,10 @@ export class SmusAuthenticationProvider {
617700
// Clear cached domain account ID
618701
this.cachedDomainAccountId = undefined
619702
logger.debug('SMUS: Cleared cached domain account ID')
703+
704+
// Clear cached project account IDs
705+
this.cachedProjectAccountIds.clear()
706+
logger.debug('SMUS: Cleared cached project account IDs')
620707
}
621708

622709
/**
@@ -665,6 +752,9 @@ export class SmusAuthenticationProvider {
665752
// Clear cached domain account ID
666753
this.cachedDomainAccountId = undefined
667754

755+
// Clear cached project account IDs
756+
this.cachedProjectAccountIds.clear()
757+
668758
this.logger.debug('SMUS Auth: Successfully disposed authentication provider')
669759
}
670760

packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import { createPlaceholderItem } from '../../../shared/treeview/utils'
2424
import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider'
2525
import { GlueCatalog } from '../../shared/client/glueCatalogClient'
2626
import { telemetry } from '../../../shared/telemetry/telemetry'
27-
import { getContext } from '../../../shared/vscode/setContext'
27+
import { recordDataConnectionTelemetry } from '../../shared/telemetry'
2828

2929
/**
3030
* Redshift data node for SageMaker Unified Studio
@@ -119,6 +119,7 @@ export function createRedshiftConnectionNode(
119119
connection: DataZoneConnection,
120120
connectionCredentialsProvider: ConnectionCredentialsProvider
121121
): RedshiftNode {
122+
const logger = getLogger()
122123
return new RedshiftNode(
123124
{
124125
id: connection.connectionId,
@@ -130,19 +131,8 @@ export function createRedshiftConnectionNode(
130131
},
131132
async (node) => {
132133
return telemetry.smus_renderRedshiftNode.run(async (span) => {
133-
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')
134-
const accountId = await connectionCredentialsProvider.getDomainAccountId()
135-
span.record({
136-
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
137-
smusDomainId: connection.domainId,
138-
smusDomainAccountId: accountId,
139-
smusProjectId: connection.projectId,
140-
smusConnectionId: connection.connectionId,
141-
smusConnectionType: connection.type,
142-
smusProjectRegion: connection.location?.awsRegion,
143-
})
144-
const logger = getLogger()
145134
logger.info(`Loading Redshift resources for connection ${connection.name}`)
135+
await recordDataConnectionTelemetry(span, connection, connectionCredentialsProvider)
146136

147137
const connectionParams = extractConnectionParams(connection)
148138
if (!connectionParams) {

packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,14 @@ export const SmusErrorCodes = {
5656
UserCancelled: 'UserCancelled',
5757
/** Error code for when domain account Id is missing */
5858
AccountIdNotFound: 'AccountIdNotFound',
59+
/** Error code for when resource ARN is missing */
60+
ResourceArnNotFound: 'ResourceArnNotFound',
5961
/** Error code for when fails to get domain account Id */
6062
GetDomainAccountIdFailed: 'GetDomainAccountIdFailed',
63+
/** Error code for when fails to get project account Id */
64+
GetProjectAccountIdFailed: 'GetProjectAccountIdFailed',
65+
/** Error code for when region is missing */
66+
RegionNotFound: 'RegionNotFound',
6167
} as const
6268

6369
/**
@@ -369,7 +375,7 @@ export class SmusUtils {
369375
* @returns The account ID from the ARN
370376
* @throws If the ARN format is invalid
371377
*/
372-
export function extractAccountIdFromArn(arn: string): string {
378+
export function extractAccountIdFromSageMakerArn(arn: string): string {
373379
// Match the ARN components to extract account ID
374380
const regex = /^arn:aws:sagemaker:(?<region>[^:]+):(?<accountId>\d+):(app|space|domain)\/.+$/i
375381
const match = arn.match(regex)

packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import { SmusAuthenticationProvider } from '../auth/providers/smusAuthentication
1818
import { getLogger } from '../../shared/logger/logger'
1919
import { getContext } from '../../shared/vscode/setContext'
2020
import { ConnectionCredentialsProvider } from '../auth/providers/connectionCredentialsProvider'
21-
import { DataZoneConnection } from './client/datazoneClient'
21+
import { DataZoneConnection, DataZoneClient } from './client/datazoneClient'
2222

2323
/**
2424
* Records space telemetry
@@ -27,16 +27,39 @@ export async function recordSpaceTelemetry(
2727
span: Span<SmusOpenRemoteConnection> | Span<SmusStopSpace>,
2828
node: SagemakerUnifiedStudioSpaceNode
2929
) {
30-
const parent = node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
31-
const authProvider = SmusAuthenticationProvider.fromContext()
32-
const accountId = await authProvider.getDomainAccountId()
33-
span.record({
34-
smusSpaceKey: node.resource.DomainSpaceKey,
35-
smusDomainRegion: node.resource.regionCode,
36-
smusDomainId: parent?.getAuthProvider()?.activeConnection?.domainId,
37-
smusDomainAccountId: accountId,
38-
smusProjectId: parent?.getProjectId(),
39-
})
30+
const logger = getLogger()
31+
32+
try {
33+
const parent = node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
34+
const authProvider = SmusAuthenticationProvider.fromContext()
35+
const accountId = await authProvider.getDomainAccountId()
36+
const projectId = parent?.getProjectId()
37+
38+
// Get project account ID and region
39+
let projectAccountId: string | undefined
40+
let projectRegion: string | undefined
41+
42+
if (projectId) {
43+
projectAccountId = await authProvider.getProjectAccountId(projectId)
44+
45+
// Get project region from tooling environment
46+
const dzClient = await DataZoneClient.getInstance(authProvider)
47+
const toolingEnv = await dzClient.getToolingEnvironment(projectId)
48+
projectRegion = toolingEnv.awsAccountRegion
49+
}
50+
51+
span.record({
52+
smusSpaceKey: node.resource.DomainSpaceKey,
53+
smusDomainRegion: node.resource.regionCode,
54+
smusDomainId: parent?.getAuthProvider()?.activeConnection?.domainId,
55+
smusDomainAccountId: accountId,
56+
smusProjectId: projectId,
57+
smusProjectAccountId: projectAccountId,
58+
smusProjectRegion: projectRegion,
59+
})
60+
} catch (err) {
61+
logger.error(`Failed to record space telemetry: ${(err as Error).message}`)
62+
}
4063
}
4164

4265
/**
@@ -78,15 +101,22 @@ export async function recordDataConnectionTelemetry(
78101
connection: DataZoneConnection,
79102
connectionCredentialsProvider: ConnectionCredentialsProvider
80103
) {
81-
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')
82-
const accountId = await connectionCredentialsProvider.getDomainAccountId()
83-
span.record({
84-
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
85-
smusDomainId: connection.domainId,
86-
smusDomainAccountId: accountId,
87-
smusProjectId: connection.projectId,
88-
smusConnectionId: connection.connectionId,
89-
smusConnectionType: connection.type,
90-
smusProjectRegion: connection.location?.awsRegion,
91-
})
104+
const logger = getLogger()
105+
106+
try {
107+
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')
108+
const accountId = await connectionCredentialsProvider.getDomainAccountId()
109+
span.record({
110+
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
111+
smusDomainId: connection.domainId,
112+
smusDomainAccountId: accountId,
113+
smusProjectId: connection.projectId,
114+
smusConnectionId: connection.connectionId,
115+
smusConnectionType: connection.type,
116+
smusProjectRegion: connection.location?.awsRegion,
117+
smusProjectAccountId: connection.location?.awsAccountId,
118+
})
119+
} catch (err) {
120+
logger.error(`Failed to record data connection telemetry: ${(err as Error).message}`)
121+
}
92122
}

0 commit comments

Comments
 (0)