Skip to content

Commit fd73378

Browse files
vpbhargavBhargava Varadharajan
andauthored
feat(smus): Add credential providers for SMUS (aws#2187)
**Description** Added credential providers for DER, Project role credentials and Connections credentials. Added in memory caching for now. The credential providers currently are structured so that it can be used with CredentialStore but a lot of things to do with credential store is really old. In the coming days, while we do bug bashes and tests, if the in memory caching fares well enough, will simplify the credential provider to implement the AWS SDK CredentialProvider directly. **Motivation** Support auth for SMUS and ensure all clients get credentials and don't fail due to lack of credentials. **Testing Done** Updated unit tests and tested flow locally. --- - Treat all work as PUBLIC. Private `feature/x` branches will not be squash-merged at release time. - Your code changes must meet the guidelines in [CONTRIBUTING.md](https://github.com/aws/aws-toolkit-vscode/blob/master/CONTRIBUTING.md#guidelines). - License: I confirm that my contribution is made under the terms of the Apache 2.0 license. --------- Co-authored-by: Bhargava Varadharajan <[email protected]>
1 parent 3edd5a8 commit fd73378

25 files changed

+3174
-851
lines changed

packages/core/src/auth/auth.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,29 @@ export class Auth implements AuthService, ConnectionManager {
219219
}
220220
}
221221

