diff --git a/packages/core/src/codewhisperer/util/customizationUtil.ts b/packages/core/src/codewhisperer/util/customizationUtil.ts index 9bacb06c48a..02ec5a1a3bd 100644 --- a/packages/core/src/codewhisperer/util/customizationUtil.ts +++ b/packages/core/src/codewhisperer/util/customizationUtil.ts @@ -18,7 +18,6 @@ import { showMessageWithUrl } from '../../shared/utilities/messages' import { parse } from '@aws-sdk/util-arn-parser' import { Commands } from '../../shared/vscode/commands2' import { vsCodeState } from '../models/model' -import { FeatureConfigProvider, Features } from '../../shared/featureConfig' /** * @@ -92,10 +91,7 @@ export const baseCustomization = { } /** - * Gets the customization that should be used for user requests. If a user has manually selected - * a customization, always respect that choice. If not, check if the user is part of an AB - * group assigned a specific customization. If so, use that customization. If not, use the - * base customization. + * @returns customization selected by users, `baseCustomization` if none is selected */ export const getSelectedCustomization = (): Customization => { if ( @@ -116,25 +112,27 @@ export const getSelectedCustomization = (): Customization => { if (selectedCustomization && selectedCustomization.name !== '') { return selectedCustomization } else { - const customizationFeature = FeatureConfigProvider.getFeature(Features.customizationArnOverride) - const arnOverride = customizationFeature?.value.stringValue - const customizationOverrideName = customizationFeature?.variation - if (arnOverride === undefined) { - return baseCustomization - } else { - return { - arn: arnOverride, - name: customizationOverrideName, - description: baseCustomization.description, - } - } + return baseCustomization } } -export const setSelectedCustomization = async (customization: Customization) => { +/** + * @param customization customization to select + * @param isOverride if the API call is made from us (Q) but not users' intent, set isOverride to TRUE + * Override happens when ALL following conditions are met + * 1. service returns non-empty override customization arn, refer to [featureConfig.ts] + * 2. the override customization arn is different from the previous override customization if any. The purpose is to only do override once on users' behalf. + */ +export const setSelectedCustomization = async (customization: Customization, isOverride: boolean = false) => { if (!AuthUtil.instance.isValidEnterpriseSsoInUse() || !AuthUtil.instance.conn) { return } + if (isOverride) { + const previousOverride = globals.globalState.tryGet('aws.amazonq.customization.override', String) + if (customization.arn === previousOverride) { + return + } + } const selectedCustomizationObj = globals.globalState.tryGet<{ [label: string]: Customization }>( 'CODEWHISPERER_SELECTED_CUSTOMIZATION', Object, @@ -144,6 +142,9 @@ export const setSelectedCustomization = async (customization: Customization) => getLogger().debug(`Selected customization ${customization.name} for ${AuthUtil.instance.conn.label}`) await globals.globalState.update('CODEWHISPERER_SELECTED_CUSTOMIZATION', selectedCustomizationObj) + if (isOverride) { + await globals.globalState.update('aws.amazonq.customization.override', customization.arn) + } vsCodeState.isFreeTierLimitReached = false await Commands.tryExecute('aws.amazonq.refreshStatusBar') } diff --git a/packages/core/src/shared/featureConfig.ts b/packages/core/src/shared/featureConfig.ts index 2c330648483..dd8bbb901e5 100644 --- a/packages/core/src/shared/featureConfig.ts +++ b/packages/core/src/shared/featureConfig.ts @@ -21,6 +21,7 @@ import { getClientId, getOperatingSystem } from './telemetry/util' import { extensionVersion } from './vscode/env' import { telemetry } from './telemetry/telemetry' import { Commands } from './vscode/commands2' +import { setSelectedCustomization } from '../codewhisperer/util/customizationUtil' const localize = nls.loadMessageBundle() @@ -148,7 +149,7 @@ export class FeatureConfigProvider { if (isBuilderIdConnection(AuthUtil.instance.conn)) { this.featureConfigs.delete(Features.customizationArnOverride) } else if (isIdcSsoConnection(AuthUtil.instance.conn)) { - let availableCustomizations = undefined + let availableCustomizations: Customization[] = [] try { const items: Customization[] = [] const response = await client.listAvailableCustomizations() @@ -157,17 +158,20 @@ export class FeatureConfigProvider { )) { items.push(...customizations) } - availableCustomizations = items.map((c) => c.arn) + availableCustomizations = items } catch (e) { getLogger().debug('amazonq: Failed to list available customizations') } // If customizationArn from A/B is not available in listAvailableCustomizations response, don't use this value - if (!availableCustomizations?.includes(customizationArnOverride)) { + const targetCustomization = availableCustomizations?.find((c) => c.arn === customizationArnOverride) + if (!targetCustomization) { getLogger().debug( `Customization arn ${customizationArnOverride} not available in listAvailableCustomizations, not using` ) this.featureConfigs.delete(Features.customizationArnOverride) + } else { + await setSelectedCustomization(targetCustomization, true) } await vscode.commands.executeCommand('aws.amazonq.refreshStatusBar') diff --git a/packages/core/src/shared/globalState.ts b/packages/core/src/shared/globalState.ts index defb7658f68..83fc24359c2 100644 --- a/packages/core/src/shared/globalState.ts +++ b/packages/core/src/shared/globalState.ts @@ -45,6 +45,7 @@ export type globalKey = | 'aws.toolkit.amazonq.dismissed' | 'aws.toolkit.amazonqInstall.dismissed' | 'aws.amazonq.workspaceIndexToggleOn' + | 'aws.amazonq.customization.override' // Deprecated/legacy names. New keys should start with "aws.". | '#sessionCreationDates' // Legacy name from `ssoAccessTokenProvider.ts`. | 'CODECATALYST_RECONNECT' diff --git a/packages/core/src/test/amazonq/customizationUtil.test.ts b/packages/core/src/test/amazonq/customizationUtil.test.ts index ea9ca7bf562..505c89ae0c9 100644 --- a/packages/core/src/test/amazonq/customizationUtil.test.ts +++ b/packages/core/src/test/amazonq/customizationUtil.test.ts @@ -82,16 +82,39 @@ describe('CodeWhisperer-customizationUtils', function () { assert.strictEqual(actualCustomization.name, selectedCustomization.name) }) - it('Returns AB customization', async function () { - sinon.stub(AuthUtil.instance, 'isValidEnterpriseSsoInUse').returns(true) + it(`setSelectedCustomization should set to the customization provided if override option is false or not specified`, async function () { + await setSelectedCustomization({ arn: 'FOO' }, false) + assert.strictEqual(getSelectedCustomization().arn, 'FOO') + + await setSelectedCustomization({ arn: 'BAR' }) + assert.strictEqual(getSelectedCustomization().arn, 'BAR') + + await setSelectedCustomization({ arn: 'BAZ' }) + assert.strictEqual(getSelectedCustomization().arn, 'BAZ') + + await setSelectedCustomization({ arn: 'QOO' }, false) + assert.strictEqual(getSelectedCustomization().arn, 'QOO') + }) + + it(`setSelectedCustomization should only set to the customization provided once for override per customization arn if override is true`, async function () { + await setSelectedCustomization({ arn: 'OVERRIDE' }, true) + assert.strictEqual(getSelectedCustomization().arn, 'OVERRIDE') + + await setSelectedCustomization({ arn: 'FOO' }, false) + assert.strictEqual(getSelectedCustomization().arn, 'FOO') + + // Should NOT override only happen per customization arn + await setSelectedCustomization({ arn: 'OVERRIDE' }, true) + assert.strictEqual(getSelectedCustomization().arn, 'FOO') - await setSelectedCustomization({ - arn: '', - name: '', - }) + await setSelectedCustomization({ arn: 'FOO' }, false) + assert.strictEqual(getSelectedCustomization().arn, 'FOO') - const returnedCustomization = getSelectedCustomization() + await setSelectedCustomization({ arn: 'BAR' }, false) + assert.strictEqual(getSelectedCustomization().arn, 'BAR') - assert.strictEqual(returnedCustomization.name, featureCustomization.name) + // Sould override as it's a different arn + await setSelectedCustomization({ arn: 'OVERRIDE_V2' }, true) + assert.strictEqual(getSelectedCustomization().arn, 'OVERRIDE_V2') }) })