@@ -14,7 +14,7 @@ import * as localizedText from '../../../shared/localizedText'
1414import { ToolkitPromptSettings } from '../../../shared/settings'
1515import { setContext , getContext } from '../../../shared/vscode/setContext'
1616import { getLogger } from '../../../shared/logger/logger'
17- import { SmusUtils , SmusErrorCodes , extractAccountIdFromArn } from '../../shared/smusUtils'
17+ import { SmusUtils , SmusErrorCodes , extractAccountIdFromSageMakerArn } from '../../shared/smusUtils'
1818import { createSmusProfile , isValidSmusConnection , SmusConnection } from '../model'
1919import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsProvider'
2020import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider'
@@ -24,6 +24,7 @@ import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils'
2424import { fromIni } from '@aws-sdk/credential-providers'
2525import { randomUUID } from '../../../shared/crypto'
2626import { DefaultStsClient } from '../../../shared/clients/stsClient'
27+ import { DataZoneClient } from '../../shared/client/datazoneClient'
2728
2829/**
2930 * Sets the context variable for SageMaker Unified Studio connection state
@@ -55,6 +56,7 @@ export class SmusAuthenticationProvider {
5556 private projectCredentialProvidersCache = new Map < string , ProjectRoleCredentialsProvider > ( )
5657 private connectionCredentialProvidersCache = new Map < string , ConnectionCredentialsProvider > ( )
5758 private cachedDomainAccountId : string | undefined
59+ private cachedProjectAccountIds = new Map < string , string > ( )
5860
5961 public constructor (
6062 public readonly auth = Auth . instance ,
@@ -79,6 +81,8 @@ export class SmusAuthenticationProvider {
7981 this . connectionCredentialProvidersCache . clear ( )
8082 // Clear cached domain account ID when connection changes
8183 this . cachedDomainAccountId = undefined
84+ // Clear cached project account IDs when connection changes
85+ this . cachedProjectAccountIds . clear ( )
8286 // Clear all clients in client store when connection changes
8387 ConnectionClientStore . getInstance ( ) . clearAll ( )
8488 await setSmusConnectedContext ( this . isConnected ( ) )
@@ -427,6 +431,34 @@ export class SmusAuthenticationProvider {
427431 return this . activeConnection . domainUrl
428432 }
429433
434+ /**
435+ * Extracts account ID from ResourceArn in SMUS space environment
436+ * @returns Promise resolving to the account ID
437+ * @throws ToolkitError if unable to extract account ID
438+ */
439+ private async extractAccountIdFromResourceMetadata ( ) : Promise < string > {
440+ const logger = getLogger ( )
441+
442+ try {
443+ logger . debug ( 'SMUS: Extracting account ID from ResourceArn in resource-metadata file' )
444+
445+ const resourceMetadata = getResourceMetadata ( ) !
446+ const resourceArn = resourceMetadata . ResourceArn
447+
448+ if ( ! resourceArn ) {
449+ throw new Error ( 'ResourceArn not found in metadata file' )
450+ }
451+
452+ const accountId = extractAccountIdFromSageMakerArn ( resourceArn )
453+ logger . debug ( `Successfully extracted account ID from resource-metadata file: ${ accountId } ` )
454+
455+ return accountId
456+ } catch ( err ) {
457+ logger . error ( `Failed to extract account ID from ResourceArn: %s` , err )
458+ throw new Error ( 'Failed to extract AWS account ID from ResourceArn in SMUS space environment' )
459+ }
460+ }
461+
430462 /**
431463 * Gets the AWS account ID for the active domain connection
432464 * In SMUS space environment, extracts from ResourceArn in metadata
@@ -445,37 +477,13 @@ export class SmusAuthenticationProvider {
445477
446478 // If in SMUS space environment, extract account ID from resource-metadata file
447479 if ( getContext ( 'aws.smus.inSmusSpaceEnvironment' ) ) {
448- try {
449- logger . debug ( 'SMUS: Extracting domain account ID from ResourceArn in resource-metadata file' )
450-
451- const resourceMetadata = getResourceMetadata ( ) !
452- const resourceArn = resourceMetadata . ResourceArn
453-
454- if ( ! resourceArn ) {
455- throw new ToolkitError ( 'ResourceArn not found in metadata file' , {
456- code : SmusErrorCodes . AccountIdNotFound ,
457- } )
458- }
459-
460- // Extract account ID from ResourceArn using SmusUtils
461- const accountId = extractAccountIdFromArn ( resourceArn )
462-
463- // Cache the account ID
464- this . cachedDomainAccountId = accountId
480+ const accountId = await this . extractAccountIdFromResourceMetadata ( )
465481
466- logger . debug (
467- `Successfully extracted and cached domain account ID from resource-metadata file: ${ accountId } `
468- )
469-
470- return accountId
471- } catch ( err ) {
472- logger . error ( `Failed to extract domain account ID from ResourceArn: %s` , err )
482+ // Cache the account ID
483+ this . cachedDomainAccountId = accountId
484+ logger . debug ( `Successfully cached domain account ID: ${ accountId } ` )
473485
474- throw new ToolkitError ( 'Failed to extract AWS account ID from ResourceArn in SMUS space environment' , {
475- code : SmusErrorCodes . GetDomainAccountIdFailed ,
476- cause : err instanceof Error ? err : undefined ,
477- } )
478- }
486+ return accountId
479487 }
480488
481489 if ( ! this . activeConnection ) {
@@ -520,6 +528,81 @@ export class SmusAuthenticationProvider {
520528 }
521529 }
522530
531+ /**
532+ * Gets the AWS account ID for a specific project using project credentials
533+ * In SMUS space environment, extracts from ResourceArn in metadata (same as domain account)
534+ * Otherwise, makes an STS GetCallerIdentity call using project credentials
535+ * @param projectId The DataZone project ID
536+ * @returns Promise resolving to the project's AWS account ID
537+ */
538+ public async getProjectAccountId ( projectId : string ) : Promise < string > {
539+ const logger = getLogger ( )
540+
541+ // Return cached value if available
542+ if ( this . cachedProjectAccountIds . has ( projectId ) ) {
543+ logger . debug ( `SMUS: Using cached project account ID for project ${ projectId } ` )
544+ return this . cachedProjectAccountIds . get ( projectId ) !
545+ }
546+
547+ // If in SMUS space environment, extract account ID from resource-metadata file
548+ if ( getContext ( 'aws.smus.inSmusSpaceEnvironment' ) ) {
549+ const accountId = await this . extractAccountIdFromResourceMetadata ( )
550+
551+ // Cache the account ID
552+ this . cachedProjectAccountIds . set ( projectId , accountId )
553+ logger . debug ( `Successfully cached project account ID for project ${ projectId } : ${ accountId } ` )
554+
555+ return accountId
556+ }
557+
558+ if ( ! this . activeConnection ) {
559+ throw new ToolkitError ( 'No active SMUS connection available' , { code : SmusErrorCodes . NoActiveConnection } )
560+ }
561+
562+ // For non-SMUS space environments, use project credentials with STS
563+ try {
564+ logger . debug ( 'Fetching project account ID via STS GetCallerIdentity with project credentials' )
565+
566+ // Get project credentials
567+ const projectCredProvider = await this . getProjectCredentialProvider ( projectId )
568+ const projectCreds = await projectCredProvider . getCredentials ( )
569+
570+ // Get project region from tooling environment
571+ const dzClient = await DataZoneClient . getInstance ( this )
572+ const toolingEnv = await dzClient . getToolingEnvironment ( projectId )
573+ const projectRegion = toolingEnv . awsAccountRegion
574+
575+ if ( ! projectRegion ) {
576+ throw new ToolkitError ( 'No AWS account region found in tooling environment' , {
577+ code : SmusErrorCodes . RegionNotFound ,
578+ } )
579+ }
580+
581+ // Use STS to get account ID from project credentials
582+ const stsClient = new DefaultStsClient ( projectRegion , projectCreds )
583+ const callerIdentity = await stsClient . getCallerIdentity ( )
584+
585+ if ( ! callerIdentity . Account ) {
586+ throw new ToolkitError ( 'Account ID not found in STS GetCallerIdentity response' , {
587+ code : SmusErrorCodes . AccountIdNotFound ,
588+ } )
589+ }
590+
591+ // Cache the account ID
592+ this . cachedProjectAccountIds . set ( projectId , callerIdentity . Account )
593+ logger . debug (
594+ `Successfully retrieved and cached project account ID for project ${ projectId } : ${ callerIdentity . Account } `
595+ )
596+
597+ return callerIdentity . Account
598+ } catch ( err ) {
599+ logger . error ( 'Failed to get project account ID: %s' , err as Error )
600+ throw new ToolkitError ( `Failed to get project account ID: ${ ( err as Error ) . message } ` , {
601+ code : SmusErrorCodes . GetProjectAccountIdFailed ,
602+ } )
603+ }
604+ }
605+
523606 public getDomainRegion ( ) : string {
524607 if ( getContext ( 'aws.smus.inSmusSpaceEnvironment' ) ) {
525608 const resourceMetadata = getResourceMetadata ( ) !
@@ -617,6 +700,10 @@ export class SmusAuthenticationProvider {
617700 // Clear cached domain account ID
618701 this . cachedDomainAccountId = undefined
619702 logger . debug ( 'SMUS: Cleared cached domain account ID' )
703+
704+ // Clear cached project account IDs
705+ this . cachedProjectAccountIds . clear ( )
706+ logger . debug ( 'SMUS: Cleared cached project account IDs' )
620707 }
621708
622709 /**
@@ -665,6 +752,9 @@ export class SmusAuthenticationProvider {
665752 // Clear cached domain account ID
666753 this . cachedDomainAccountId = undefined
667754
755+ // Clear cached project account IDs
756+ this . cachedProjectAccountIds . clear ( )
757+
668758 this . logger . debug ( 'SMUS Auth: Successfully disposed authentication provider' )
669759 }
670760
0 commit comments