@@ -14,7 +14,7 @@ import * as localizedText from '../../../shared/localizedText'
1414import { ToolkitPromptSettings } from '../../../shared/settings'
1515import { setContext , getContext } from '../../../shared/vscode/setContext'
1616import { getLogger } from '../../../shared/logger/logger'
17- import { SmusUtils , SmusErrorCodes , extractAccountIdFromArn } from '../../shared/smusUtils'
17+ import { SmusUtils , SmusErrorCodes , extractAccountIdFromResourceMetadata } from '../../shared/smusUtils'
1818import { createSmusProfile , isValidSmusConnection , SmusConnection } from '../model'
1919import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsProvider'
2020import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider'
@@ -24,6 +24,7 @@ import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils'
2424import { fromIni } from '@aws-sdk/credential-providers'
2525import { randomUUID } from '../../../shared/crypto'
2626import { DefaultStsClient } from '../../../shared/clients/stsClient'
27+ import { DataZoneClient } from '../../shared/client/datazoneClient'
2728
2829/**
2930 * Sets the context variable for SageMaker Unified Studio connection state
@@ -55,6 +56,7 @@ export class SmusAuthenticationProvider {
5556 private projectCredentialProvidersCache = new Map < string , ProjectRoleCredentialsProvider > ( )
5657 private connectionCredentialProvidersCache = new Map < string , ConnectionCredentialsProvider > ( )
5758 private cachedDomainAccountId : string | undefined
59+ private cachedProjectAccountIds = new Map < string , string > ( )
5860
5961 public constructor (
6062 public readonly auth = Auth . instance ,
@@ -79,6 +81,8 @@ export class SmusAuthenticationProvider {
7981 this . connectionCredentialProvidersCache . clear ( )
8082 // Clear cached domain account ID when connection changes
8183 this . cachedDomainAccountId = undefined
84+ // Clear cached project account IDs when connection changes
85+ this . cachedProjectAccountIds . clear ( )
8286 // Clear all clients in client store when connection changes
8387 ConnectionClientStore . getInstance ( ) . clearAll ( )
8488 await setSmusConnectedContext ( this . isConnected ( ) )
@@ -445,37 +449,13 @@ export class SmusAuthenticationProvider {
445449
446450 // If in SMUS space environment, extract account ID from resource-metadata file
447451 if ( getContext ( 'aws.smus.inSmusSpaceEnvironment' ) ) {
448- try {
449- logger . debug ( 'SMUS: Extracting domain account ID from ResourceArn in resource-metadata file' )
450-
451- const resourceMetadata = getResourceMetadata ( ) !
452- const resourceArn = resourceMetadata . ResourceArn
453-
454- if ( ! resourceArn ) {
455- throw new ToolkitError ( 'ResourceArn not found in metadata file' , {
456- code : SmusErrorCodes . AccountIdNotFound ,
457- } )
458- }
452+ const accountId = await extractAccountIdFromResourceMetadata ( )
459453
460- // Extract account ID from ResourceArn using SmusUtils
461- const accountId = extractAccountIdFromArn ( resourceArn )
462-
463- // Cache the account ID
464- this . cachedDomainAccountId = accountId
465-
466- logger . debug (
467- `Successfully extracted and cached domain account ID from resource-metadata file: ${ accountId } `
468- )
469-
470- return accountId
471- } catch ( err ) {
472- logger . error ( `Failed to extract domain account ID from ResourceArn: %s` , err )
454+ // Cache the account ID
455+ this . cachedDomainAccountId = accountId
456+ logger . debug ( `Successfully cached domain account ID: ${ accountId } ` )
473457
474- throw new ToolkitError ( 'Failed to extract AWS account ID from ResourceArn in SMUS space environment' , {
475- code : SmusErrorCodes . GetDomainAccountIdFailed ,
476- cause : err instanceof Error ? err : undefined ,
477- } )
478- }
458+ return accountId
479459 }
480460
481461 if ( ! this . activeConnection ) {
@@ -520,6 +500,81 @@ export class SmusAuthenticationProvider {
520500 }
521501 }
522502
503+ /**
504+ * Gets the AWS account ID for a specific project using project credentials
505+ * In SMUS space environment, extracts from ResourceArn in metadata (same as domain account)
506+ * Otherwise, makes an STS GetCallerIdentity call using project credentials
507+ * @param projectId The DataZone project ID
508+ * @returns Promise resolving to the project's AWS account ID
509+ */
510+ public async getProjectAccountId ( projectId : string ) : Promise < string > {
511+ const logger = getLogger ( )
512+
513+ // Return cached value if available
514+ if ( this . cachedProjectAccountIds . has ( projectId ) ) {
515+ logger . debug ( `SMUS: Using cached project account ID for project ${ projectId } ` )
516+ return this . cachedProjectAccountIds . get ( projectId ) !
517+ }
518+
519+ // If in SMUS space environment, extract account ID from resource-metadata file
520+ if ( getContext ( 'aws.smus.inSmusSpaceEnvironment' ) ) {
521+ const accountId = await extractAccountIdFromResourceMetadata ( )
522+
523+ // Cache the account ID
524+ this . cachedProjectAccountIds . set ( projectId , accountId )
525+ logger . debug ( `Successfully cached project account ID for project ${ projectId } : ${ accountId } ` )
526+
527+ return accountId
528+ }
529+
530+ if ( ! this . activeConnection ) {
531+ throw new ToolkitError ( 'No active SMUS connection available' , { code : SmusErrorCodes . NoActiveConnection } )
532+ }
533+
534+ // For non-SMUS space environments, use project credentials with STS
535+ try {
536+ logger . debug ( 'Fetching project account ID via STS GetCallerIdentity with project credentials' )
537+
538+ // Get project credentials
539+ const projectCredProvider = await this . getProjectCredentialProvider ( projectId )
540+ const projectCreds = await projectCredProvider . getCredentials ( )
541+
542+ // Get project region from tooling environment
543+ const dzClient = await DataZoneClient . getInstance ( this )
544+ const toolingEnv = await dzClient . getToolingEnvironment ( projectId )
545+ const projectRegion = toolingEnv . awsAccountRegion
546+
547+ if ( ! projectRegion ) {
548+ throw new ToolkitError ( 'No AWS account region found in tooling environment' , {
549+ code : SmusErrorCodes . RegionNotFound ,
550+ } )
551+ }
552+
553+ // Use STS to get account ID from project credentials
554+ const stsClient = new DefaultStsClient ( projectRegion , projectCreds )
555+ const callerIdentity = await stsClient . getCallerIdentity ( )
556+
557+ if ( ! callerIdentity . Account ) {
558+ throw new ToolkitError ( 'Account ID not found in STS GetCallerIdentity response' , {
559+ code : SmusErrorCodes . AccountIdNotFound ,
560+ } )
561+ }
562+
563+ // Cache the account ID
564+ this . cachedProjectAccountIds . set ( projectId , callerIdentity . Account )
565+ logger . debug (
566+ `Successfully retrieved and cached project account ID for project ${ projectId } : ${ callerIdentity . Account } `
567+ )
568+
569+ return callerIdentity . Account
570+ } catch ( err ) {
571+ logger . error ( 'Failed to get project account ID: %s' , err as Error )
572+ throw new ToolkitError ( `Failed to get project account ID: ${ ( err as Error ) . message } ` , {
573+ code : SmusErrorCodes . GetProjectAccountIdFailed ,
574+ } )
575+ }
576+ }
577+
523578 public getDomainRegion ( ) : string {
524579 if ( getContext ( 'aws.smus.inSmusSpaceEnvironment' ) ) {
525580 const resourceMetadata = getResourceMetadata ( ) !
@@ -617,6 +672,10 @@ export class SmusAuthenticationProvider {
617672 // Clear cached domain account ID
618673 this . cachedDomainAccountId = undefined
619674 logger . debug ( 'SMUS: Cleared cached domain account ID' )
675+
676+ // Clear cached project account IDs
677+ this . cachedProjectAccountIds . clear ( )
678+ logger . debug ( 'SMUS: Cleared cached project account IDs' )
620679 }
621680
622681 /**
@@ -665,6 +724,9 @@ export class SmusAuthenticationProvider {
665724 // Clear cached domain account ID
666725 this . cachedDomainAccountId = undefined
667726
727+ // Clear cached project account IDs
728+ this . cachedProjectAccountIds . clear ( )
729+
668730 this . logger . debug ( 'SMUS Auth: Successfully disposed authentication provider' )
669731 }
670732
0 commit comments