diff --git a/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts b/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts index 8aadda5fcc0..6c0f204cbd3 100644 --- a/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts +++ b/packages/core/src/sagemakerunifiedstudio/auth/providers/smusAuthenticationProvider.ts @@ -14,7 +14,7 @@ import * as localizedText from '../../../shared/localizedText' import { ToolkitPromptSettings } from '../../../shared/settings' import { setContext, getContext } from '../../../shared/vscode/setContext' import { getLogger } from '../../../shared/logger/logger' -import { SmusUtils, SmusErrorCodes, extractAccountIdFromArn } from '../../shared/smusUtils' +import { SmusUtils, SmusErrorCodes, extractAccountIdFromResourceMetadata } from '../../shared/smusUtils' import { createSmusProfile, isValidSmusConnection, SmusConnection } from '../model' import { DomainExecRoleCredentialsProvider } from './domainExecRoleCredentialsProvider' import { ProjectRoleCredentialsProvider } from './projectRoleCredentialsProvider' @@ -24,6 +24,7 @@ import { getResourceMetadata } from '../../shared/utils/resourceMetadataUtils' import { fromIni } from '@aws-sdk/credential-providers' import { randomUUID } from '../../../shared/crypto' import { DefaultStsClient } from '../../../shared/clients/stsClient' +import { DataZoneClient } from '../../shared/client/datazoneClient' /** * Sets the context variable for SageMaker Unified Studio connection state @@ -55,6 +56,7 @@ export class SmusAuthenticationProvider { private projectCredentialProvidersCache = new Map() private connectionCredentialProvidersCache = new Map() private cachedDomainAccountId: string | undefined + private cachedProjectAccountIds = new Map() public constructor( public readonly auth = Auth.instance, @@ -79,6 +81,8 @@ export class SmusAuthenticationProvider { this.connectionCredentialProvidersCache.clear() // Clear cached domain account ID when connection changes this.cachedDomainAccountId = undefined + // Clear cached project account IDs when connection changes + this.cachedProjectAccountIds.clear() // Clear all clients in client store when connection changes ConnectionClientStore.getInstance().clearAll() await setSmusConnectedContext(this.isConnected()) @@ -445,37 +449,13 @@ export class SmusAuthenticationProvider { // If in SMUS space environment, extract account ID from resource-metadata file if (getContext('aws.smus.inSmusSpaceEnvironment')) { - try { - logger.debug('SMUS: Extracting domain account ID from ResourceArn in resource-metadata file') - - const resourceMetadata = getResourceMetadata()! - const resourceArn = resourceMetadata.ResourceArn - - if (!resourceArn) { - throw new ToolkitError('ResourceArn not found in metadata file', { - code: SmusErrorCodes.AccountIdNotFound, - }) - } + const accountId = await extractAccountIdFromResourceMetadata() - // Extract account ID from ResourceArn using SmusUtils - const accountId = extractAccountIdFromArn(resourceArn) - - // Cache the account ID - this.cachedDomainAccountId = accountId - - logger.debug( - `Successfully extracted and cached domain account ID from resource-metadata file: ${accountId}` - ) - - return accountId - } catch (err) { - logger.error(`Failed to extract domain account ID from ResourceArn: %s`, err) + // Cache the account ID + this.cachedDomainAccountId = accountId + logger.debug(`Successfully cached domain account ID: ${accountId}`) - throw new ToolkitError('Failed to extract AWS account ID from ResourceArn in SMUS space environment', { - code: SmusErrorCodes.GetDomainAccountIdFailed, - cause: err instanceof Error ? err : undefined, - }) - } + return accountId } if (!this.activeConnection) { @@ -520,6 +500,81 @@ export class SmusAuthenticationProvider { } } + /** + * Gets the AWS account ID for a specific project using project credentials + * In SMUS space environment, extracts from ResourceArn in metadata (same as domain account) + * Otherwise, makes an STS GetCallerIdentity call using project credentials + * @param projectId The DataZone project ID + * @returns Promise resolving to the project's AWS account ID + */ + public async getProjectAccountId(projectId: string): Promise { + const logger = getLogger() + + // Return cached value if available + if (this.cachedProjectAccountIds.has(projectId)) { + logger.debug(`SMUS: Using cached project account ID for project ${projectId}`) + return this.cachedProjectAccountIds.get(projectId)! + } + + // If in SMUS space environment, extract account ID from resource-metadata file + if (getContext('aws.smus.inSmusSpaceEnvironment')) { + const accountId = await extractAccountIdFromResourceMetadata() + + // Cache the account ID + this.cachedProjectAccountIds.set(projectId, accountId) + logger.debug(`Successfully cached project account ID for project ${projectId}: ${accountId}`) + + return accountId + } + + if (!this.activeConnection) { + throw new ToolkitError('No active SMUS connection available', { code: SmusErrorCodes.NoActiveConnection }) + } + + // For non-SMUS space environments, use project credentials with STS + try { + logger.debug('Fetching project account ID via STS GetCallerIdentity with project credentials') + + // Get project credentials + const projectCredProvider = await this.getProjectCredentialProvider(projectId) + const projectCreds = await projectCredProvider.getCredentials() + + // Get project region from tooling environment + const dzClient = await DataZoneClient.getInstance(this) + const toolingEnv = await dzClient.getToolingEnvironment(projectId) + const projectRegion = toolingEnv.awsAccountRegion + + if (!projectRegion) { + throw new ToolkitError('No AWS account region found in tooling environment', { + code: SmusErrorCodes.RegionNotFound, + }) + } + + // Use STS to get account ID from project credentials + const stsClient = new DefaultStsClient(projectRegion, projectCreds) + const callerIdentity = await stsClient.getCallerIdentity() + + if (!callerIdentity.Account) { + throw new ToolkitError('Account ID not found in STS GetCallerIdentity response', { + code: SmusErrorCodes.AccountIdNotFound, + }) + } + + // Cache the account ID + this.cachedProjectAccountIds.set(projectId, callerIdentity.Account) + logger.debug( + `Successfully retrieved and cached project account ID for project ${projectId}: ${callerIdentity.Account}` + ) + + return callerIdentity.Account + } catch (err) { + logger.error('Failed to get project account ID: %s', err as Error) + throw new ToolkitError(`Failed to get project account ID: ${(err as Error).message}`, { + code: SmusErrorCodes.GetProjectAccountIdFailed, + }) + } + } + public getDomainRegion(): string { if (getContext('aws.smus.inSmusSpaceEnvironment')) { const resourceMetadata = getResourceMetadata()! @@ -617,6 +672,10 @@ export class SmusAuthenticationProvider { // Clear cached domain account ID this.cachedDomainAccountId = undefined logger.debug('SMUS: Cleared cached domain account ID') + + // Clear cached project account IDs + this.cachedProjectAccountIds.clear() + logger.debug('SMUS: Cleared cached project account IDs') } /** @@ -665,6 +724,9 @@ export class SmusAuthenticationProvider { // Clear cached domain account ID this.cachedDomainAccountId = undefined + // Clear cached project account IDs + this.cachedProjectAccountIds.clear() + this.logger.debug('SMUS Auth: Successfully disposed authentication provider') } diff --git a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts index 9dc38b33a7a..af0d7cfbbac 100644 --- a/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts +++ b/packages/core/src/sagemakerunifiedstudio/explorer/nodes/redshiftStrategy.ts @@ -24,7 +24,7 @@ import { createPlaceholderItem } from '../../../shared/treeview/utils' import { ConnectionCredentialsProvider } from '../../auth/providers/connectionCredentialsProvider' import { GlueCatalog } from '../../shared/client/glueCatalogClient' import { telemetry } from '../../../shared/telemetry/telemetry' -import { getContext } from '../../../shared/vscode/setContext' +import { recordDataConnectionTelemetry } from '../../shared/telemetry' /** * Redshift data node for SageMaker Unified Studio @@ -119,6 +119,7 @@ export function createRedshiftConnectionNode( connection: DataZoneConnection, connectionCredentialsProvider: ConnectionCredentialsProvider ): RedshiftNode { + const logger = getLogger() return new RedshiftNode( { id: connection.connectionId, @@ -130,19 +131,8 @@ export function createRedshiftConnectionNode( }, async (node) => { return telemetry.smus_renderRedshiftNode.run(async (span) => { - const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') - const accountId = await connectionCredentialsProvider.getDomainAccountId() - span.record({ - smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', - smusDomainId: connection.domainId, - smusDomainAccountId: accountId, - smusProjectId: connection.projectId, - smusConnectionId: connection.connectionId, - smusConnectionType: connection.type, - smusProjectRegion: connection.location?.awsRegion, - }) - const logger = getLogger() logger.info(`Loading Redshift resources for connection ${connection.name}`) + await recordDataConnectionTelemetry(span, connection, connectionCredentialsProvider) const connectionParams = extractConnectionParams(connection) if (!connectionParams) { diff --git a/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts b/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts index 5950ba16c65..35858f0dc5a 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/smusUtils.ts @@ -56,8 +56,14 @@ export const SmusErrorCodes = { UserCancelled: 'UserCancelled', /** Error code for when domain account Id is missing */ AccountIdNotFound: 'AccountIdNotFound', + /** Error code for when resource ARN is missing */ + ResourceArnNotFound: 'ResourceArnNotFound', /** Error code for when fails to get domain account Id */ GetDomainAccountIdFailed: 'GetDomainAccountIdFailed', + /** Error code for when fails to get project account Id */ + GetProjectAccountIdFailed: 'GetProjectAccountIdFailed', + /** Error code for when region is missing */ + RegionNotFound: 'RegionNotFound', } as const /** @@ -369,9 +375,9 @@ export class SmusUtils { * @returns The account ID from the ARN * @throws If the ARN format is invalid */ -export function extractAccountIdFromArn(arn: string): string { +export function extractAccountIdFromSageMakerArn(arn: string): string { // Match the ARN components to extract account ID - const regex = /^arn:aws:sagemaker:(?[^:]+):(?\d+):(app|space|domain)\/.+$/i + const regex = /^arn:aws:sagemaker:(?[^:]+):(?\d+):(app|space)\/.+$/i const match = arn.match(regex) if (!match?.groups) { @@ -380,3 +386,31 @@ export function extractAccountIdFromArn(arn: string): string { return match.groups.accountId } + +/** + * Extracts account ID from ResourceArn in SMUS space environment + * @returns Promise resolving to the account ID + * @throws ToolkitError if unable to extract account ID + */ +export async function extractAccountIdFromResourceMetadata(): Promise { + const logger = getLogger() + + try { + logger.debug('SMUS: Extracting account ID from ResourceArn in resource-metadata file') + + const resourceMetadata = getResourceMetadata()! + const resourceArn = resourceMetadata.ResourceArn + + if (!resourceArn) { + throw new Error('ResourceArn not found in metadata file') + } + + const accountId = extractAccountIdFromSageMakerArn(resourceArn) + logger.debug(`Successfully extracted account ID from resource-metadata file: ${accountId}`) + + return accountId + } catch (err) { + logger.error(`Failed to extract account ID from ResourceArn: %s`, err) + throw new Error('Failed to extract AWS account ID from ResourceArn in SMUS space environment') + } +} diff --git a/packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts b/packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts index b97762270b9..ceeb4828b83 100644 --- a/packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts +++ b/packages/core/src/sagemakerunifiedstudio/shared/telemetry.ts @@ -18,7 +18,7 @@ import { SmusAuthenticationProvider } from '../auth/providers/smusAuthentication import { getLogger } from '../../shared/logger/logger' import { getContext } from '../../shared/vscode/setContext' import { ConnectionCredentialsProvider } from '../auth/providers/connectionCredentialsProvider' -import { DataZoneConnection } from './client/datazoneClient' +import { DataZoneConnection, DataZoneClient } from './client/datazoneClient' /** * Records space telemetry @@ -27,16 +27,39 @@ export async function recordSpaceTelemetry( span: Span | Span, node: SagemakerUnifiedStudioSpaceNode ) { - const parent = node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode - const authProvider = SmusAuthenticationProvider.fromContext() - const accountId = await authProvider.getDomainAccountId() - span.record({ - smusSpaceKey: node.resource.DomainSpaceKey, - smusDomainRegion: node.resource.regionCode, - smusDomainId: parent?.getAuthProvider()?.activeConnection?.domainId, - smusDomainAccountId: accountId, - smusProjectId: parent?.getProjectId(), - }) + const logger = getLogger() + + try { + const parent = node.resource.getParent() as SageMakerUnifiedStudioSpacesParentNode + const authProvider = SmusAuthenticationProvider.fromContext() + const accountId = await authProvider.getDomainAccountId() + const projectId = parent?.getProjectId() + + // Get project account ID and region + let projectAccountId: string | undefined + let projectRegion: string | undefined + + if (projectId) { + projectAccountId = await authProvider.getProjectAccountId(projectId) + + // Get project region from tooling environment + const dzClient = await DataZoneClient.getInstance(authProvider) + const toolingEnv = await dzClient.getToolingEnvironment(projectId) + projectRegion = toolingEnv.awsAccountRegion + } + + span.record({ + smusSpaceKey: node.resource.DomainSpaceKey, + smusDomainRegion: node.resource.regionCode, + smusDomainId: parent?.getAuthProvider()?.activeConnection?.domainId, + smusDomainAccountId: accountId, + smusProjectId: projectId, + smusProjectAccountId: projectAccountId, + smusProjectRegion: projectRegion, + }) + } catch (err) { + logger.error(`Failed to record space telemetry: ${(err as Error).message}`) + } } /** @@ -65,7 +88,7 @@ export async function recordAuthTelemetry( }) } catch (err) { logger.error( - `Failed to resolve AWS account ID via STS Client for domain ${domainId} in region ${region}: ${err}` + `Failed to record Domain AccountId in data connection telemetry for domain ${domainId} in region ${region}: ${err}` ) } } @@ -78,15 +101,22 @@ export async function recordDataConnectionTelemetry( connection: DataZoneConnection, connectionCredentialsProvider: ConnectionCredentialsProvider ) { - const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') - const accountId = await connectionCredentialsProvider.getDomainAccountId() - span.record({ - smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', - smusDomainId: connection.domainId, - smusDomainAccountId: accountId, - smusProjectId: connection.projectId, - smusConnectionId: connection.connectionId, - smusConnectionType: connection.type, - smusProjectRegion: connection.location?.awsRegion, - }) + const logger = getLogger() + + try { + const isInSmusSpace = getContext('aws.smus.inSmusSpaceEnvironment') + const accountId = await connectionCredentialsProvider.getDomainAccountId() + span.record({ + smusToolkitEnv: isInSmusSpace ? 'smus_space' : 'local', + smusDomainId: connection.domainId, + smusDomainAccountId: accountId, + smusProjectId: connection.projectId, + smusConnectionId: connection.connectionId, + smusConnectionType: connection.type, + smusProjectRegion: connection.location?.awsRegion, + smusProjectAccountId: connection.location?.awsAccountId, + }) + } catch (err) { + logger.error(`Failed to record data connection telemetry: ${(err as Error).message}`) + } } diff --git a/packages/core/src/shared/telemetry/vscodeTelemetry.json b/packages/core/src/shared/telemetry/vscodeTelemetry.json index 6fbfa22e394..fee97143abd 100644 --- a/packages/core/src/shared/telemetry/vscodeTelemetry.json +++ b/packages/core/src/shared/telemetry/vscodeTelemetry.json @@ -1462,6 +1462,14 @@ { "type": "smusDomainRegion", "required": false + }, + { + "type": "smusProjectRegion", + "required": false + }, + { + "type": "smusProjectAccountId", + "required": false } ] }, @@ -1488,6 +1496,14 @@ { "type": "smusDomainRegion", "required": false + }, + { + "type": "smusProjectRegion", + "required": false + }, + { + "type": "smusProjectAccountId", + "required": false } ] }, @@ -1515,6 +1531,10 @@ "type": "smusProjectRegion", "required": false }, + { + "type": "smusProjectAccountId", + "required": false + }, { "type": "smusConnectionId", "required": false @@ -1549,6 +1569,10 @@ "type": "smusProjectRegion", "required": false }, + { + "type": "smusProjectAccountId", + "required": false + }, { "type": "smusConnectionId", "required": false @@ -1583,6 +1607,10 @@ "type": "smusProjectRegion", "required": false }, + { + "type": "smusProjectAccountId", + "required": false + }, { "type": "smusConnectionId", "required": false diff --git a/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts b/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts index f971ff5520f..7cd2662f467 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/auth/smusAuthenticationProvider.test.ts @@ -418,7 +418,6 @@ describe('SmusAuthenticationProvider', function () { describe('getDomainAccountId', function () { let getContextStub: sinon.SinonStub let getResourceMetadataStub: sinon.SinonStub - let extractAccountIdFromArnStub: sinon.SinonStub let getDerCredentialsProviderStub: sinon.SinonStub let getDomainRegionStub: sinon.SinonStub let mockStsClient: any @@ -428,7 +427,6 @@ describe('SmusAuthenticationProvider', function () { // Mock dependencies getContextStub = sinon.stub(vscodeSetContext, 'getContext') getResourceMetadataStub = sinon.stub(resourceMetadataUtils, 'getResourceMetadata') - extractAccountIdFromArnStub = sinon.stub(smusUtils, 'extractAccountIdFromArn') // Mock STS client mockStsClient = { @@ -476,41 +474,32 @@ describe('SmusAuthenticationProvider', function () { }) describe('in SMUS space environment', function () { + let extractAccountIdFromResourceMetadataStub: sinon.SinonStub + beforeEach(function () { getContextStub.withArgs('aws.smus.inSmusSpaceEnvironment').returns(true) + extractAccountIdFromResourceMetadataStub = sinon + .stub(smusUtils, 'extractAccountIdFromResourceMetadata') + .resolves('123456789012') }) - it('should extract account ID from ResourceArn and cache it', async function () { + it('should extract account from resource metadata and cache result', async function () { const testAccountId = '123456789012' - const testResourceArn = `arn:aws:sagemaker:us-east-1:${testAccountId}:domain/test-domain` - - getResourceMetadataStub.returns({ - ResourceArn: testResourceArn, - }) - extractAccountIdFromArnStub.returns(testAccountId) const result = await smusAuthProvider.getDomainAccountId() assert.strictEqual(result, testAccountId) assert.strictEqual(smusAuthProvider['cachedDomainAccountId'], testAccountId) - assert.ok(getResourceMetadataStub.called) - assert.ok(extractAccountIdFromArnStub.calledWith(testResourceArn)) + assert.ok(extractAccountIdFromResourceMetadataStub.called) assert.ok(mockStsClient.getCallerIdentity.notCalled) }) - it('should throw error when ResourceArn is missing from metadata', async function () { - getResourceMetadataStub.returns({}) + it('should throw error when extractAccountIdFromResourceMetadata fails', async function () { + extractAccountIdFromResourceMetadataStub.rejects(new ToolkitError('Metadata extraction failed')) await assert.rejects( () => smusAuthProvider.getDomainAccountId(), - (err: ToolkitError) => { - return ( - err.code === 'GetDomainAccountIdFailed' && - err.message.includes( - 'Failed to extract AWS account ID from ResourceArn in SMUS space environment' - ) - ) - } + (err: ToolkitError) => err.message.includes('Metadata extraction failed') ) assert.strictEqual(smusAuthProvider['cachedDomainAccountId'], undefined) @@ -576,4 +565,196 @@ describe('SmusAuthenticationProvider', function () { }) }) }) + + describe('getProjectAccountId', function () { + let getContextStub: sinon.SinonStub + let extractAccountIdFromResourceMetadataStub: sinon.SinonStub + let getProjectCredentialProviderStub: sinon.SinonStub + let mockProjectCredentialsProvider: any + let mockStsClient: any + let mockDataZoneClientForProject: any + + const testProjectId = 'test-project-id' + const testAccountId = '123456789012' + const testRegion = 'us-east-1' + + beforeEach(function () { + // Mock dependencies + getContextStub = sinon.stub(vscodeSetContext, 'getContext') + extractAccountIdFromResourceMetadataStub = sinon + .stub(smusUtils, 'extractAccountIdFromResourceMetadata') + .resolves(testAccountId) + + // Mock project credentials provider + mockProjectCredentialsProvider = { + getCredentials: sinon.stub().resolves({ + accessKeyId: 'test-key', + secretAccessKey: 'test-secret', + sessionToken: 'test-token', + }), + } + getProjectCredentialProviderStub = sinon + .stub(smusAuthProvider, 'getProjectCredentialProvider') + .resolves(mockProjectCredentialsProvider) + + // Update the existing mockDataZoneClient to include getToolingEnvironment + mockDataZoneClientForProject = { + getToolingEnvironment: sinon.stub().resolves({ + awsAccountRegion: testRegion, + projectId: testProjectId, + domainId: testDomainId, + createdBy: 'test-user', + name: 'test-environment', + id: 'test-env-id', + status: 'ACTIVE', + }), + } + // Update the existing mockDataZoneClient instead of creating a new stub + Object.assign(mockDataZoneClient, mockDataZoneClientForProject) + + // Mock STS client + mockStsClient = { + getCallerIdentity: sinon.stub().resolves({ + Account: testAccountId, + UserId: 'test-user-id', + Arn: 'arn:aws:sts::123456789012:assumed-role/test-role/test-session', + }), + } + + // Clear cache + smusAuthProvider['cachedProjectAccountIds'].clear() + mockSecondaryAuthState.activeConnection = mockSmusConnection + }) + + afterEach(function () { + sinon.restore() + }) + + describe('when cached value exists', function () { + it('should return cached project account ID without making any calls', async function () { + smusAuthProvider['cachedProjectAccountIds'].set(testProjectId, testAccountId) + + const result = await smusAuthProvider.getProjectAccountId(testProjectId) + + assert.strictEqual(result, testAccountId) + assert.ok(getContextStub.notCalled) + assert.ok(extractAccountIdFromResourceMetadataStub.notCalled) + assert.ok(getProjectCredentialProviderStub.notCalled) + assert.ok(mockStsClient.getCallerIdentity.notCalled) + }) + }) + + describe('in SMUS space environment', function () { + beforeEach(function () { + getContextStub.withArgs('aws.smus.inSmusSpaceEnvironment').returns(true) + }) + + it('should extract account ID from resource metadata and cache it', async function () { + const result = await smusAuthProvider.getProjectAccountId(testProjectId) + + assert.strictEqual(result, testAccountId) + assert.strictEqual(smusAuthProvider['cachedProjectAccountIds'].get(testProjectId), testAccountId) + assert.ok(extractAccountIdFromResourceMetadataStub.called) + assert.ok(getProjectCredentialProviderStub.notCalled) + assert.ok(mockStsClient.getCallerIdentity.notCalled) + }) + + it('should throw error when extractAccountIdFromResourceMetadata fails', async function () { + extractAccountIdFromResourceMetadataStub.rejects(new ToolkitError('Metadata extraction failed')) + + await assert.rejects( + () => smusAuthProvider.getProjectAccountId(testProjectId), + (err: ToolkitError) => err.message.includes('Metadata extraction failed') + ) + + assert.ok(!smusAuthProvider['cachedProjectAccountIds'].has(testProjectId)) + }) + }) + + describe('in non-SMUS space environment', function () { + let stsConstructorStub: sinon.SinonStub + + beforeEach(function () { + getContextStub.withArgs('aws.smus.inSmusSpaceEnvironment').returns(false) + // Stub the DefaultStsClient constructor to return our mock instance + const stsClientModule = require('../../../shared/clients/stsClient') + stsConstructorStub = sinon.stub(stsClientModule, 'DefaultStsClient').callsFake(() => mockStsClient) + }) + + afterEach(function () { + if (stsConstructorStub) { + stsConstructorStub.restore() + } + }) + + it('should use project credentials with STS to get account ID and cache it', async function () { + const result = await smusAuthProvider.getProjectAccountId(testProjectId) + + assert.strictEqual(result, testAccountId) + assert.strictEqual(smusAuthProvider['cachedProjectAccountIds'].get(testProjectId), testAccountId) + assert.ok(getProjectCredentialProviderStub.calledWith(testProjectId)) + assert.ok(mockProjectCredentialsProvider.getCredentials.called) + assert.ok((DataZoneClient.getInstance as sinon.SinonStub).called) + assert.ok(mockDataZoneClientForProject.getToolingEnvironment.calledWith(testProjectId)) + assert.ok(mockStsClient.getCallerIdentity.called) + }) + + it('should throw error when no active connection exists', async function () { + mockSecondaryAuthState.activeConnection = undefined + + await assert.rejects( + () => smusAuthProvider.getProjectAccountId(testProjectId), + (err: ToolkitError) => { + return ( + err.code === 'NoActiveConnection' && + err.message.includes('No active SMUS connection available') + ) + } + ) + + assert.ok(!smusAuthProvider['cachedProjectAccountIds'].has(testProjectId)) + }) + + it('should throw error when tooling environment has no region', async function () { + mockDataZoneClientForProject.getToolingEnvironment.resolves({ + id: 'env-123', + awsAccountRegion: undefined, + projectId: undefined, + domainId: undefined, + createdBy: undefined, + name: undefined, + provider: undefined, + $metadata: {}, + }) + + await assert.rejects( + () => smusAuthProvider.getProjectAccountId(testProjectId), + (err: ToolkitError) => { + return ( + err.message.includes('Failed to get project account ID') && + err.message.includes('No AWS account region found in tooling environment') + ) + } + ) + + assert.ok(!smusAuthProvider['cachedProjectAccountIds'].has(testProjectId)) + }) + + it('should throw error when STS GetCallerIdentity fails', async function () { + mockStsClient.getCallerIdentity.rejects(new Error('STS call failed')) + + await assert.rejects( + () => smusAuthProvider.getProjectAccountId(testProjectId), + (err: ToolkitError) => { + return ( + err.message.includes('Failed to get project account ID') && + err.message.includes('STS call failed') + ) + } + ) + + assert.ok(!smusAuthProvider['cachedProjectAccountIds'].has(testProjectId)) + }) + }) + }) }) diff --git a/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts b/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts index f895fe6ea13..c03b55c64c6 100644 --- a/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts +++ b/packages/core/src/test/sagemakerunifiedstudio/shared/smusUtils.test.ts @@ -11,7 +11,8 @@ import { SmusTimeouts, SmusCredentialExpiry, validateCredentialFields, - extractAccountIdFromArn, + extractAccountIdFromSageMakerArn, + extractAccountIdFromResourceMetadata, } from '../../../sagemakerunifiedstudio/shared/smusUtils' import { ToolkitError } from '../../../shared/errors' import * as extensionUtilities from '../../../shared/extensionUtilities' @@ -463,11 +464,11 @@ describe('SmusUtils', () => { }) }) -describe('extractAccountIdFromArn', () => { +describe('extractAccountIdFromSageMakerArn', () => { describe('valid ARN formats', () => { it('should extract account ID from valid ARN', () => { const arn = 'arn:aws:sagemaker:us-west-2:123456789012:app/domain-id/ce/CodeEditor/default' - const result = extractAccountIdFromArn(arn) + const result = extractAccountIdFromSageMakerArn(arn) assert.strictEqual(result, '123456789012') }) @@ -476,7 +477,7 @@ describe('extractAccountIdFromArn', () => { describe('invalid ARN formats', () => { it('should throw error for empty ARN', () => { assert.throws( - () => extractAccountIdFromArn(''), + () => extractAccountIdFromSageMakerArn(''), (error: any) => { assert.ok(error instanceof ToolkitError) assert.ok(error.message.includes('Invalid SageMaker ARN format')) @@ -487,7 +488,7 @@ describe('extractAccountIdFromArn', () => { it('should throw error for non-ARN string', () => { assert.throws( - () => extractAccountIdFromArn('not-an-arn'), + () => extractAccountIdFromSageMakerArn('not-an-arn'), (error: any) => { assert.ok(error instanceof ToolkitError) assert.ok(error.message.includes('Invalid SageMaker ARN format')) @@ -499,7 +500,7 @@ describe('extractAccountIdFromArn', () => { it('should throw error for wrong service', () => { const arn = 'arn:aws:s3:us-east-1:123456789012:bucket/my-bucket' assert.throws( - () => extractAccountIdFromArn(arn), + () => extractAccountIdFromSageMakerArn(arn), (error: any) => { assert.ok(error instanceof ToolkitError) assert.ok(error.message.includes('Invalid SageMaker ARN format')) @@ -511,7 +512,7 @@ describe('extractAccountIdFromArn', () => { it('should throw error for missing account ID', () => { const arn = 'arn:aws:sagemaker:us-east-1::space/domain/space' assert.throws( - () => extractAccountIdFromArn(arn), + () => extractAccountIdFromSageMakerArn(arn), (error: any) => { assert.ok(error instanceof ToolkitError) assert.ok(error.message.includes('Invalid SageMaker ARN format')) @@ -521,3 +522,58 @@ describe('extractAccountIdFromArn', () => { }) }) }) + +describe('extractAccountIdFromResourceMetadata', () => { + let getResourceMetadataStub: sinon.SinonStub + + beforeEach(() => { + getResourceMetadataStub = sinon.stub(resourceMetadataUtils, 'getResourceMetadata') + }) + + afterEach(() => { + sinon.restore() + }) + + it('should extract account ID from ResourceArn successfully', async () => { + const testAccountId = '123456789012' + const testResourceArn = `arn:aws:sagemaker:us-east-1:${testAccountId}:app/domain-id/appName/CodeEditor/default` + + getResourceMetadataStub.returns({ + ResourceArn: testResourceArn, + }) + + const result = await extractAccountIdFromResourceMetadata() + + assert.strictEqual(result, testAccountId) + assert.ok(getResourceMetadataStub.called) + }) + + it('should throw error when ResourceArn is missing', async () => { + getResourceMetadataStub.returns({}) + + await assert.rejects( + () => extractAccountIdFromResourceMetadata(), + (err: Error) => { + return err.message.includes( + 'Failed to extract AWS account ID from ResourceArn in SMUS space environment' + ) + } + ) + }) + + it('should throw error when extractAccountIdFromSageMakerArn fails', async () => { + const testResourceArn = 'invalid-arn' + getResourceMetadataStub.returns({ + ResourceArn: testResourceArn, + }) + + await assert.rejects( + () => extractAccountIdFromResourceMetadata(), + (err: Error) => { + return err.message.includes( + 'Failed to extract AWS account ID from ResourceArn in SMUS space environment' + ) + } + ) + }) +})