Skip to content

Commit e7b652d

Browse files
authored
feat: sm ide #3837
Problem: New SM VSCode IDE with IAM role, we want to use the IAM role for Code Whisperer auth. Currently Code Whisperer auth used Builder ID/SSO/Cloud9(IAM). Current IAM auth is for Cloud9 IDE and does not allow for paginated results. Solution: Add new IDE type for SM VSCode that can be used for this feature and future modifications. Use the new IDE type for conditional on connection type to use, defaulting to IAM. Use IAM for Code Whisperer and make sure we use the correct, `generateRecommendations`, endpoint to allow for auth with IAM and paginated results.
1 parent fa9fe68 commit e7b652d

File tree

10 files changed

+142
-29
lines changed

10 files changed

+142
-29
lines changed

src/auth/auth.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ import {
5353
loadLinkedProfilesIntoStore,
5454
ssoAccountAccessScopes,
5555
} from './connection'
56+
import { isSageMaker, isCloud9 } from '../shared/extensionUtilities'
5657

5758
interface AuthService {
5859
/**
@@ -777,3 +778,15 @@ export class Auth implements AuthService, ConnectionManager {
777778
: `${localizedText.iamIdentityCenter} (${truncatedUrl})`
778779
}
779780
}
781+
/**
782+
* Returns true if credentials are provided by the environment (ex. via ~/.aws/)
783+
*
784+
* @param isC9 boolean for if Cloud9 is host
785+
* @param isSM boolean for if SageMaker is host
786+
* @returns boolean for if C9 "OR" SM
787+
*/
788+
export function hasVendedIamCredentials(isC9?: boolean, isSM?: boolean) {
789+
isC9 ??= isCloud9()
790+
isSM ??= isSageMaker()
791+
return isSM || isC9
792+
}

src/codewhisperer/client/codewhisperer.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import * as CodeWhispererUserClient from './codewhispereruserclient'
99
import { ListAvailableCustomizationsResponse, SendTelemetryEventRequest } from './codewhispereruserclient'
1010
import * as CodeWhispererConstants from '../models/constants'
1111
import { ServiceOptions } from '../../shared/awsClientBuilder'
12-
import { isCloud9 } from '../../shared/extensionUtilities'
12+
import { hasVendedIamCredentials } from '../../auth/auth'
1313
import { CodeWhispererSettings } from '../util/codewhispererSettings'
1414
import { PromiseResult } from 'aws-sdk/lib/request'
1515
import { AuthUtil } from '../util/authUtil'
@@ -85,7 +85,8 @@ export class DefaultCodeWhispererClient {
8585
}
8686
// This logic is for backward compatability with legacy SDK v2 behavior for refreshing
8787
// credentials. Once the Toolkit adds a file watcher for credentials it won't be needed.
88-
if (isCloud9()) {
88+
89+
if (hasVendedIamCredentials()) {
8990
req.on('retry', resp => {
9091
if (
9192
resp.error?.code === 'AccessDeniedException' &&

src/codewhisperer/commands/basicCommands.ts

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,15 @@ export const createGettingStartedNode = () =>
6060

6161
export const enableCodeSuggestions = Commands.declare(
6262
'aws.codeWhisperer.enableCodeSuggestions',
63-
(context: ExtContext) => async () => {
64-
await set(CodeWhispererConstants.autoTriggerEnabledKey, true, context.extensionContext.globalState)
65-
await vscode.commands.executeCommand('setContext', 'CODEWHISPERER_ENABLED', true)
66-
await vscode.commands.executeCommand('aws.codeWhisperer.refresh')
67-
if (!isCloud9()) {
68-
await vscode.commands.executeCommand('aws.codeWhisperer.refreshStatusBar')
63+
(context: ExtContext) =>
64+
async (isAuto: boolean = true) => {
65+
await set(CodeWhispererConstants.autoTriggerEnabledKey, isAuto, context.extensionContext.globalState)
66+
await vscode.commands.executeCommand('setContext', 'CODEWHISPERER_ENABLED', true)
67+
await vscode.commands.executeCommand('aws.codeWhisperer.refresh')
68+
if (!isCloud9()) {
69+
await vscode.commands.executeCommand('aws.codeWhisperer.refreshStatusBar')
70+
}
6971
}
70-
}
7172
)
7273

7374
export const showReferenceLog = Commands.declare(

src/codewhisperer/explorer/codewhispererNode.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import {
1919
import { createGettingStartedNode } from '../commands/basicCommands'
2020
import { Commands } from '../../shared/vscode/commands2'
2121
import { RootNode } from '../../awsexplorer/localExplorer'
22-
import { isCloud9 } from '../../shared/extensionUtilities'
22+
import { hasVendedIamCredentials } from '../../auth/auth'
2323
import { AuthUtil } from '../util/authUtil'
2424
import { TreeNode } from '../../shared/treeview/resourceTreeDataProvider'
2525

@@ -57,9 +57,13 @@ export class CodeWhispererNode implements RootNode {
5757

5858
private getDescription(): string {
5959
if (AuthUtil.instance.isConnectionValid()) {
60-
return AuthUtil.instance.isEnterpriseSsoInUse()
61-
? 'IAM Identity Center Connected'
62-
: 'AWS Builder ID Connected'
60+
if (AuthUtil.instance.isEnterpriseSsoInUse()) {
61+
return 'IAM Identity Center Connected'
62+
} else if (AuthUtil.instance.isBuilderIdInUse()) {
63+
return 'AWS Builder ID Connected'
64+
} else {
65+
return 'IAM Connected'
66+
}
6367
} else if (AuthUtil.instance.isConnectionExpired()) {
6468
return 'Expired Connection'
6569
}
@@ -76,13 +80,13 @@ export class CodeWhispererNode implements RootNode {
7680
return [createSsoSignIn(), createLearnMore()]
7781
}
7882
if (this._showFreeTierLimitReachedNode) {
79-
if (isCloud9()) {
83+
if (hasVendedIamCredentials()) {
8084
return [createFreeTierLimitMetNode(), createOpenReferenceLogNode()]
8185
} else {
8286
return [createFreeTierLimitMetNode(), createSecurityScanNode(), createOpenReferenceLogNode()]
8387
}
8488
} else {
85-
if (isCloud9()) {
89+
if (hasVendedIamCredentials()) {
8690
return [createAutoSuggestionsNode(autoTriggerEnabled), createOpenReferenceLogNode()]
8791
} else {
8892
if (AuthUtil.instance.isValidEnterpriseSsoInUse() && AuthUtil.instance.isCustomizationFeatureEnabled) {

src/codewhisperer/service/recommendationHandler.ts

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ import { AWSError } from 'aws-sdk'
1414
import { isAwsError } from '../../shared/errors'
1515
import { TelemetryHelper } from '../util/telemetryHelper'
1616
import { getLogger } from '../../shared/logger'
17-
import { isCloud9 } from '../../shared/extensionUtilities'
17+
import { isCloud9, isSageMaker } from '../../shared/extensionUtilities'
18+
import { hasVendedIamCredentials } from '../../auth/auth'
1819
import {
1920
asyncCallWithTimeout,
2021
isInlineCompletionEnabled,
@@ -40,6 +41,7 @@ import { AuthUtil } from '../util/authUtil'
4041
import { CodeWhispererUserGroupSettings } from '../util/userGroupUtil'
4142
import { CWInlineCompletionItemProvider } from './inlineCompletionItemProvider'
4243
import { application } from '../util/codeWhispererApplication'
44+
import { openUrl } from '../../shared/utilities/vsCodeUtils'
4345
import { indent } from '../../shared/utilities/textUtilities'
4446

4547
/**
@@ -110,8 +112,10 @@ export class RecommendationHandler {
110112
isFirstPaginationCall: boolean,
111113
promise: Promise<any>
112114
): Promise<any> {
113-
const timeoutMessage = isCloud9() ? `Generate recommendation timeout.` : `List recommendation timeout`
114-
if (isManualTriggerOn && triggerType === 'OnDemand' && (isCloud9() || isFirstPaginationCall)) {
115+
const timeoutMessage = hasVendedIamCredentials()
116+
? 'Generate recommendation timeout.'
117+
: 'List recommendation timeout'
118+
if (isManualTriggerOn && triggerType === 'OnDemand' && (hasVendedIamCredentials() || isFirstPaginationCall)) {
115119
return vscode.window.withProgress(
116120
{
117121
location: vscode.ProgressLocation.Notification,
@@ -154,6 +158,7 @@ export class RecommendationHandler {
154158
autoTriggerType?: CodewhispererAutomatedTriggerType,
155159
pagination: boolean = true,
156160
page: number = 0,
161+
isSM: boolean = isSageMaker(),
157162
retry: boolean = false
158163
): Promise<GetRecommendationsResponse> {
159164
let invocationResult: 'Succeeded' | 'Failed' = 'Failed'
@@ -227,9 +232,8 @@ export class RecommendationHandler {
227232
startTime = performance.now()
228233
this.lastInvocationTime = startTime
229234
const mappedReq = runtimeLanguageContext.mapToRuntimeLanguage(request)
230-
const codewhispererPromise = pagination
231-
? client.listRecommendations(mappedReq)
232-
: client.generateRecommendations(mappedReq)
235+
const codewhispererPromise =
236+
pagination && !isSM ? client.listRecommendations(mappedReq) : client.generateRecommendations(mappedReq)
233237
const resp = await this.getServerResponse(
234238
triggerType,
235239
config.isManualTriggerEnabled,
@@ -270,6 +274,18 @@ export class RecommendationHandler {
270274
errorCode = error.code
271275
reason = `CodeWhisperer Invocation Exception: ${error?.code ?? error?.name ?? 'unknown'}`
272276
await this.onThrottlingException(error, triggerType)
277+
278+
if (error?.code === 'AccessDeniedException' && errorMessage?.includes('no identity-based policy')) {
279+
getLogger().error('CodeWhisperer AccessDeniedException : %s', (error as Error).message)
280+
vscode.window
281+
.showErrorMessage(`CodeWhisperer: ${error?.message}`, CodeWhispererConstants.settingsLearnMore)
282+
.then(async resp => {
283+
if (resp === CodeWhispererConstants.settingsLearnMore) {
284+
openUrl(vscode.Uri.parse(CodeWhispererConstants.learnMoreUri))
285+
}
286+
})
287+
await vscode.commands.executeCommand('aws.codeWhisperer.enableCodeSuggestions', false)
288+
}
273289
} else {
274290
errorMessage = error as string
275291
reason = error ? String(error) : 'unknown'

src/codewhisperer/util/authUtil.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import * as CodeWhispererConstants from '../models/constants'
88
import { Auth } from '../../auth/auth'
99
import { ToolkitError } from '../../shared/errors'
1010
import { getSecondaryAuth } from '../../auth/secondaryAuth'
11-
import { isCloud9 } from '../../shared/extensionUtilities'
11+
import { isCloud9, isSageMaker } from '../../shared/extensionUtilities'
1212
import { PromptSettings } from '../../shared/settings'
1313
import {
1414
ssoAccountAccessScopes,
@@ -32,6 +32,10 @@ export const isValidCodeWhispererConnection = (conn?: Connection): conn is Conne
3232
return isIamConnection(conn)
3333
}
3434

35+
if (isSageMaker()) {
36+
return isIamConnection(conn)
37+
}
38+
3539
return (
3640
(isCloud9('codecatalyst') && isIamConnection(conn)) ||
3741
(isSsoConnection(conn) && hasScopes(conn, codewhispererScopes))

src/extension.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import {
2727
getToolkitEnvironmentDetails,
2828
initializeComputeRegion,
2929
isCloud9,
30+
isSageMaker,
3031
showQuickStartWebview,
3132
showWelcomeMessage,
3233
} from './shared/extensionUtilities'
@@ -209,7 +210,10 @@ export async function activate(context: vscode.ExtensionContext) {
209210
})
210211
)
211212

212-
await codecatalyst.activate(extContext)
213+
// do not enable codecatalyst for sagemaker
214+
if (!isSageMaker()) {
215+
await codecatalyst.activate(extContext)
216+
}
213217

214218
await activateCloudFormationTemplateRegistry(context)
215219

src/shared/extensionUtilities.ts

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ const localize = nls.loadMessageBundle()
2222
const vscodeAppname = 'Visual Studio Code'
2323
const cloud9Appname = 'AWS Cloud9'
2424
const cloud9CnAppname = 'Amazon Cloud9'
25+
const sageMakerAppname = 'SageMaker Code Editor'
2526
const notInitialized = 'notInitialized'
2627

2728
export const mostRecentVersionKey: string = 'globalsMostRecentVersion'
2829

2930
export enum IDE {
3031
vscode,
3132
cloud9,
33+
sagemaker,
3234
unknown,
3335
}
3436

@@ -44,7 +46,11 @@ export function getIdeType(): IDE {
4446
return IDE.cloud9
4547
}
4648

47-
// Theia doesn't necessarily have all env propertie
49+
if (vscode.env.appName === sageMakerAppname) {
50+
return IDE.sagemaker
51+
}
52+
53+
// Theia doesn't necessarily have all env properties
4854
// so we should be defensive and assume appName is nullable.
4955
if (vscode.env.appName?.startsWith(vscodeAppname)) {
5056
return IDE.vscode
@@ -80,12 +86,29 @@ export function getIdeProperties(): IdeProperties {
8086
return createCloud9Properties(localize('AWS.title.cn', 'Amazon'))
8187
}
8288
return createCloud9Properties(company)
89+
case IDE.sagemaker:
90+
if (isCn()) {
91+
// check for cn region
92+
return createSageMakerProperties(localize('AWS.title.cn', 'Amazon'))
93+
}
94+
return createSageMakerProperties(company)
8395
// default is IDE.vscode
8496
default:
8597
return vscodeVals
8698
}
8799
}
88100

101+
function createSageMakerProperties(company: string): IdeProperties {
102+
return {
103+
shortName: localize('AWS.vscode.shortName', '{0} Code Editor', company),
104+
longName: localize('AWS.vscode.longName', '{0} SageMaker Code Editor', company),
105+
commandPalette: localize('AWS.vscode.commandPalette', 'Command Palette'),
106+
codelens: localize('AWS.vscode.codelens', 'CodeLens'),
107+
codelenses: localize('AWS.vscode.codelenses', 'CodeLenses'),
108+
company,
109+
}
110+
}
111+
89112
function createCloud9Properties(company: string): IdeProperties {
90113
return {
91114
shortName: localize('AWS.cloud9.shortName', 'Cloud9'),
@@ -109,6 +132,10 @@ export function isCloud9(flavor: 'classic' | 'codecatalyst' | 'any' = 'any'): bo
109132
return (flavor === 'classic' && !codecat) || (flavor === 'codecatalyst' && codecat)
110133
}
111134

135+
export function isSageMaker(): boolean {
136+
return vscode.env.appName === sageMakerAppname
137+
}
138+
112139
export function isCn(): boolean {
113140
return getComputeRegion()?.startsWith('cn') ?? false
114141
}
@@ -338,10 +365,17 @@ export function getToolkitEnvironmentDetails(): string {
338365
}
339366

340367
/**
341-
* Returns the Cloud9 compute region or 'unknown' if we can't pull a region, or `undefined` if this is not Cloud9.
368+
* Returns the Cloud9/SageMaker compute region or 'unknown' if we can't pull a region, or `undefined` if this is not Cloud9 or SageMaker.
342369
*/
343-
export async function initializeComputeRegion(metadata?: Ec2MetadataClient, isC9: boolean = isCloud9()): Promise<void> {
344-
if (isC9) {
370+
371+
export async function initializeComputeRegion(
372+
metadata?: Ec2MetadataClient,
373+
isC9?: boolean,
374+
isSM?: boolean
375+
): Promise<void> {
376+
isC9 ??= isCloud9()
377+
isSM ??= isSageMaker()
378+
if (isC9 || isSM) {
345379
metadata ??= new DefaultEc2MetadataClient()
346380
try {
347381
const identity = await metadata.getInstanceIdentity()

src/test/codewhisperer/explorer/codewhispererNode.test.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ describe('codewhispererNode', function () {
3939

4040
it('should create a node showing AWS Builder ID connection', function () {
4141
sinon.stub(AuthUtil.instance, 'isUsingSavedConnection').get(() => true)
42+
sinon.stub(AuthUtil.instance, 'isBuilderIdInUse').resolves(true)
4243
isConnectionValid.returns(true)
4344

4445
const node = codewhispererNode
@@ -50,6 +51,20 @@ describe('codewhispererNode', function () {
5051
assert.strictEqual(treeItem.collapsibleState, vscode.TreeItemCollapsibleState.Collapsed)
5152
})
5253

54+
it('should create a node showing IAM connection', function () {
55+
sinon.stub(AuthUtil.instance, 'isUsingSavedConnection').get(() => true)
56+
//sinon.stub(AuthUtil.instance, 'isBuilderIdInUse').resolves(false)
57+
isConnectionValid.returns(true)
58+
59+
const node = codewhispererNode
60+
const treeItem = node.getTreeItem()
61+
62+
assert.strictEqual(treeItem.label, 'CodeWhisperer')
63+
assert.strictEqual(treeItem.contextValue, 'awsCodeWhispererNodeSaved')
64+
assert.strictEqual(treeItem.description, 'IAM Connected')
65+
assert.strictEqual(treeItem.collapsibleState, vscode.TreeItemCollapsibleState.Collapsed)
66+
})
67+
5368
it('should create a node showing enterprise SSO connection', function () {
5469
sinon.stub(AuthUtil.instance, 'isUsingSavedConnection').get(() => true)
5570
sinon.stub(AuthUtil.instance, 'isEnterpriseSsoInUse').resolves(true)

src/test/shared/extensionUtilities.test.ts

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,45 @@ describe('initializeComputeRegion, getComputeRegion', async function () {
166166
assert.strictEqual(getComputeRegion(), 'us-weast-1')
167167
})
168168

169+
it('returns a compute region for sagemaker', async function () {
170+
sandbox.stub(metadataService, 'getInstanceIdentity').resolves({ region: 'us-weast-1' })
171+
172+
await initializeComputeRegion(metadataService, false, true)
173+
assert.strictEqual(getComputeRegion(), 'us-weast-1')
174+
})
175+
169176
it('returns "unknown" if cloud9 and the MetadataService request fails', async function () {
170177
sandbox.stub(metadataService, 'getInstanceIdentity').rejects({} as AWSError)
171178

172179
await initializeComputeRegion(metadataService, true)
173180
assert.strictEqual(getComputeRegion(), 'unknown')
174181
})
175182

183+
it('returns "unknown" if sagemaker and the MetadataService request fails', async function () {
184+
sandbox.stub(metadataService, 'getInstanceIdentity').rejects({} as AWSError)
185+
186+
await initializeComputeRegion(metadataService, false, true)
187+
assert.strictEqual(getComputeRegion(), 'unknown')
188+
})
189+
176190
it('returns "unknown" if cloud9 and can not find a region', async function () {
177191
sandbox.stub(metadataService, 'getInstanceIdentity').resolves({} as InstanceIdentity)
178192

179193
await initializeComputeRegion(metadataService, true)
180194
assert.strictEqual(getComputeRegion(), 'unknown')
181195
})
182196

183-
it('returns undefined if not cloud9', async function () {
197+
it('returns "unknown" if sagemaker and can not find a region', async function () {
198+
sandbox.stub(metadataService, 'getInstanceIdentity').resolves({} as InstanceIdentity)
199+
200+
await initializeComputeRegion(metadataService, false, true)
201+
assert.strictEqual(getComputeRegion(), 'unknown')
202+
})
203+
204+
it('returns undefined if not cloud9 or sagemaker', async function () {
184205
sandbox.stub(metadataService, 'getInstanceIdentity').callsArgWith(1, undefined, 'lol')
185206

186-
await initializeComputeRegion(metadataService, false)
207+
await initializeComputeRegion(metadataService, false, false)
187208
assert.strictEqual(getComputeRegion(), undefined)
188209
})
189210

0 commit comments

Comments
 (0)