@@ -14,7 +14,7 @@ import * as localizedText from '../../../shared/localizedText'
14
14
import { ToolkitPromptSettings } from '../../../shared/settings'
15
15
import { setContext , getContext } from '../../../shared/vscode/setContext'
16
16
import { getLogger } from '../../../shared/logger/logger'
17
- import { SmusUtils , SmusErrorCodes , extractAccountIdFromArn } from '../../shared/smusUtils'
17
+ import { SmusUtils , SmusErrorCodes , extractAccountIdFromResourceMetadata } from '../../shared/smusUtils'
18
18
import { createSmusProfile , isValidSmusConnection , SmusConnection } from '../model'
19
19
import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsProvider'
20
20
import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider'
@@ -24,6 +24,7 @@ import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils'
24
24
import { fromIni } from '@aws-sdk/credential-providers'
25
25
import { randomUUID } from '../../../shared/crypto'
26
26
import { DefaultStsClient } from '../../../shared/clients/stsClient'
27
+ import { DataZoneClient } from '../../shared/client/datazoneClient'
27
28
28
29
/**
29
30
* Sets the context variable for SageMaker Unified Studio connection state
@@ -55,6 +56,7 @@ export class SmusAuthenticationProvider {
55
56
private projectCredentialProvidersCache = new Map < string , ProjectRoleCredentialsProvider > ( )
56
57
private connectionCredentialProvidersCache = new Map < string , ConnectionCredentialsProvider > ( )
57
58
private cachedDomainAccountId : string | undefined
59
+ private cachedProjectAccountIds = new Map < string , string > ( )
58
60
59
61
public constructor (
60
62
public readonly auth = Auth . instance ,
@@ -79,6 +81,8 @@ export class SmusAuthenticationProvider {
79
81
this . connectionCredentialProvidersCache . clear ( )
80
82
// Clear cached domain account ID when connection changes
81
83
this . cachedDomainAccountId = undefined
84
+ // Clear cached project account IDs when connection changes
85
+ this . cachedProjectAccountIds . clear ( )
82
86
// Clear all clients in client store when connection changes
83
87
ConnectionClientStore . getInstance ( ) . clearAll ( )
84
88
await setSmusConnectedContext ( this . isConnected ( ) )
@@ -445,37 +449,13 @@ export class SmusAuthenticationProvider {
445
449
446
450
// If in SMUS space environment, extract account ID from resource-metadata file
447
451
if ( getContext ( 'aws.smus.inSmusSpaceEnvironment' ) ) {
448
- try {
449
- logger . debug ( 'SMUS: Extracting domain account ID from ResourceArn in resource-metadata file' )
450
-
451
- const resourceMetadata = getResourceMetadata ( ) !
452
- const resourceArn = resourceMetadata . ResourceArn
453
-
454
- if ( ! resourceArn ) {
455
- throw new ToolkitError ( 'ResourceArn not found in metadata file' , {
456
- code : SmusErrorCodes . AccountIdNotFound ,
457
- } )
458
- }
452
+ const accountId = await extractAccountIdFromResourceMetadata ( )
459
453
460
- // Extract account ID from ResourceArn using SmusUtils
461
- const accountId = extractAccountIdFromArn ( resourceArn )
462
-
463
- // Cache the account ID
464
- this . cachedDomainAccountId = accountId
465
-
466
- logger . debug (
467
- `Successfully extracted and cached domain account ID from resource-metadata file: ${ accountId } `
468
- )
469
-
470
- return accountId
471
- } catch ( err ) {
472
- logger . error ( `Failed to extract domain account ID from ResourceArn: %s` , err )
454
+ // Cache the account ID
455
+ this . cachedDomainAccountId = accountId
456
+ logger . debug ( `Successfully cached domain account ID: ${ accountId } ` )
473
457
474
- throw new ToolkitError ( 'Failed to extract AWS account ID from ResourceArn in SMUS space environment' , {
475
- code : SmusErrorCodes . GetDomainAccountIdFailed ,
476
- cause : err instanceof Error ? err : undefined ,
477
- } )
478
- }
458
+ return accountId
479
459
}
480
460
481
461
if ( ! this . activeConnection ) {
@@ -520,6 +500,81 @@ export class SmusAuthenticationProvider {
520
500
}
521
501
}
522
502
503
+ /**
504
+ * Gets the AWS account ID for a specific project using project credentials
505
+ * In SMUS space environment, extracts from ResourceArn in metadata (same as domain account)
506
+ * Otherwise, makes an STS GetCallerIdentity call using project credentials
507
+ * @param projectId The DataZone project ID
508
+ * @returns Promise resolving to the project's AWS account ID
509
+ */
510
+ public async getProjectAccountId ( projectId : string ) : Promise < string > {
511
+ const logger = getLogger ( )
512
+
513
+ // Return cached value if available
514
+ if ( this . cachedProjectAccountIds . has ( projectId ) ) {
515
+ logger . debug ( `SMUS: Using cached project account ID for project ${ projectId } ` )
516
+ return this . cachedProjectAccountIds . get ( projectId ) !
517
+ }
518
+
519
+ // If in SMUS space environment, extract account ID from resource-metadata file
520
+ if ( getContext ( 'aws.smus.inSmusSpaceEnvironment' ) ) {
521
+ const accountId = await extractAccountIdFromResourceMetadata ( )
522
+
523
+ // Cache the account ID
524
+ this . cachedProjectAccountIds . set ( projectId , accountId )
525
+ logger . debug ( `Successfully cached project account ID for project ${ projectId } : ${ accountId } ` )
526
+
527
+ return accountId
528
+ }
529
+
530
+ if ( ! this . activeConnection ) {
531
+ throw new ToolkitError ( 'No active SMUS connection available' , { code : SmusErrorCodes . NoActiveConnection } )
532
+ }
533
+
534
+ // For non-SMUS space environments, use project credentials with STS
535
+ try {
536
+ logger . debug ( 'Fetching project account ID via STS GetCallerIdentity with project credentials' )
537
+
538
+ // Get project credentials
539
+ const projectCredProvider = await this . getProjectCredentialProvider ( projectId )
540
+ const projectCreds = await projectCredProvider . getCredentials ( )
541
+
542
+ // Get project region from tooling environment
543
+ const dzClient = await DataZoneClient . getInstance ( this )
544
+ const toolingEnv = await dzClient . getToolingEnvironment ( projectId )
545
+ const projectRegion = toolingEnv . awsAccountRegion
546
+
547
+ if ( ! projectRegion ) {
548
+ throw new ToolkitError ( 'No AWS account region found in tooling environment' , {
549
+ code : SmusErrorCodes . RegionNotFound ,
550
+ } )
551
+ }
552
+
553
+ // Use STS to get account ID from project credentials
554
+ const stsClient = new DefaultStsClient ( projectRegion , projectCreds )
555
+ const callerIdentity = await stsClient . getCallerIdentity ( )
556
+
557
+ if ( ! callerIdentity . Account ) {
558
+ throw new ToolkitError ( 'Account ID not found in STS GetCallerIdentity response' , {
559
+ code : SmusErrorCodes . AccountIdNotFound ,
560
+ } )
561
+ }
562
+
563
+ // Cache the account ID
564
+ this . cachedProjectAccountIds . set ( projectId , callerIdentity . Account )
565
+ logger . debug (
566
+ `Successfully retrieved and cached project account ID for project ${ projectId } : ${ callerIdentity . Account } `
567
+ )
568
+
569
+ return callerIdentity . Account
570
+ } catch ( err ) {
571
+ logger . error ( 'Failed to get project account ID: %s' , err as Error )
572
+ throw new ToolkitError ( `Failed to get project account ID: ${ ( err as Error ) . message } ` , {
573
+ code : SmusErrorCodes . GetProjectAccountIdFailed ,
574
+ } )
575
+ }
576
+ }
577
+
523
578
public getDomainRegion ( ) : string {
524
579
if ( getContext ( 'aws.smus.inSmusSpaceEnvironment' ) ) {
525
580
const resourceMetadata = getResourceMetadata ( ) !
@@ -617,6 +672,10 @@ export class SmusAuthenticationProvider {
617
672
// Clear cached domain account ID
618
673
this . cachedDomainAccountId = undefined
619
674
logger . debug ( 'SMUS: Cleared cached domain account ID' )
675
+
676
+ // Clear cached project account IDs
677
+ this . cachedProjectAccountIds . clear ( )
678
+ logger . debug ( 'SMUS: Cleared cached project account IDs' )
620
679
}
621
680
622
681
/**
@@ -665,6 +724,9 @@ export class SmusAuthenticationProvider {
665
724
// Clear cached domain account ID
666
725
this . cachedDomainAccountId = undefined
667
726
727
+ // Clear cached project account IDs
728
+ this . cachedProjectAccountIds . clear ( )
729
+
668
730
this . logger . debug ( 'SMUS Auth: Successfully disposed authentication provider' )
669
731
}
670
732
0 commit comments