Skip to content

Commit d968af9

Browse files
Merge master into feature/LSP-gamma
2 parents d07ccae + e06830b commit d968af9

21 files changed

+683
-279
lines changed

packages/core/src/sagemakerunifiedstudio/auth/providers/connectionCredentialsProvider.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ export class ConnectionCredentialsProvider implements CredentialsProvider {
7373
return this.smusAuthProvider.getDomainRegion()
7474
}
7575

76+
/**
77+
* Gets the domain AWS account ID
78+
* @returns Promise resolving to the domain account ID
79+
*/
80+
public async getDomainAccountId(): Promise<string> {
81+
return this.smusAuthProvider.getDomainAccountId()
82+
}
83+
7684
/**
7785
* Gets the hash code
7886
* @returns Hash code

packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import * as localizedText from '../../../shared/localizedText'
1414
import { ToolkitPromptSettings } from '../../../shared/settings'
1515
import { setContext, getContext } from '../../../shared/vscode/setContext'
1616
import { getLogger } from '../../../shared/logger/logger'
17-
import { SmusUtils, SmusErrorCodes } from '../../shared/smusUtils'
17+
import { SmusUtils, SmusErrorCodes, extractAccountIdFromArn } from '../../shared/smusUtils'
1818
import { createSmusProfile, isValidSmusConnection, SmusConnection } from '../model'
1919
import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsProvider'
2020
import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider'
@@ -23,6 +23,7 @@ import { ConnectionClientStore } from '../../shared/client/connectionClientStore
2323
import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils'
2424
import { fromIni } from '@aws-sdk/credential-providers'
2525
import { randomUUID } from '../../../shared/crypto'
26+
import { DefaultStsClient } from '../../../shared/clients/stsClient'
2627

2728
/**
2829
* Sets the context variable for SageMaker Unified Studio connection state
@@ -53,6 +54,7 @@ export class SmusAuthenticationProvider {
5354
private credentialsProviderCache = new Map<string, any>()
5455
private projectCredentialProvidersCache = new Map<string, ProjectRoleCredentialsProvider>()
5556
private connectionCredentialProvidersCache = new Map<string, ConnectionCredentialsProvider>()
57+
private cachedDomainAccountId: string | undefined
5658

5759
public constructor(
5860
public readonly auth = Auth.instance,
@@ -75,6 +77,8 @@ export class SmusAuthenticationProvider {
7577
this.projectCredentialProvidersCache.clear()
7678
// Clear connection provider cache when connection changes
7779
this.connectionCredentialProvidersCache.clear()
80+
// Clear cached domain account ID when connection changes
81+
this.cachedDomainAccountId = undefined
7882
// Clear all clients in client store when connection changes
7983
ConnectionClientStore.getInstance().clearAll()
8084
await setSmusConnectedContext(this.isConnected())
@@ -423,6 +427,99 @@ export class SmusAuthenticationProvider {
423427
return this.activeConnection.domainUrl
424428
}
425429

430+
/**
431+
* Gets the AWS account ID for the active domain connection
432+
* In SMUS space environment, extracts from ResourceArn in metadata
433+
* Otherwise, makes an STS GetCallerIdentity call using DER credentials and caches the result
434+
* @returns Promise resolving to the domain's AWS account ID
435+
* @throws ToolkitError if unable to retrieve account ID
436+
*/
437+
public async getDomainAccountId(): Promise<string> {
438+
const logger = getLogger()
439+
440+
// Return cached value if available
441+
if (this.cachedDomainAccountId) {
442+
logger.debug('SMUS: Using cached domain account ID')
443+
return this.cachedDomainAccountId
444+
}
445+
446+
// If in SMUS space environment, extract account ID from resource-metadata file
447+
if (getContext('aws.smus.inSmusSpaceEnvironment')) {
448+
try {
449+
logger.debug('SMUS: Extracting domain account ID from ResourceArn in resource-metadata file')
450+
451+
const resourceMetadata = getResourceMetadata()!
452+
const resourceArn = resourceMetadata.ResourceArn
453+
454+
if (!resourceArn) {
455+
throw new ToolkitError('ResourceArn not found in metadata file', {
456+
code: SmusErrorCodes.AccountIdNotFound,
457+
})
458+
}
459+
460+
// Extract account ID from ResourceArn using SmusUtils
461+
const accountId = extractAccountIdFromArn(resourceArn)
462+
463+
// Cache the account ID
464+
this.cachedDomainAccountId = accountId
465+
466+
logger.debug(
467+
`Successfully extracted and cached domain account ID from resource-metadata file: ${accountId}`
468+
)
469+
470+
return accountId
471+
} catch (err) {
472+
logger.error(`Failed to extract domain account ID from ResourceArn: %s`, err)
473+
474+
throw new ToolkitError('Failed to extract AWS account ID from ResourceArn in SMUS space environment', {
475+
code: SmusErrorCodes.GetDomainAccountIdFailed,
476+
cause: err instanceof Error ? err : undefined,
477+
})
478+
}
479+
}
480+
481+
if (!this.activeConnection) {
482+
throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection })
483+
}
484+
485+
// Use existing STS GetCallerIdentity implementation for non-SMUS space environments
486+
try {
487+
logger.debug('Fetching domain account ID via STS GetCallerIdentity')
488+
489+
// Get DER credentials provider
490+
const derCredProvider = await this.getDerCredentialsProvider()
491+
492+
// Get the region for STS client
493+
const region = this.getDomainRegion()
494+
495+
// Create STS client with DER credentials
496+
const stsClient = new DefaultStsClient(region, await derCredProvider.getCredentials())
497+
498+
// Make GetCallerIdentity call
499+
const callerIdentity = await stsClient.getCallerIdentity()
500+
501+
if (!callerIdentity.Account) {
502+
throw new ToolkitError('Account ID not found in STS GetCallerIdentity response', {
503+
code: SmusErrorCodes.AccountIdNotFound,
504+
})
505+
}
506+
507+
// Cache the account ID
508+
this.cachedDomainAccountId = callerIdentity.Account
509+
510+
logger.debug(`Successfully retrieved and cached domain account ID: ${callerIdentity.Account}`)
511+
512+
return callerIdentity.Account
513+
} catch (err) {
514+
logger.error(`Failed to retrieve domain account ID: %s`, err)
515+
516+
throw new ToolkitError('Failed to retrieve AWS account ID for active domain connection', {
517+
code: SmusErrorCodes.GetDomainAccountIdFailed,
518+
cause: err instanceof Error ? err : undefined,
519+
})
520+
}
521+
}
522+
426523
public getDomainRegion(): string {
427524
if (getContext('aws.smus.inSmusSpaceEnvironment')) {
428525
const resourceMetadata = getResourceMetadata()!
@@ -516,6 +613,10 @@ export class SmusAuthenticationProvider {
516613
logger.warn(`SMUS: Failed to invalidate connection credentials for cache key ${cacheKey}: %s`, err)
517614
}
518615
}
616+
617+
// Clear cached domain account ID
618+
this.cachedDomainAccountId = undefined
619+
logger.debug('SMUS: Cleared cached domain account ID')
519620
}
520621

