Skip to content

Commit 5c8fc36

Browse files
liuzulinsjc25-test
andauthored
feat(sagemakerunifiedstudio): Fix list spaces in cross region cross account cases (aws#2209)
## Problem 1. In a cross region cross account set up. seeing error fetching space 2. The error message when cannot connect is not meaningful. I had to check on telemetry log to see the error message ## Solution Fixed the above issue --- - Treat all work as PUBLIC. Private `feature/x` branches will not be squash-merged at release time. - Your code changes must meet the guidelines in [CONTRIBUTING.md](https://github.com/aws/aws-toolkit-vscode/blob/master/CONTRIBUTING.md#guidelines). - License: I confirm that my contribution is made under the terms of the Apache 2.0 license. --------- Co-authored-by: Zulin Liu <[email protected]>
1 parent d39c3ec commit 5c8fc36

File tree

10 files changed

+376
-40
lines changed

10 files changed

+376
-40
lines changed

packages/core/src/awsService/sagemaker/commands.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ export async function stopSpace(
161161
code: error.name,
162162
})
163163
} else {
164-
throw new ToolkitError(`Failed to stop space: ${spaceName}`, {
164+
throw new ToolkitError(`Failed to stop space ${spaceName}: ${(error as Error).message}`, {
165165
cause: error,
166166
code: error.name,
167167
})
@@ -198,7 +198,7 @@ export async function openRemoteConnect(
198198
} catch (err: any) {
199199
// Ignore InstanceTypeError since it means the user decided not to use an instanceType with more memory
200200
if (err.code !== InstanceTypeError) {
201-
throw new ToolkitError('Remote connection failed.', {
201+
throw new ToolkitError(`Remote connection failed: ${(err as Error).message}`, {
202202
cause: err as Error,
203203
code: err.code,
204204
})

packages/core/src/sagemakerunifiedstudio/auth/providers/domainExecRoleCredentialsProvider.ts

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,15 +202,16 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider {
202202
)
203203
}
204204

205-
const data = (await response.json()) as {
205+
const responseText = await response.text()
206+
207+
const data = JSON.parse(responseText) as {
206208
credentials: {
207209
accessKeyId: string
208210
secretAccessKey: string
209211
sessionToken: string
210212
expiration: string
211213
}
212214
}
213-
214215
this.logger.debug(`SMUS DER: Successfully received credentials from API for domain ${this.domainId}`)
215216

216217
// Validate the response data structure
@@ -226,21 +227,50 @@ export class DomainExecRoleCredentialsProvider implements CredentialsProvider {
226227
validateCredentialFields(credentials, 'InvalidCredentialResponse', 'API response')
227228

228229
// Create credentials with expiration
229-
// Note: The response doesn't include expiration yet, so we set it to 10 minutes for now if it does't exist
230230
let credentialExpiresAt: Date
231231
if (credentials.expiration) {
232-
// The API returns expiration as a string, convert to Date
233-
const parsedExpiration = new Date(credentials.expiration)
232+
// Handle both epoch timestamps and ISO date strings
233+
let parsedExpiration: Date
234+
235+
// Check if expiration is a numeric string (epoch timestamp)
236+
const expirationNum = Number(credentials.expiration)
237+
if (!isNaN(expirationNum) && expirationNum > 0) {
238+
// Treat as epoch timestamp in seconds and convert to milliseconds
239+
const timestampMs = expirationNum * 1000
240+
parsedExpiration = new Date(timestampMs)
241+
this.logger.debug(
242+
`SMUS DER: Parsed epoch timestamp ${credentials.expiration} (seconds) as ${parsedExpiration.toISOString()}`
243+
)
244+
} else {
245+
// Treat as ISO date string
246+
parsedExpiration = new Date(credentials.expiration)
247+
if (!isNaN(parsedExpiration.getTime())) {
248+
this.logger.debug(
249+
`SMUS DER: Parsed ISO date string ${credentials.expiration} as ${parsedExpiration.toISOString()}`
250+
)
251+
} else {
252+
this.logger.debug(
253+
`SMUS DER: Failed to parse ISO date string ${credentials.expiration} - invalid date format`
254+
)
255+
}
256+
}
257+
234258
// Check if the parsed date is valid
235259
if (isNaN(parsedExpiration.getTime())) {
236260
this.logger.warn(
237-
`SMUS DER: Invalid expiration date string: ${credentials.expiration}, using default expiration`
261+
`SMUS DER: Invalid expiration value: ${credentials.expiration}, using default expiration`
238262
)
239263
credentialExpiresAt = new Date(Date.now() + SmusCredentialExpiry.derExpiryMs)
240264
} else {
241265
credentialExpiresAt = parsedExpiration
242266
}
267+
if (!isNaN(credentialExpiresAt.getTime())) {
268+
this.logger.debug(`SMUS DER: Credential expires at ${credentialExpiresAt.toISOString()}`)
269+
} else {
270+
this.logger.debug(`SMUS DER: Invalid credential expiration date, using default`)
271+
}
243272
} else {
273+
this.logger.debug(`SMUS DER: No expiration provided, using default`)
244274
credentialExpiresAt = new Date(Date.now() + SmusCredentialExpiry.derExpiryMs)
245275
}
246276

packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioDataNode.ts

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ export class SageMakerUnifiedStudioDataNode implements TreeNode {
9898
(conn) => (conn.type as ConnectionType) === ConnectionType.LAKEHOUSE
9999
)
100100

101-
// Create Bucket parent node if there are S3 connections
102-
if (s3Connections.length > 0) {
103-
const bucketNode = this.createBucketParentNode(project, s3Connections, region)
104-
dataNodes.push(bucketNode)
101+
// Add Lakehouse nodes first
102+
for (const connection of lakehouseConnections) {
103+
const node = await this.createLakehouseNode(project, connection, region)
104+
dataNodes.push(node)
105105
}
106106

107+
// Add Redshift nodes second
107108
for (const connection of redshiftConnections) {
108109
if (connection.name.startsWith('project.lakehouse')) {
109110
continue
@@ -115,9 +116,10 @@ export class SageMakerUnifiedStudioDataNode implements TreeNode {
115116
dataNodes.push(node)
116117
}
117118

118-
for (const connection of lakehouseConnections) {
119-
const node = await this.createLakehouseNode(project, connection, region)
120-
dataNodes.push(node)
119+
// Add S3 Bucket parent node last
120+
if (s3Connections.length > 0) {
121+
const bucketNode = this.createBucketParentNode(project, s3Connections, region)
122+
dataNodes.push(bucketNode)
121123
}
122124

123125
this.logger.info(`Created ${dataNodes.length} total connection nodes`)

packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioProjectNode.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,17 @@ export class SageMakerUnifiedStudioProjectNode implements TreeNode {
117117
return [dataNode]
118118
}
119119

120-
this.sagemakerClient = await this.initializeSagemakerClient(
121-
this.authProvider.activeConnection?.ssoRegion || 'us-east-1'
122-
)
120+
const dzClient = await DataZoneClient.getInstance(this.authProvider)
121+
if (!this.project?.id) {
122+
throw new Error('Project ID is required')
123+
}
124+
const toolingEnv = await dzClient.getToolingEnvironment(this.project.id)
125+
const spaceAwsAccountRegion = toolingEnv.awsAccountRegion
126+
127+
if (!spaceAwsAccountRegion) {
128+
throw new Error('No AWS account region found in tooling environment')
129+
}
130+
this.sagemakerClient = await this.initializeSagemakerClient(spaceAwsAccountRegion)
123131
const computeNode = new SageMakerUnifiedStudioComputeNode(
124132
this,
125133
this.extensionContext,

packages/core/src/sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode.ts

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import { SageMakerUnifiedStudioComputeNode } from './sageMakerUnifiedStudioCompu
88
import { updateInPlace } from '../../../shared/utilities/collectionUtils'
99
import { DataZoneClient } from '../../shared/client/datazoneClient'
1010
import { DescribeDomainResponse } from '@amzn/sagemaker-client'
11-
import { GetEnvironmentCommandOutput } from '@aws-sdk/client-datazone'
1211
import { getDomainUserProfileKey } from '../../../awsService/sagemaker/utils'
1312
import { getLogger } from '../../../shared/logger/logger'
1413
import { TreeNode } from '../../../shared/treeview/resourceTreeDataProvider'
@@ -31,6 +30,7 @@ export class SageMakerUnifiedStudioSpacesParentNode implements TreeNode {
3130
public readonly onDidChangeTreeItem = this.onDidChangeEmitter.event
3231
public readonly onDidChangeChildren = this.onDidChangeEmitter.event
3332
public readonly pollingSet: PollingSet<string> = new PollingSet(5, this.updatePendingNodes.bind(this))
33+
private spaceAwsAccountRegion: string | undefined
3434

3535
public constructor(
3636
private readonly parent: SageMakerUnifiedStudioComputeNode,
@@ -143,16 +143,9 @@ export class SageMakerUnifiedStudioSpacesParentNode implements TreeNode {
143143
if (!datazoneClient) {
144144
throw new Error('DataZone client is not initialized')
145145
}
146-
const toolingEnvId = await datazoneClient
147-
.getToolingEnvironmentId(datazoneClient.getDomainId(), this.projectId)
148-
.catch((err) => {
149-
this.logger.error('Failed to get tooling environment ID for project %s', this.projectId)
150-
throw new Error(`Failed to get tooling environment ID: ${err.message}`)
151-
})
152-
if (!toolingEnvId) {
153-
throw new Error('No default environment found for project')
154-
}
155-
const toolingEnv: GetEnvironmentCommandOutput = await datazoneClient.getEnvironmentDetails(toolingEnvId)
146+
147+
const toolingEnv = await datazoneClient.getToolingEnvironment(this.projectId)
148+
this.spaceAwsAccountRegion = toolingEnv.awsAccountRegion
156149
if (toolingEnv.provisionedResources) {
157150
for (const resource of toolingEnv.provisionedResources) {
158151
if (resource.name === 'sageMakerDomainId') {
@@ -224,7 +217,10 @@ export class SageMakerUnifiedStudioSpacesParentNode implements TreeNode {
224217
new SagemakerUnifiedStudioSpaceNode(
225218
this as any,
226219
this.sagemakerClient,
227-
datazoneClient.getRegion(),
220+
this.spaceAwsAccountRegion ||
221+
(() => {
222+
throw new Error('No AWS account region found in tooling environment')
223+
})(),
228224
this.spaceApps.get(key)!,
229225
true /* isSMUSSpace */
230226
)

packages/core/src/sagemakerunifiedstudio/shared/client/datazoneClient.ts

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import {
1515
S3PropertiesOutput,
1616
ConnectionType,
1717
GluePropertiesOutput,
18+
GetEnvironmentCommandOutput,
1819
} from '@aws-sdk/client-datazone'
1920
import { getLogger } from '../../../shared/logger/logger'
2021
import type { SmusAuthenticationProvider } from '../../auth/providers/smusAuthenticationProvider'
@@ -753,6 +754,33 @@ export class DataZoneClient {
753754
}
754755
}
755756

757+
/**
758+
* Gets the tooling environment details for a project
759+
* @param projectId The project ID
760+
* @returns The tooling environment details
761+
*/
762+
public async getToolingEnvironment(projectId: string): Promise<GetEnvironmentCommandOutput> {
763+
const logger = getLogger()
764+
765+
const datazoneClient = await DataZoneClient.getInstance(this.authProvider)
766+
if (!datazoneClient) {
767+
throw new Error('DataZone client is not initialized')
768+
}
769+
770+
const toolingEnvId = await datazoneClient
771+
.getToolingEnvironmentId(datazoneClient.getDomainId(), projectId)
772+
.catch((err) => {
773+
logger.error('Failed to get tooling environment ID for project %s', projectId)
774+
throw new Error(`Failed to get tooling environment ID: ${err.message}`)
775+
})
776+
777+
if (!toolingEnvId) {
778+
throw new Error('No default environment found for project')
779+
}
780+
781+
return await datazoneClient.getEnvironmentDetails(toolingEnvId)
782+
}
783+
756784
public async getUserId(): Promise<string | undefined> {
757785
const derCredProvider = await this.authProvider.getDerCredentialsProvider()
758786
this.logger.debug(`Calling STS GetCallerIdentity using DER credentials of ${this.getDomainId()}`)

0 commit comments

Comments
 (0)