222+
/**
223+
* Gets the SSO access token for a connection
224+
* @param connection The SSO connection to get the token for
225+
* @returns Promise resolving to the access token string
226+
*/
227+
@withTelemetryContext({ name: 'getSsoAccessToken', class: authClassName })
228+
public async getSsoAccessToken(connection: Pick<SsoConnection, 'id'>): Promise<string> {
229+
const profile = this.store.getProfileOrThrow(connection.id)
230+
231+
if (profile.type !== 'sso') {
232+
throw new Error(`Connection ${connection.id} is not an SSO connection`)
233+
}
234+
235+
const provider = this.getSsoTokenProvider(connection.id, profile)
236+
const token = await provider.getToken()
237+
238+
if (!token?.accessToken) {
239+
throw new Error(`No access token available for connection ${connection.id}`)
240+
}
241+
242+
return token.accessToken
243+
}
244+
222245
public async useConnection({ id }: Pick<SsoConnection, 'id'>): Promise<SsoConnection>
223246
public async useConnection({ id }: Pick<IamConnection, 'id'>): Promise<IamConnection>
224247
@withTelemetryContext({ name: 'useConnection', class: authClassName })
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
/*!
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
import { getLogger } from '../../../shared/logger/logger'
7+
import { ToolkitError } from '../../../shared/errors'
8+
import * as AWS from '@aws-sdk/types'
9+
import { CredentialsId, CredentialsProvider, CredentialsProviderType } from '../../../auth/providers/credentials'
10+
11+
import { DataZoneClient } from '../../shared/client/datazoneClient'
12+
import { SmusAuthenticationProvider } from './smusAuthenticationProvider'
13+
import { CredentialType } from '../../../shared/telemetry/telemetry'
14+
import { SmusCredentialExpiry, validateCredentialFields } from '../../shared/smusUtils'
15+
16+
/**
17+
* Credentials provider for SageMaker Unified Studio Connection credentials
18+
* Uses DataZone API to get connection credentials for a specific connection *
19+
* This provider implements independent caching with 10-minute expiry
20+
*/
21+
export class ConnectionCredentialsProvider implements CredentialsProvider {
22+
private readonly logger = getLogger()
23+
private credentialCache?: {
24+
credentials: AWS.Credentials
25+
expiresAt: Date
26+
}
27+
28+
constructor(
29+
private readonly smusAuthProvider: SmusAuthenticationProvider,
30+
private readonly connectionId: string
31+
) {}
32+
33+
/**
34+
* Gets the connection ID
35+
* @returns Connection ID
36+
*/
37+
public getConnectionId(): string {
38+
return this.connectionId
39+
}
40+
41+
/**
42+
* Gets the credentials ID
43+
* @returns Credentials ID
44+
*/
45+
public getCredentialsId(): CredentialsId {
46+
return {
47+
credentialSource: 'temp',
48+
credentialTypeId: `${this.smusAuthProvider.getDomainId()}:${this.connectionId}`,
49+
}
50+
}
51+
52+
/**
53+
* Gets the provider type
54+
* @returns Provider type
55+
*/
56+
public getProviderType(): CredentialsProviderType {
57+
return 'temp'
58+
}
59+
60+
/**
61+
* Gets the telemetry type
62+
* @returns Telemetry type
63+
*/
64+
public getTelemetryType(): CredentialType {
65+
return 'other'
66+
}
67+
68+
/**
69+
* Gets the default region
70+
* @returns Default region
71+
*/
72+
public getDefaultRegion(): string | undefined {
73+
return this.smusAuthProvider.getDomainRegion()
74+
}
75+
76+
/**
77+
* Gets the hash code
78+
* @returns Hash code
79+
*/
80+
public getHashCode(): string {
81+
const hashCode = `smus-connection:${this.smusAuthProvider.getDomainId()}:${this.connectionId}`
82+
return hashCode
83+
}
84+
85+
/**
86+
* Determines if the provider can auto-connect
87+
* @returns Promise resolving to boolean
88+
*/
89+
public async canAutoConnect(): Promise<boolean> {
90+
return false // SMUS requires manual authentication
91+
}
92+
93+
/**
94+
* Determines if the provider is available
95+
* @returns Promise resolving to boolean
96+
*/
97+
public async isAvailable(): Promise<boolean> {
98+
try {
99+
return this.smusAuthProvider.isConnected()
100+
} catch (err) {
101+
this.logger.error('SMUS Connection: Error checking if auth provider is connected: %s', err)
102+
return false
103+
}
104+
}
105+
106+
/**
107+
* Gets Connection credentials with independent caching
108+
* @returns Promise resolving to credentials
109+
*/
110+
public async getCredentials(): Promise<AWS.Credentials> {
111+
this.logger.debug(`SMUS Connection: Getting credentials for connection ${this.connectionId}`)
112+
113+
// Check cache first (10-minute expiry)
114+
if (this.credentialCache && this.credentialCache.expiresAt > new Date()) {
115+
this.logger.debug(
116+
`SMUS Connection: Using cached connection credentials for connection ${this.connectionId}`
117+
)
118+
return this.credentialCache.credentials
119+
}
120+
121+
this.logger.debug(
122+
`SMUS Connection: Calling GetConnection to fetch credentials for connection ${this.connectionId}`
123+
)
124+
125+
try {
126+
const datazoneClient = await DataZoneClient.getInstance(this.smusAuthProvider)
127+
const getConnectionResponse = await datazoneClient.getConnection({
128+
domainIdentifier: this.smusAuthProvider.getDomainId(),
129+
identifier: this.connectionId,
130+
withSecret: true,
131+
})
132+
133+
this.logger.debug(`SMUS Connection: Successfully retrieved connection details for ${this.connectionId}`)
134+
135+
// Extract connection credentials
136+
const connectionCredentials = getConnectionResponse.connectionCredentials
137+
if (!connectionCredentials) {
138+
throw new ToolkitError(
139+
`No connection credentials available in response for connection ${this.connectionId}`,
140+
{
141+
code: 'NoConnectionCredentials',
142+
}
143+
)
144+
}
145+
146+
// Validate credential fields
147+
validateCredentialFields(
148+
connectionCredentials,
149+
'InvalidConnectionCredentials',
150+
'connection credential response'
151+
)
152+
153+
// Create AWS credentials with expiration
154+
// Use the expiration from the response if available, otherwise default to 10 minutes
155+
let expiresAt: Date
156+
if (connectionCredentials.expiration) {
157+
// The API returns expiration as a string or Date, handle both cases
158+
expiresAt =
159+
connectionCredentials.expiration instanceof Date
160+
? connectionCredentials.expiration
161+
: new Date(connectionCredentials.expiration)
162+
} else {
163+
expiresAt = new Date(Date.now() + SmusCredentialExpiry.connectionExpiryMs)
164+
}
165+
166+
const awsCredentials: AWS.Credentials = {
167+
accessKeyId: connectionCredentials.accessKeyId as string,
168+
secretAccessKey: connectionCredentials.secretAccessKey as string,
169+
sessionToken: connectionCredentials.sessionToken as string,
170+
expiration: expiresAt,
171+
}
172+
173+
// Cache connection credentials (10-minute expiry)
174+
const cacheExpiresAt = new Date(Date.now() + SmusCredentialExpiry.connectionExpiryMs)
175+
this.credentialCache = {
176+
credentials: awsCredentials,
177+
expiresAt: cacheExpiresAt,
178+
}
179+
180+
this.logger.debug(
181+
`SMUS Connection: Successfully cached connection credentials for connection ${this.connectionId}, expires in %s minutes`,
182+
Math.round((cacheExpiresAt.getTime() - Date.now()) / 60000)
183+
)
184+
185+
return awsCredentials
186+
} catch (err) {
187+
this.logger.error(
188+
`SMUS Connection: Failed to get connection credentials for connection ${this.connectionId}: %s`,
189+
err
190+
)
191+
192+
// Re-throw ToolkitErrors with specific codes (NoConnectionCredentials, InvalidConnectionCredentials)
193+
if (
194+
err instanceof ToolkitError &&
195+
(err.code === 'NoConnectionCredentials' || err.code === 'InvalidConnectionCredentials')
196+
) {
197+
throw err
198+
}
199+
200+
// Wrap other errors in ConnectionCredentialsFetchFailed
201+
throw new ToolkitError(`Failed to get connection credentials for ${this.connectionId}: ${err}`, {
202+
code: 'ConnectionCredentialsFetchFailed',
203+
cause: err instanceof Error ? err : undefined,
204+
})
205+
}
206+
}
207+
208+
/**
209+
* Invalidates cached connection credentials
210+
* Clears the internal cache without fetching new credentials
211+
*/
212+
public invalidate(): void {
213+
this.logger.debug(`SMUS Connection: Invalidating cached credentials for connection ${this.connectionId}`)
214+
// Clear cache to force fresh fetch on next getCredentials() call
215+
this.credentialCache = undefined
216+
this.logger.debug(
217+
`SMUS Connection: Successfully invalidated connection credentials cache for connection ${this.connectionId}`
218+
)
219+
}
220+
}

0 commit comments

Comments
 (0)