521622
/**
@@ -560,6 +661,10 @@ export class SmusAuthenticationProvider {
560661
}
561662
}
562663
this.credentialsProviderCache.clear()
664+
665+
// Clear cached domain account ID
666+
this.cachedDomainAccountId = undefined
667+
563668
this.logger.debug('SMUS Auth: Successfully disposed authentication provider')
564669
}
565670

packages/core/src/sagemakerunifiedstudio/explorer/activation.ts

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ import { getLogger } from '../../shared/logger/logger'
2020
import { setSmusConnectedContext, SmusAuthenticationProvider } from '../auth/providers/smusAuthenticationProvider'
2121
import { setupUserActivityMonitoring } from '../../awsService/sagemaker/sagemakerSpace'
2222
import { telemetry } from '../../shared/telemetry/telemetry'
23-
import { SageMakerUnifiedStudioSpacesParentNode } from './nodes/sageMakerUnifiedStudioSpacesParentNode'
2423
import { isSageMaker } from '../../shared/extensionUtilities'
24+
import { recordSpaceTelemetry } from '../shared/telemetry'
2525

2626
export async function activate(extensionContext: vscode.ExtensionContext): Promise<void> {
2727
// Initialize the SMUS authentication provider
@@ -75,16 +75,7 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi
7575
return
7676
}
7777
await telemetry.smus_stopSpace.run(async (span) => {
78-
span.record({
79-
smusSpaceKey: node.resource.DomainSpaceKey,
80-
smusDomainRegion: node.resource.regionCode,
81-
smusDomainId: (
82-
node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
83-
)?.getAuthProvider()?.activeConnection?.domainId,
84-
smusProjectId: (
85-
node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
86-
)?.getProjectId(),
87-
})
78+
await recordSpaceTelemetry(span, node)
8879
await stopSpace(node.resource, extensionContext, node.resource.sageMakerClient)
8980
})
9081
}),
@@ -95,17 +86,8 @@ export async function activate(extensionContext: vscode.ExtensionContext): Promi
9586
if (!validateNode(node)) {
9687
return
9788
}
98-
await telemetry.smus_startSpace.run(async (span) => {
99-
span.record({
100-
smusSpaceKey: node.resource.DomainSpaceKey,
101-
smusDomainRegion: node.resource.regionCode,
102-
smusDomainId: (
103-
node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
104-
)?.getAuthProvider()?.activeConnection?.domainId,
105-
smusProjectId: (
106-
node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode
107-
)?.getProjectId(),
108-
})
89+
await telemetry.smus_openRemoteConnection.run(async (span) => {
90+
await recordSpaceTelemetry(span, node)
10991
await openRemoteConnect(node.resource, extensionContext, node.resource.sageMakerClient)
11092
})
11193
}

packages/core/src/sagemakerunifiedstudio/explorer/nodes/lakehouseStrategy.ts

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import { createPlaceholderItem } from '../../../shared/treeview/utils'
3434
import { Column, Database, Table } from '@aws-sdk/client-glue'
3535
import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider'
3636
import { telemetry } from '../../../shared/telemetry/telemetry'
37-
import { getContext } from '../../../shared/vscode/setContext'
37+
import { recordDataConnectionTelemetry } from '../../shared/telemetry'
3838

3939
/**
4040
* Lakehouse data node for SageMaker Unified Studio
@@ -152,16 +152,7 @@ export function createLakehouseConnectionNode(
152152
},
153153
async (node) => {
154154
return telemetry.smus_renderLakehouseNode.run(async (span) => {
155-
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')
156-
157-
span.record({
158-
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
159-
smusDomainId: connection.domainId,
160-
smusProjectId: connection.projectId,
161-
smusConnectionId: connection.connectionId,
162-
smusConnectionType: connection.type,
163-
smusProjectRegion: connection.location?.awsRegion,
164-
})
155+
await recordDataConnectionTelemetry(span, connection, connectionCredentialsProvider)
165156
try {
166157
logger.info(`Loading Lakehouse catalogs for connection ${connection.name}`)
167158

packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,11 @@ export function createRedshiftConnectionNode(
131131
async (node) => {
132132
return telemetry.smus_renderRedshiftNode.run(async (span) => {
133133
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')
134-
134+
const accountId = await connectionCredentialsProvider.getDomainAccountId()
135135
span.record({
136136
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
137137
smusDomainId: connection.domainId,
138+
smusDomainAccountId: accountId,
138139
smusProjectId: connection.projectId,
139140
smusConnectionId: connection.connectionId,
140141
smusConnectionType: connection.type,

packages/core/src/sagemakerunifiedstudio/explorer/nodes/s3Strategy.ts

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import {
2020
import { S3, ListObjectsV2Command } from '@aws-sdk/client-s3'
2121
import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider'
2222
import { telemetry } from '../../../shared/telemetry/telemetry'
23-
import { getContext } from '../../../shared/vscode/setContext'
23+
import { recordDataConnectionTelemetry } from '../../shared/telemetry'
2424

2525
// Regex to match default S3 connection names
2626
// eslint-disable-next-line @typescript-eslint/naming-convention
@@ -144,16 +144,7 @@ export function createS3ConnectionNode(
144144
},
145145
async (node) => {
146146
return telemetry.smus_renderS3Node.run(async (span) => {
147-
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')
148-
149-
span.record({
150-
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
151-
smusDomainId: connection.domainId,
152-
smusProjectId: connection.projectId,
153-
smusConnectionId: connection.connectionId,
154-
smusConnectionType: connection.type,
155-
smusProjectRegion: connection.location?.awsRegion,
156-
})
147+
await recordDataConnectionTelemetry(span, connection, connectionCredentialsProvider)
157148
try {
158149
if (isDefaultConnection && s3Info.prefix) {
159150
// For default connections, show the full path as the first node

packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,11 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode {
8282
return telemetry.smus_renderProjectChildrenNode.run(async (span) => {
8383
try {
8484
const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment')
85-
85+
const accountId = await this.authProvider.getDomainAccountId()
8686
span.record({
8787
smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local',
8888
smusDomainId: this.project?.domainId,
89+
smusDomainAccountId: accountId,
8990
smusProjectId: this.project?.id,
9091
smusDomainRegion: this.authProvider.getDomainRegion(),
9192
})

packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioRootNode.ts

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ import { telemetry } from '../../../shared/telemetry/telemetry'
1313
import { createQuickPick } from '../../../shared/ui/pickerPrompter'
1414
import { SageMakerUnifiedStudioProjectNode } from './sageMakerUnifiedStudioProjectNode'
1515
import { SageMakerUnifiedStudioAuthInfoNode } from './sageMakerUnifiedStudioAuthInfoNode'
16-
import { SmusUtils } from '../../shared/smusUtils'
16+
import { SmusErrorCodes, SmusUtils } from '../../shared/smusUtils'
1717
import { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticationProvider'
1818
import { ToolkitError } from '../../../../src/shared/errors'
19-
import { errorCode } from '../../shared/errors'
19+
import { recordAuthTelemetry } from '../../shared/telemetry'
2020

2121
const contextValueSmusRoot = 'sageMakerUnifiedStudioRoot'
2222
const contextValueSmusLogin = 'sageMakerUnifiedStudioLogin'
@@ -237,7 +237,10 @@ export const smusLoginCommand = Commands.declare('aws.smus.login', () => async (
237237
if (!domainUrl) {
238238
// User cancelled
239239
logger.debug('User cancelled domain URL input')
240-
return
240+
throw new ToolkitError('User cancelled domain URL input', {
241+
cancelled: true,
242+
code: SmusErrorCodes.UserCancelled,
243+
})
241244
}
242245

243246
// Show a simple status bar message instead of progress dialog
@@ -252,19 +255,16 @@ export const smusLoginCommand = Commands.declare('aws.smus.login', () => async (
252255

253256
if (!connection) {
254257
throw new ToolkitError('Failed to establish connection', {
255-
code: errorCode.failedAuthConnecton,
258+
code: SmusErrorCodes.FailedAuthConnecton,
256259
})
257260
}
258261

259-
// Extract domain ID and region for logging
262+
// Extract domain account ID, domain ID, and region for logging
260263
const domainId = connection.domainId
261264
const region = connection.ssoRegion
262265

263266
logger.info(`Connected to SageMaker Unified Studio domain: ${domainId} in region ${region}`)
264-
span.record({
265-
smusDomainId: domainId,
266-
awsRegion: region,
267-
})
267+
await recordAuthTelemetry(span, authProvider, domainId, region)
268268

269269
// Show success message
270270
void vscode.window.showInformationMessage(
@@ -292,9 +292,12 @@ export const smusLoginCommand = Commands.declare('aws.smus.login', () => async (
292292
})
293293
}
294294
} catch (err) {
295-
void vscode.window.showErrorMessage(
296-
`SageMaker Unified Studio: Failed to initiate login: ${(err as Error).message}`
297-
)
295+
const isUserCancelled = err instanceof ToolkitError && err.code === SmusErrorCodes.UserCancelled
296+
if (!isUserCancelled) {
297+
void vscode.window.showErrorMessage(
298+
`SageMaker Unified Studio: Failed to initiate login: ${(err as Error).message}`
299+
)
300+
}
298301
logger.error('Failed to initiate login: %s', (err as Error).message)
299302
throw new ToolkitError('Failed to initiate login.', {
300303
cause: err as Error,
@@ -329,11 +332,7 @@ export const smusSignOutCommand = Commands.declare('aws.smus.signOut', () => asy
329332

330333
// Show status message
331334
vscode.window.setStatusBarMessage('Signing out from SageMaker Unified Studio...', 5000)
332-
333-
span.record({
334-
smusDomainId: domainId,
335-
awsRegion: region,
336-
})
335+
await recordAuthTelemetry(span, authProvider, domainId, region)
337336

338337
// Delete the connection (this will also invalidate tokens and clear cache)
339338
if (activeConnection) {
@@ -425,10 +424,12 @@ export async function selectSMUSProject(projectNode?: SageMakerUnifiedStudioProj
425424
}
426425

427426
const selectedProject = await showQuickPick(items)
427+
const accountId = await authProvider.getDomainAccountId()
428428
span.record({
429429
smusDomainId: authProvider.getDomainId(),
430430
smusProjectId: (selectedProject as DataZoneProject).id as string | undefined,
431431
smusDomainRegion: authProvider.getDomainRegion(),
432+
smusDomainAccountId: accountId,
432433
})
433434
if (
434435
selectedProject &&

packages/core/src/sagemakerunifiedstudio/shared/errors.ts

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)