Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ src.gen/*
**/src/auth/sso/oidcclientpkce.d.ts
**/src/sagemakerunifiedstudio/shared/client/gluecatalogapi.d.ts
**/src/sagemakerunifiedstudio/shared/client/sqlworkbench.d.ts
**/src/sagemakerunifiedstudio/shared/client/datazonecustomclient.d.ts

# Generated by tests
**/src/testFixtures/**/bin
Expand Down
10 changes: 5 additions & 5 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"scan-licenses": "ts-node ./scripts/scan-licenses.ts"
},
"devDependencies": {
"@aws-toolkits/telemetry": "^1.0.329",
"@aws-toolkits/telemetry": "^1.0.338",
"@playwright/browser-chromium": "^1.43.1",
"@stylistic/eslint-plugin": "^2.11.0",
"@types/he": "^1.2.3",
Expand Down
1 change: 1 addition & 0 deletions packages/core/package.nls.json
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@
"AWS.command.s3.uploadFileToParent": "Upload to Parent...",
"AWS.command.smus.switchProject": "Switch Project",
"AWS.command.smus.refreshProject": "Refresh Project",
"AWS.command.smus.refresh": "Refresh",
"AWS.command.smus.signOut": "Sign Out",
"AWS.command.sagemaker.filterSpaces": "Filter Sagemaker Spaces",
"AWS.command.stepFunctions.createStateMachineFromTemplate": "Create a new Step Functions state machine",
Expand Down
4 changes: 4 additions & 0 deletions packages/core/scripts/build/generateServiceClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ void (async () => {
serviceJsonPath: 'src/sagemakerunifiedstudio/shared/client/sqlworkbench.json',
serviceName: 'SQLWorkbench',
},
{
serviceJsonPath: 'src/sagemakerunifiedstudio/shared/client/datazonecustomclient.json',
serviceName: 'DataZoneCustomClient',
},
]
await generateServiceClients(serviceClientDefinitions)
})()
2 changes: 1 addition & 1 deletion packages/core/src/auth/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
* e.g. https://view.awsapps.com/start/# will become https://view.awsapps.com/start
*/
public normalizeStartUrl(startUrl: string | undefined) {
return !startUrl ? undefined : startUrl.replace(/[\/#]+$/g, '')

Check failure

Code scanning / CodeQL

Polynomial regular expression used on uncontrolled data High

This
regular expression
that depends on
library input
may run slow on strings with many repetitions of '#'.
}

public isInternalAmazonUser(): boolean {
Expand Down Expand Up @@ -598,7 +598,7 @@
}

@withTelemetryContext({ name: 'updateConnectionState', class: authClassName })
private async updateConnectionState(id: Connection['id'], connectionState: ProfileMetadata['connectionState']) {
public async updateConnectionState(id: Connection['id'], connectionState: ProfileMetadata['connectionState']) {
getLogger().info(`auth: Updating connection state of ${id} to ${connectionState}`)

if (connectionState === 'authenticating') {
Expand Down
27 changes: 27 additions & 0 deletions packages/core/src/auth/sso/clients.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import { AuthenticationFlow } from './model'
import { toSnakeCase } from '../../shared/utilities/textUtilities'
import { getUserAgent, withTelemetryContext } from '../../shared/telemetry/util'
import { oneSecond } from '../../shared/datetime'
import { telemetry } from '../../shared/telemetry/telemetry'
import { getTelemetryReason, getTelemetryReasonDesc, getHttpStatusCode } from '../../shared/errors'

export class OidcClient {
public constructor(
Expand Down Expand Up @@ -86,15 +88,40 @@ export class OidcClient {
}

public async createToken(request: CreateTokenRequest) {
const startTime = this.clock.Date.now()
const grantType = request.grantType

let response
try {
response = await this.client.createToken(request as CreateTokenRequest)
} catch (err) {
const statusCode = getHttpStatusCode(err)
telemetry.auth_ssoTokenOperation.emit({
result: 'Failed',
grantType: grantType ?? 'unknown',
duration: this.clock.Date.now() - startTime,
reason: getTelemetryReason(err),
reasonDesc: getTelemetryReasonDesc(err),
...(statusCode !== undefined ? { httpStatusCode: String(statusCode) } : {}),
})

getLogger().error(`sso-oidc: createToken failed (grantType=${grantType}): ${err}`)

const newError = AwsClientResponseError.instanceIf(err)
throw newError
}
assertHasProps(response, 'accessToken', 'expiresIn')

telemetry.auth_ssoTokenOperation.emit({
result: 'Succeeded',
grantType: grantType ?? 'unknown',
duration: this.clock.Date.now() - startTime,
})

getLogger().debug(
`sso-oidc: createToken succeeded (grantType=${grantType}, requestId=${response.$metadata.requestId})`
)

return {
...selectFrom(response, 'accessToken', 'refreshToken', 'tokenType'),
requestId: response.$metadata.requestId,
Expand Down
13 changes: 10 additions & 3 deletions packages/core/src/awsService/sagemaker/credentialMapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { getLogger } from '../../shared/logger/logger'
import { parseArn } from './detached-server/utils'
import { SagemakerUnifiedStudioSpaceNode } from '../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpaceNode'
import { SageMakerUnifiedStudioSpacesParentNode } from '../../sagemakerunifiedstudio/explorer/nodes/sageMakerUnifiedStudioSpacesParentNode'
import { isSmusSsoConnection } from '../../sagemakerunifiedstudio/auth/model'

const mappingFileName = '.sagemaker-space-profiles'
const mappingFilePath = path.join(os.homedir(), '.aws', mappingFileName)
Expand Down Expand Up @@ -74,10 +75,11 @@ export async function persistLocalCredentials(spaceArn: string): Promise<void> {
export async function persistSmusProjectCreds(spaceArn: string, node: SagemakerUnifiedStudioSpaceNode): Promise<void> {
const nodeParent = node.getParent() as SageMakerUnifiedStudioSpacesParentNode
const authProvider = nodeParent.getAuthProvider()
const activeConnection = authProvider.activeConnection
const projectId = nodeParent.getProjectId()
const projectAuthProvider = await authProvider.getProjectCredentialProvider(projectId)
await projectAuthProvider.getCredentials()
await setSmusSpaceSsoProfile(spaceArn, projectId)
await setSmusSpaceProfile(spaceArn, projectId, isSmusSsoConnection(activeConnection) ? 'sso' : 'iam')
// Trigger SSH credential refresh for the project
projectAuthProvider.startProactiveCredentialRefresh()
}
Expand Down Expand Up @@ -177,11 +179,16 @@ export async function setSpaceSsoProfile(
* Sets the SM Space to map to SageMaker Unified Studio Project.
* @param spaceArn - The arn of the SageMaker Unified Studio space.
* @param projectId - The project ID associated with the SageMaker Unified Studio space.
* @param credentialType - The type of credential ('sso' or 'iam').
*/
export async function setSmusSpaceSsoProfile(spaceArn: string, projectId: string): Promise<void> {
export async function setSmusSpaceProfile(
spaceArn: string,
projectId: string,
credentialType: 'iam' | 'sso'
): Promise<void> {
const data = await loadMappings()
data.localCredential ??= {}
data.localCredential[spaceArn] = { type: 'sso', smusProjectId: projectId }
data.localCredential[spaceArn] = { type: credentialType, smusProjectId: projectId }
await saveMappings(data)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,25 @@ export async function resolveCredentialsFor(connectionIdentifier: string): Promi

switch (profile.type) {
case 'iam': {
const name = profile.profileName?.split(':')[1]
if (!name) {
throw new Error(`Invalid IAM profile name for "${connectionIdentifier}"`)
if ('profileName' in profile) {
const name = profile.profileName?.split(':')[1]
if (!name) {
throw new Error(`Invalid IAM profile name for "${connectionIdentifier}"`)
}
return fromIni({ profile: name })
} else if ('smusProjectId' in profile) {
const { accessKey, secret, token } = mapping.smusProjects?.[profile.smusProjectId] || {}
if (!accessKey || !secret || !token) {
throw new Error(`Missing ProjectRole credentials for SMUS Space "${connectionIdentifier}"`)
}
return {
accessKeyId: accessKey,
secretAccessKey: secret,
sessionToken: token,
}
} else {
throw new Error(`Missing IAM credentials for "${connectionIdentifier}"`)
}
return fromIni({ profile: name })
}
case 'sso': {
if ('accessKey' in profile && 'secret' in profile && 'token' in profile) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ export const getVSCodeErrorTitle = (error: SageMakerServiceException): string =>
return ErrorText.StartSession[ExceptionType.DEFAULT].Title
}

export const getVSCodeErrorText = (error: SageMakerServiceException, isSmus?: boolean): string => {
export const getVSCodeErrorText = (
error: SageMakerServiceException,
isSmus?: boolean,
isSmusIamConn?: boolean
): string => {
const exceptionType = error.name as ExceptionType

switch (exceptionType) {
Expand All @@ -41,9 +45,12 @@ export const getVSCodeErrorText = (error: SageMakerServiceException, isSmus?: bo
return ErrorText.StartSession[exceptionType].Text.replace('{message}', error.message)
case ExceptionType.EXPIRED_TOKEN:
// Use SMUS-specific message if in SMUS context
return isSmus
? ErrorText.StartSession[ExceptionType.EXPIRED_TOKEN].SmusText
: ErrorText.StartSession[exceptionType].Text
if (isSmus) {
return isSmusIamConn
? ErrorText.StartSession[ExceptionType.EXPIRED_TOKEN].SmusIamText
: ErrorText.StartSession[ExceptionType.EXPIRED_TOKEN].SmusSsoText
}
return ErrorText.StartSession[exceptionType].Text
case ExceptionType.INTERNAL_FAILURE:
case ExceptionType.RESOURCE_LIMIT_EXCEEDED:
case ExceptionType.THROTTLING:
Expand All @@ -66,8 +73,10 @@ export const ErrorText = {
[ExceptionType.EXPIRED_TOKEN]: {
Title: 'Authentication expired',
Text: 'Your session has expired. Please refresh your credentials and try again.',
SmusText:
'Your session has expired. This is likely due to network connectivity issues after machine sleep/resume. Please wait 10-30 seconds for automatic credential refresh, then try again. If the issue persists, try reconnecting through AWS Toolkit.',
SmusSsoText:
'Your session has expired. This is likely due to network connectivity issues after machine sleep/resume. Wait 10-30 seconds for automatic credential refresh, then try again. If the issue persists, try reconnecting through AWS Toolkit.',
SmusIamText:
'Your session has expired. Update the credentials associated with the IAM profile or use a valid IAM profile, then try again.',
},
[ExceptionType.INTERNAL_FAILURE]: {
Title: 'Failed to connect remotely to VSCode',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
// Disabled: detached server files cannot import vscode.
/* eslint-disable aws-toolkits/no-console-log */
import { IncomingMessage, ServerResponse } from 'http'
import { startSagemakerSession, parseArn, isSmusConnection } from '../utils'
import { startSagemakerSession, parseArn, isSmusConnection, isSmusIamConnection } from '../utils'
import { resolveCredentialsFor } from '../credentials'
import url from 'url'
import { SageMakerServiceException } from '@amzn/sagemaker-client'
Expand Down Expand Up @@ -35,6 +35,7 @@
const { region } = parseArn(connectionIdentifier)
// Detect if this is a SMUS connection for specialized error handling
const isSmus = await isSmusConnection(connectionIdentifier)
const isSmusIamConn = await isSmusIamConnection(connectionIdentifier)

try {
const session = await startSagemakerSession({ region, connectionIdentifier, credentials })
Expand All @@ -48,9 +49,9 @@
)
} catch (err) {
const error = err as SageMakerServiceException
console.error(`Failed to start SageMaker session for ${connectionIdentifier}:`, err)

Check failure

Code scanning / CodeQL

Use of externally-controlled format string High

Format string depends on a
user-provided value
.
const errorTitle = getVSCodeErrorTitle(error)
const errorText = getVSCodeErrorText(error, isSmus)
const errorText = getVSCodeErrorText(error, isSmus, isSmusIamConn)
await openErrorPage(errorTitle, errorText)
res.writeHead(500, { 'Content-Type': 'text/plain' })
res.end('Failed to start SageMaker session')
Expand Down
18 changes: 18 additions & 0 deletions packages/core/src/awsService/sagemaker/detached-server/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,24 @@ export async function isSmusConnection(connectionIdentifier: string): Promise<bo
}
}

/**
* Detects if the connection identifier is using SMUS IAM credentials
* @param connectionIdentifier - The connection identifier to check
* @returns Promise<boolean> - true if SMUS IAM connection, false otherwise
*/
export async function isSmusIamConnection(connectionIdentifier: string): Promise<boolean> {
try {
const mapping = await readMapping()
const profile = mapping.localCredential?.[connectionIdentifier]

// Check if profile exists, has smusProjectId, and type is 'iam'
return profile && 'smusProjectId' in profile && profile.type === 'iam'
} catch (err) {
// If we can't detect it is iam connection, assume not SMUS IAM to avoid breaking existing functionality
return false
}
}

/**
* Writes the mapping to a temp file and atomically renames it to the target path.
* Uses a queue to prevent race conditions when multiple requests try to write simultaneously.
Expand Down
21 changes: 21 additions & 0 deletions packages/core/src/awsService/sagemaker/hyperpodCommands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,30 @@ const localize = nls.loadMessageBundle()

export async function openHyperPodRemoteConnection(node: SagemakerDevSpaceNode): Promise<void> {
await startHyperpodSpaceCommand(node)
await waitForDevSpaceRunning(node)
await connectToHyperPodDevSpace(node)
}

async function waitForDevSpaceRunning(node: SagemakerDevSpaceNode): Promise<void> {
const kubectlClient = node.getParent().getKubectlClient(node.hpCluster.clusterName)
if (!kubectlClient) {
getLogger().error(`No kubectlClient available for cluster: ${node.hpCluster.clusterName}`)
return
}
const timeout = 5 * 60 * 1000 // 5 minutes
const startTime = Date.now()

while (Date.now() - startTime < timeout) {
const status = await kubectlClient.getHyperpodSpaceStatus(node.devSpace)
if (status === 'Running') {
return
}
await new Promise((resolve) => setTimeout(resolve, 5000))
}

throw new Error('Timeout waiting for dev space to reach Running status')
}

export async function connectToHyperPodDevSpace(node: SagemakerDevSpaceNode): Promise<void> {
const logger = getLogger()

Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/awsService/sagemaker/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export interface SpaceMappings {
export type LocalCredentialProfile =
| { type: 'iam'; profileName: string }
| { type: 'sso'; accessKey: string; secret: string; token: string }
| { type: 'sso'; smusProjectId: string }
| { type: 'sso' | 'iam'; smusProjectId: string }

export interface DeeplinkSession {
requests: Record<string, SsmConnectionInfo>
Expand Down
Loading
Loading